Skip to content

Commit 5ffc554

Browse files
Merge pull request #709 from Open-EO/execute_local_udf_context
Execute local udf context
2 parents ee7941a + ab9c070 commit 5ffc554

File tree

3 files changed

+39
-5
lines changed

3 files changed

+39
-5
lines changed

openeo/udf/run_code.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,9 @@ def run_udf_code(code: str, data: UdfData) -> UdfData:
235235
raise OpenEoUdfException("No UDF found.")
236236

237237

238-
def execute_local_udf(udf: Union[str, openeo.UDF], datacube: Union[str, xarray.DataArray, XarrayDataCube], fmt='netcdf'):
238+
def execute_local_udf(
239+
udf: Union[str, openeo.UDF], datacube: Union[str, pathlib.Path, xarray.DataArray, XarrayDataCube], fmt="netcdf"
240+
):
239241
"""
240242
Locally executes an user defined function on a previously downloaded datacube.
241243
@@ -244,8 +246,8 @@ def execute_local_udf(udf: Union[str, openeo.UDF], datacube: Union[str, xarray.D
244246
:param fmt: format of the file if datacube is string
245247
:return: the resulting DataCube
246248
"""
247-
if isinstance(udf, openeo.UDF):
248-
udf = udf.code
249+
if isinstance(udf, str):
250+
udf = openeo.UDF(code=udf)
249251

250252
if isinstance(datacube, (str, pathlib.Path)):
251253
d = XarrayDataCube.from_file(path=datacube, fmt=fmt)
@@ -266,13 +268,13 @@ def execute_local_udf(udf: Union[str, openeo.UDF], datacube: Union[str, xarray.D
266268
.astype(numpy.float64)
267269
)
268270
# wrap to udf_data
269-
udf_data = UdfData(datacube_list=[d])
271+
udf_data = UdfData(datacube_list=[d], user_context=udf.context)
270272

271273
# TODO: enrich to other types like time series, vector data,... probalby by adding named arguments
272274
# signature: UdfData(proj, datacube_list, feature_collection_list, structured_data_list, ml_model_list, metadata)
273275

274276
# run the udf through the same routine as it would have been parsed in the backend
275-
result = run_udf_code(udf, udf_data)
277+
result = run_udf_code(udf.code, udf_data)
276278
return result
277279

278280

tests/udf/test_run_code.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,31 @@ def test_run_local_udf_from_file_netcdf(tmp_path):
307307
assert result[2, 0, 4, 3] == _ndvi(2034, 2134)
308308

309309

310+
def test_run_local_udf_from_file_netcdf_with_context(tmp_path):
311+
udf_code = _get_udf_code("multiply_factor.py")
312+
xdc = _build_xdc(
313+
ts=[numpy.datetime64("2020-08-01"), numpy.datetime64("2020-08-11"), numpy.datetime64("2020-08-21")],
314+
bands=["bandzero", "bandone"],
315+
xs=[10.0, 11.0, 12.0, 13.0, 14.0],
316+
ys=[20.0, 21.0, 22.0, 23.0, 24.0, 25.0],
317+
)
318+
assert xdc.array.shape == (3, 2, 5, 6)
319+
data_path = tmp_path / "data.nc"
320+
xdc.save_to_file(path=data_path, fmt="netcdf")
321+
322+
factor = 100
323+
udf = UDF(udf_code, runtime="Python", context={"factor": factor})
324+
res = execute_local_udf(udf, data_path, fmt="netcdf")
325+
326+
assert isinstance(res, UdfData)
327+
result = res.get_datacube_list()[0].get_array()
328+
329+
assert result.shape == (3, 2, 6, 5)
330+
swapped_result = result.transpose("t", "bands", "x", "y")
331+
expected = xdc.array * factor
332+
xarray.testing.assert_equal(swapped_result, expected)
333+
334+
310335
def _is_package_available(name: str) -> bool:
311336
# TODO: move this to a more general test utility module.
312337
return importlib.util.find_spec(name) is not None
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from openeo.udf import XarrayDataCube
2+
3+
4+
def apply_datacube(cube: XarrayDataCube, context: dict) -> XarrayDataCube:
5+
factor = context["factor"]
6+
array = cube.get_array() * factor
7+
return XarrayDataCube(array)

0 commit comments

Comments
 (0)