Skip to content

Commit ece9ac1

Browse files
committed
Added binning implementation, test, profiling
1 parent 69b179c commit ece9ac1

File tree

6 files changed

+457
-1
lines changed

6 files changed

+457
-1
lines changed

efast/binning.py

Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
from collections import namedtuple
2+
from pathlib import Path
3+
from typing import Iterable, NamedTuple, Self, TypeVar
4+
import warnings
5+
from numpy.lib.stride_tricks import as_strided
6+
from numpy.typing import NDArray
7+
import xarray as xr
8+
import re
9+
import numpy as np
10+
from scipy import stats
11+
import shapely
12+
from scipy import ndimage
13+
import cv2
14+
15+
16+
def binning_s3_py(
17+
download_dir,
18+
binning_dir,
19+
footprint,
20+
s3_bands=["SDR_Oa04", "SDR_Oa06", "SDR_Oa08", "SDR_Oa17"],
21+
max_zenith_angle=30,
22+
crs="EPSG:4326",
23+
):
24+
"""
25+
TODO: "binning" might be a misnomer, as the function does more than just binning
26+
"""
27+
28+
# read required bands of sentinel-3 product
29+
# reproject to EPSG 4326 (~300m grid), this step is likely unnecessary
30+
# bin to SEA_grid
31+
# reproject to EPSG 4326 (66xxx slices)
32+
33+
pass
34+
35+
36+
def get_reflectance_filename(index: int):
37+
if not (0 < index <= 21):
38+
raise ValueError(
39+
f"The index must be an integer between 1 and 21 (both inclusive)"
40+
)
41+
42+
return f"SDR_Oa{index:02}"
43+
44+
45+
GEOLOCATION_FILE_NAME = "geolocation.nc"
46+
FLAG_FILE_NAME = "flags.nc"
47+
48+
SCALE_FACTOR = 1e-4 # TODO read from netcdf
49+
50+
51+
# TODO better assume that band_names are already actual band names. Calculate which netcdf file you need independently.
52+
class SynProduct:
53+
FLAG_VARIABLE_NAMES = [
54+
"CLOUD_flags",
55+
"OLC_flags",
56+
"SLN_flags",
57+
"SLO_flags",
58+
"SYN_flags",
59+
]
60+
61+
def __init__(self, path: Path | str, band_names: list[str]):
62+
self.path = Path(path)
63+
self.band_names = band_names
64+
65+
def read_bands(self) -> xr.Dataset:
66+
# sort bands into flag bands, reflectance bands, others
67+
flag_band_names = list(filter(lambda b: b.endswith("_flags"), self.band_names))
68+
reflectance_band_names = list(
69+
filter(lambda b: b.startswith("SDR_Oa"), self.band_names)
70+
)
71+
remaining_band_names = [
72+
b
73+
for b in self.band_names
74+
if b not in set([*flag_band_names, *reflectance_band_names])
75+
]
76+
77+
if len(remaining_band_names) != 0:
78+
raise ValueError(
79+
f"Band names '{remaining_band_names}' are neither "
80+
"flags nor reflectance bands. Cannot open."
81+
)
82+
83+
# geolocation
84+
geolocation_filename = self.path / GEOLOCATION_FILE_NAME
85+
geolocation_ds = xr.open_dataset(geolocation_filename)
86+
lat = geolocation_ds["lat"].data
87+
lon = geolocation_ds["lon"].data
88+
89+
# reflectance bands
90+
reflectance_bands = self.open_reflectance_bands(reflectance_band_names)
91+
92+
# flag bands
93+
flag_bands = self.open_flag_bands(flag_band_names)
94+
95+
dims = ["lat", "lon"]
96+
bands = {**reflectance_bands, **flag_bands}
97+
bands = {name: (dims, band.data) for (name, band) in bands.items()}
98+
bands["lat"] = (["x", "y"], lat)
99+
bands["lon"] = (["x", "y"], lon)
100+
# join
101+
ds = xr.Dataset(
102+
bands,
103+
)
104+
ds.set_coords(("lat", "lon"))
105+
106+
return ds
107+
108+
def open_reflectance_bands(self, band_names: list[str]) -> dict[str, xr.DataArray]:
109+
bands: list[xr.DataArray] = []
110+
for band_name in band_names:
111+
file_name = determine_file_name_from_reflectance_variable_name(band_name)
112+
# by default `mask_and_scale=None` behaves as if `mask_and_scale=True`
113+
band_ds = xr.open_dataset(self.path / file_name)
114+
bands.append(band_ds[band_name])
115+
116+
return {name: band for (name, band) in zip(band_names, bands)}
117+
118+
def open_flag_bands(self, band_names: list[str]) -> dict[str, xr.DataArray]:
119+
flag_ds = xr.open_dataset(self.path / FLAG_FILE_NAME)
120+
121+
band_name_exists = [bn in flag_ds.variables.keys() for bn in band_names]
122+
nonexistent_bands = [
123+
band for (band, exists) in zip(band_names, band_name_exists) if not exists
124+
]
125+
126+
if len(nonexistent_bands) > 0:
127+
raise ValueError(
128+
f"Could not find bands '{nonexistent_bands}' in file {self.path / FLAG_FILE_NAME}"
129+
)
130+
131+
return {bn: flag_ds[bn] for bn in band_names}
132+
133+
134+
def determine_file_name_from_reflectance_variable_name(varname: str):
135+
index_group = 1
136+
pattern = f"SDR_Oa(..)"
137+
m = re.match(pattern, varname)
138+
if m is None:
139+
raise ValueError(
140+
f"variable name '{varname}' does not match pattern '{pattern}'. Not a reflectance_band."
141+
)
142+
143+
return f"Syn_Oa{int(m.group(index_group)):02}_reflectance.nc" # pyright: ignore [reportOptionalMemberAccess]
144+
145+
146+
class BBox:
147+
def __init__(
148+
self, *, lat_min: float, lat_max: float, lon_min: float, lon_max: float
149+
) -> None:
150+
self.lat_min = lat_min
151+
self.lat_max = lat_max
152+
self.lon_min = lon_min
153+
self.lon_max = lon_max
154+
155+
@classmethod
156+
def from_wkt(cls, wkt: str) -> Self:
157+
geom = shapely.from_wkt(wkt)
158+
envelope = shapely.envelope(geom)
159+
lon_min, lat_min, lon_max, lat_max = envelope.bounds
160+
return cls(
161+
lat_min=lat_min,
162+
lat_max=lat_max,
163+
lon_min=lon_min,
164+
lon_max=lon_max,
165+
)
166+
167+
168+
Grid = namedtuple("Grid", ["lat", "lon"])
169+
170+
def bin_to_grid(ds: xr.Dataset, bands: Iterable[str], grid: Grid,*, super_sampling: int=1, interpolation_order: int=1) -> NDArray:
171+
lat = ds["lat"]
172+
lon = ds["lon"]
173+
174+
#lat = ndimage.zoom(lat, super_sampling, order=interpolation_order).ravel()
175+
# lon = ndimage.zoom(lon, super_sampling, order=interpolation_order).ravel()
176+
lat = super_sample_opencv(lat, super_sampling, interpolation=cv2.INTER_LINEAR)
177+
lon = super_sample_opencv(lon, super_sampling, interpolation=cv2.INTER_LINEAR)
178+
179+
binned = []
180+
181+
for band in bands:
182+
data = ds[band].data
183+
if super_sampling != 1:
184+
kernel = np.ones((super_sampling, super_sampling))
185+
186+
data = super_sample(data, super_sampling)
187+
res = stats.binned_statistic_2d(
188+
lat,
189+
lon,
190+
values=data.ravel(),
191+
statistic="mean",
192+
bins=(grid.lat, grid.lon), # definition of target grid
193+
#range=bbox,
194+
)
195+
binned.append(res.statistic)
196+
197+
binned = np.array(binned)
198+
return binned
199+
200+
def bin_to_grid_numpy(ds: xr.Dataset, bands: Iterable[str], grid: Grid,*, super_sampling: int=1, interpolation_order: int=1) -> NDArray:
201+
lat = ds["lat"]
202+
lon = ds["lon"]
203+
204+
#lat = ndimage.zoom(lat, super_sampling, order=interpolation_order).ravel()
205+
#lon = ndimage.zoom(lon, super_sampling, order=interpolation_order).ravel()
206+
lat = super_sample_opencv(lat.data, super_sampling, interpolation=cv2.INTER_LINEAR).ravel()
207+
lon = super_sample_opencv(lon.data, super_sampling, interpolation=cv2.INTER_LINEAR).ravel()
208+
209+
width = grid.lon.shape[0] - 1
210+
height = grid.lat.shape[0] - 1
211+
212+
pixel_size = (grid.lon[-1] - grid.lon[0]) / width
213+
bin_idx_row = (lat - grid.lat[0]) / pixel_size
214+
bin_idx_col = (lon - grid.lon[0]) / pixel_size
215+
216+
# TODO test
217+
bin_idx_row = bin_idx_row.astype(int)
218+
bin_idx_col = bin_idx_col.astype(int)
219+
220+
bin_idx = bin_idx_row * width + bin_idx_col
221+
bin_idx[(bin_idx_row < 0) | (bin_idx_row > height) | (bin_idx_col < 0) | (bin_idx_col > width)] = -1
222+
223+
counts, _ = np.histogram(bin_idx, width * height, range=(0, width*height))
224+
#counts, _, _ = np.histogram2d(bin_idx_row, bin_idx_col, bins=(range(height + 1), range(width + 1)))#, range=(0, width * height))
225+
226+
binned = []
227+
for band in bands:
228+
data = ds[band].data
229+
data[np.isnan(data)] = 0
230+
if super_sampling != 1:
231+
# TODO could reuse allocation
232+
data = super_sample(data, super_sampling)
233+
if data.dtype == np.float32:
234+
# TODO otherwise we get weird results
235+
data = data.astype(np.float64)
236+
hist, _ = np.histogram(bin_idx, range(width * height + 1), weights=data.ravel(), range=(0, width*height))
237+
#hist, _, _ = np.histogram2d(bin_idx_row, bin_idx_col, (range(height + 1), range(width + 1)), weights=data.ravel(), range=(0, width * height))
238+
# TODO divide by zero
239+
#means = (hist / counts).reshape((height, width))
240+
means = (hist / counts).reshape(height, width)
241+
binned.append(means)
242+
243+
binned = np.array(binned)
244+
return binned
245+
246+
247+
def create_geogrid(bbox: BBox, num_rows: int = 66792):
248+
# -90 and 90 are included, 0 also included. Lat has one more entry than num_rows, rows are defined by the spaces between lat entries
249+
lat = np.linspace(0, 180, num=num_rows + 1, endpoint=True) - 90
250+
251+
# 180 and 0 are included, -180 is not.
252+
# lon has 2 * num_rows entries
253+
# the last bin is between lon[-1] and lon[0] (antimeridian)
254+
lon = np.linspace(0, 360, num=num_rows * 2 + 1, endpoint=True) - 180
255+
lon = lon[1:]
256+
# TODO return type with names to not confuse lat/lon
257+
258+
# one lat bound before first bound that is larger than the bounding box min
259+
# TODO I can calculate lat_idx_min (and the others) directly
260+
lat_idx_min = np.argmax(lat >= bbox.lat_min) - 1
261+
# first lat bound that is larger than the bounding box max
262+
lat_idx_max = np.argmax(lat > bbox.lat_max)
263+
# one lon bound before first bound that is larger than the bounding box min
264+
lon_idx_min = np.argmax(lon >= bbox.lon_min) - 1
265+
# first lon bound that is larger than the bounding box max
266+
lon_idx_max = np.argmax(lon > bbox.lon_max)
267+
268+
grid = Grid(lat=lat[lat_idx_min:lat_idx_max+1], lon=lon[lon_idx_min:lon_idx_max+1])
269+
return grid
270+
271+
272+
def super_sample(arr, factor, *, out=None):
273+
#return super_sample_kron(arr, factor, out=out)
274+
#return super_sample_repeat(arr, factor, out=out)
275+
return super_sample_opencv(arr, factor, out=out)
276+
277+
def super_sample_kron(arr, factor, *, out=None):
278+
if out is not None:
279+
warnings.warn("Parameter 'out' not supported for kron super sampling")
280+
kernel = np.ones((factor, factor))
281+
return np.kron(arr, kernel)
282+
283+
def super_sample_repeat(arr, factor, *, out=None):
284+
if out is not None:
285+
warnings.warn("Parameter 'out' not supported for repeat super sampling")
286+
return arr.repeat(factor, axis=1).repeat(factor, axis=0)
287+
288+
289+
def super_sample_opencv(arr, factor,*, out=None, interpolation=cv2.INTER_NEAREST):
290+
if out is None:
291+
out = np.zeros((arr.shape[0] * factor, arr.shape[1] * factor), dtype=arr.dtype)
292+
293+
cv2.resize(arr, dst=out, dsize=out.shape[::-1], fx=2, fy=2, interpolation=interpolation)
294+
return out

