Skip to content

Commit 333c6db

Browse files
committed
Improved performance of binning by using integers instead of floats
1 parent ece9ac1 commit 333c6db

File tree

1 file changed

+33
-27
lines changed

1 file changed

+33
-27
lines changed

efast/binning.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def open_reflectance_bands(self, band_names: list[str]) -> dict[str, xr.DataArra
110110
for band_name in band_names:
111111
file_name = determine_file_name_from_reflectance_variable_name(band_name)
112112
# by default `mask_and_scale=None` behaves as if `mask_and_scale=True`
113-
band_ds = xr.open_dataset(self.path / file_name)
113+
band_ds = xr.open_dataset(self.path / file_name, mask_and_scale=False)
114114
bands.append(band_ds[band_name])
115115

116116
return {name: band for (name, band) in zip(band_names, bands)}
@@ -198,47 +198,53 @@ def bin_to_grid(ds: xr.Dataset, bands: Iterable[str], grid: Grid,*, super_sampli
198198
return binned
199199

200200
def bin_to_grid_numpy(ds: xr.Dataset, bands: Iterable[str], grid: Grid,*, super_sampling: int=1, interpolation_order: int=1) -> NDArray:
201-
lat = ds["lat"]
202-
lon = ds["lon"]
201+
lat2d = ds["lat"].data
202+
lon2d = ds["lon"].data
203203

204-
#lat = ndimage.zoom(lat, super_sampling, order=interpolation_order).ravel()
205-
#lon = ndimage.zoom(lon, super_sampling, order=interpolation_order).ravel()
206-
lat = super_sample_opencv(lat.data, super_sampling, interpolation=cv2.INTER_LINEAR).ravel()
207-
lon = super_sample_opencv(lon.data, super_sampling, interpolation=cv2.INTER_LINEAR).ravel()
204+
lat = super_sample_opencv(lat2d, super_sampling, interpolation=cv2.INTER_LINEAR).ravel()
205+
lon = super_sample_opencv(lon2d, super_sampling, interpolation=cv2.INTER_LINEAR).ravel()
208206

209207
width = grid.lon.shape[0] - 1
210208
height = grid.lat.shape[0] - 1
211209

212210
pixel_size = (grid.lon[-1] - grid.lon[0]) / width
213-
bin_idx_row = (lat - grid.lat[0]) / pixel_size
214-
bin_idx_col = (lon - grid.lon[0]) / pixel_size
215211

216-
# TODO test
217-
bin_idx_row = bin_idx_row.astype(int)
218-
bin_idx_col = bin_idx_col.astype(int)
212+
# Reuse for outputs
213+
bin_idx_buf = np.zeros_like(lat)
214+
215+
bin_idx_row = np.divide((lat - grid.lat[0]), pixel_size, out=bin_idx_buf).astype(int)
216+
bin_idx_col = np.divide((lon - grid.lon[0]), pixel_size, out=bin_idx_buf).astype(int)
217+
218+
bin_idx = bin_idx_row * width
219+
bin_idx += bin_idx_col
219220

220-
bin_idx = bin_idx_row * width + bin_idx_col
221+
222+
221223
bin_idx[(bin_idx_row < 0) | (bin_idx_row > height) | (bin_idx_col < 0) | (bin_idx_col > width)] = -1
224+
#bin_idx[(bin_idx_row < 0)] = -1
225+
#bin_idx[(bin_idx_row > height)] = -1
226+
#bin_idx[(bin_idx_col < 0)] = -1
227+
#bin_idx[(bin_idx_col > width)] = -1
222228

223229
counts, _ = np.histogram(bin_idx, width * height, range=(0, width*height))
224-
#counts, _, _ = np.histogram2d(bin_idx_row, bin_idx_col, bins=(range(height + 1), range(width + 1)))#, range=(0, width * height))
225230

226231
binned = []
232+
means = None
233+
sampled_data = None if super_sampling == 1 else np.zeros((lat2d.shape[0] * super_sampling, lat2d.shape[1] * super_sampling), dtype=np.int16)
234+
235+
FILL_VALUE = -10000 # TODO move
227236
for band in bands:
228237
data = ds[band].data
229-
data[np.isnan(data)] = 0
238+
data[data == FILL_VALUE] = 0
230239
if super_sampling != 1:
231-
# TODO could reuse allocation
232-
data = super_sample(data, super_sampling)
233-
if data.dtype == np.float32:
234-
# TODO otherwise we get weird results
235-
data = data.astype(np.float64)
236-
hist, _ = np.histogram(bin_idx, range(width * height + 1), weights=data.ravel(), range=(0, width*height))
237-
#hist, _, _ = np.histogram2d(bin_idx_row, bin_idx_col, (range(height + 1), range(width + 1)), weights=data.ravel(), range=(0, width * height))
238-
# TODO divide by zero
239-
#means = (hist / counts).reshape((height, width))
240-
means = (hist / counts).reshape(height, width)
241-
binned.append(means)
240+
super_sample(data, super_sampling, out=sampled_data)
241+
else:
242+
sampled_data = data
243+
244+
hist, _ = np.histogram(bin_idx, range(width * height + 1), weights=sampled_data.astype(np.int32).reshape(-1), range=(0, width*height))
245+
means = np.divide(hist, counts, out=means)
246+
scaled_means = means.reshape(height, width) * SCALE_FACTOR
247+
binned.append(scaled_means)
242248

243249
binned = np.array(binned)
244250
return binned
@@ -290,5 +296,5 @@ def super_sample_opencv(arr, factor,*, out=None, interpolation=cv2.INTER_NEAREST
290296
if out is None:
291297
out = np.zeros((arr.shape[0] * factor, arr.shape[1] * factor), dtype=arr.dtype)
292298

293-
cv2.resize(arr, dst=out, dsize=out.shape[::-1], fx=2, fy=2, interpolation=interpolation)
299+
res = cv2.resize(arr, dst=out, dsize=out.shape[::-1], fx=2, fy=2, interpolation=interpolation)
294300
return out

0 commit comments

Comments
 (0)