Skip to content

Commit 15bd804

Browse files
committed
Extend and simplify roundd.
1 parent 8010dc1 commit 15bd804

File tree

1 file changed

+39
-58
lines changed

1 file changed

+39
-58
lines changed

raster_tools/roundd.py

Lines changed: 39 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,58 @@
11
# (c) Nelen & Schuurmans. GPL licensed, see LICENSE.rst.
22
# -*- coding: utf-8 -*-
33
"""
4-
Round raster to set decimals.
4+
Round values, change "no data" value, or both.
55
"""
66

7-
from os.path import dirname, exists
7+
from os.path import exists
88
import argparse
9-
import os
10-
from osgeo import gdal
9+
from osgeo.gdal import GetDriverByName, Open
1110
import numpy as np
1211
from raster_tools import datasets
1312

14-
# output driver and optinos
15-
DRIVER = gdal.GetDriverByName('gtiff')
13+
DRIVER = GetDriverByName('gtiff')
1614
OPTIONS = ['compress=deflate', 'tiled=yes']
1715

18-
progress = True
1916

20-
21-
class Exchange(object):
22-
def __init__(self, source_path, target_path):
23-
"""
24-
Read source, create target array.
25-
"""
26-
dataset = gdal.Open(source_path)
27-
band = dataset.GetRasterBand(1)
28-
29-
self.source = band.ReadAsArray()
30-
self.no_data_value = band.GetNoDataValue()
31-
32-
self.shape = self.source.shape
33-
34-
self.kwargs = {
35-
'no_data_value': self.no_data_value,
36-
'projection': dataset.GetProjection(),
37-
'geo_transform': dataset.GetGeoTransform(),
38-
}
39-
40-
self.target_path = target_path
41-
self.target = np.full_like(self.source, self.no_data_value)
42-
43-
def round(self, decimals):
44-
""" Round target. """
45-
active = self.source != self.no_data_value
46-
self.target[active] = self.source[active].round(decimals)
47-
48-
def save(self):
49-
""" Save. """
50-
# prepare dirs
51-
subdir = dirname(self.target_path)
52-
if subdir:
53-
os.makedirs(subdir, exist_ok=True)
54-
55-
# write tiff
56-
array = self.target[np.newaxis]
57-
with datasets.Dataset(array, **self.kwargs) as dataset:
58-
DRIVER.CreateCopy(self.target_path, dataset, options=OPTIONS)
59-
60-
61-
def roundd(source_path, target_path, decimals):
62-
""" Round decimals. """
17+
def roundd(source_path, target_path, decimals=None, no_data_value=None):
6318
# skip existing
6419
if exists(target_path):
65-
print('{} skipped.'.format(target_path))
20+
print(f'{target_path} skipped.')
6621
return
6722

6823
# skip when missing sources
6924
if not exists(source_path):
70-
print('Raster source "{}" not found.'.format(source_path))
25+
print(f'{target_path} not found.')
7126
return
7227

73-
# read
74-
exchange = Exchange(source_path, target_path)
28+
dataset = Open(source_path)
29+
band = dataset.GetRasterBand(1)
30+
31+
values = band.ReadAsArray()
32+
active = values != band.GetNoDataValue()
33+
34+
kwargs = {
35+
'projection': dataset.GetProjection(),
36+
'geo_transform': dataset.GetGeoTransform(),
37+
}
7538

76-
if decimals:
77-
exchange.round(decimals)
39+
# round
40+
if decimals is not None:
41+
values[active] = values[active].round(decimals)
7842

79-
# save
80-
exchange.save()
43+
# change "no data" value
44+
if no_data_value is not None:
45+
values[~active] = no_data_value
46+
else:
47+
no_data_value = band.GetNoDatavalue()
48+
49+
kwargs["no_data_value"] = no_data_value
50+
51+
# write tiff
52+
array = values[np.newaxis]
53+
with datasets.Dataset(array, **kwargs) as dataset:
54+
DRIVER.CreateCopy(target_path, dataset, options=OPTIONS)
55+
print(f'{target_path} written.')
8156

8257

8358
def get_parser():
@@ -101,6 +76,12 @@ def get_parser():
10176
dest='decimals',
10277
help='Round the result to this number of decimals.',
10378
)
79+
parser.add_argument(
80+
'-n', '--no-data-value',
81+
type=float,
82+
dest='no_data_value',
83+
help='Use this as new No Data Value.',
84+
)
10485

10586
return parser
10687

0 commit comments

Comments
 (0)