efast/eoprofiling.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""Generic profiling context"""
4+
5+
__author__ = "Martin Böttcher, Brockmann Consult GmbH"
6+
__copyright__ = "Copyright 2023, Brockmann Consult GmbH"
7+
__license__ = "TBD"
8+
__version__ = "0.5"
9+
__email__ = "[email protected]"
10+
__status__ = "Development"
11+
12+
# changes in 1.1:
13+
# ...
14+
15+
16+
class Profiling:
17+
18+
def __init__(self, output: str = None):
19+
self._output = output
20+
21+
def __enter__(self):
22+
if self._output:
23+
import cProfile
24+
25+
self._profile = cProfile.Profile()
26+
self._profile.enable()
27+
28+
def __exit__(self, exc_type, exc_val, exc_tb):
29+
if self._output:
30+
self._profile.disable()
31+
import io
32+
import pstats
33+
34+
buffer = io.StringIO()
35+
stats = pstats.Stats(self._profile, stream=buffer)
36+
stats.sort_stats("tottime").print_stats()
37+
with open(self._output, "w") as f:
38+
f.write(buffer.getvalue())

efast/s3_processing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import os
2929
import re
30+
from pathlib import Path
3031

3132
from datetime import datetime
3233

@@ -180,7 +181,7 @@ def binning_s3(
180181
string += f"{variable}_mean,"
181182
subset_node_id = graph.subset_op(binning_node_id, string[:-1])
182183
graph.write_op(subset_node_id, output_path) # export as .tif
183-
graph.run(snap_gpt_path, snap_parallelization, snap_memory)
184+
graph.run(snap_gpt_path, snap_parallelization, snap_memory, graph_filename=str(Path(__file__).parent.parent / "test_temporary_data" / "s3_pre_snap_graph.xml"))
184185

185186

186187
def produce_median_composite(

eoprofiling.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""Generic profiling context"""
4+
5+
__author__ = "Martin Böttcher, Brockmann Consult GmbH"
6+
__copyright__ = "Copyright 2023, Brockmann Consult GmbH"
7+
__license__ = "TBD"
8+
__version__ = "0.5"
9+
__email__ = "[email protected]"
10+
__status__ = "Development"
11+
12+
# changes in 1.1:
13+
# ...
14+
15+
16+
class Profiling:
17+
18+
def __init__(self, output: str = None):
19+
self._output = output
20+
21+
def __enter__(self):
22+
if self._output:
23+
import cProfile
24+
25+
self._profile = cProfile.Profile()
26+
self._profile.enable()
27+
28+
def __exit__(self, exc_type, exc_val, exc_tb):
29+
if self._output:
30+
self._profile.disable()
31+
import io
32+
import pstats
33+
34+
buffer = io.StringIO()
35+
stats = pstats.Stats(self._profile, stream=buffer)
36+
stats.sort_stats("tottime").print_stats()
37+
with open(self._output, "w") as f:
38+
f.write(buffer.getvalue())

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ dependencies = [
2525
"snap-graph @ git+https://github.com/DHI-GRAS/snap-graph",
2626
"creodias-finder @ git+https://github.com/DHI-GRAS/creodias-finder",
2727

28+
"opencv-python",
2829
"jupyterlab>=4.2.5",
2930
"openeo[localprocessing]>0.30",
3031
"openeo-processes-dask[implementations, ml]",

0 commit comments

Comments
 (0)