Skip to content

Commit da629e5

Browse files
Include UDF context when running local UDF. #556
1 parent ee7941a commit da629e5

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

openeo/udf/run_code.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,9 @@ def run_udf_code(code: str, data: UdfData) -> UdfData:
235235
raise OpenEoUdfException("No UDF found.")
236236

237237

238-
def execute_local_udf(udf: Union[str, openeo.UDF], datacube: Union[str, xarray.DataArray, XarrayDataCube], fmt='netcdf'):
238+
def execute_local_udf(
239+
udf: Union[str, openeo.UDF], datacube: Union[str, pathlib.Path, xarray.DataArray, XarrayDataCube], fmt="netcdf"
240+
):
239241
"""
240242
Locally executes an user defined function on a previously downloaded datacube.
241243
@@ -245,7 +247,11 @@ def execute_local_udf(udf: Union[str, openeo.UDF], datacube: Union[str, xarray.D
245247
:return: the resulting DataCube
246248
"""
247249
if isinstance(udf, openeo.UDF):
248-
udf = udf.code
250+
udf_code = udf.code
251+
elif isinstance(udf, str):
252+
udf_code = udf
253+
else:
254+
raise ValueError(udf)
249255

250256
if isinstance(datacube, (str, pathlib.Path)):
251257
d = XarrayDataCube.from_file(path=datacube, fmt=fmt)
@@ -266,13 +272,13 @@ def execute_local_udf(udf: Union[str, openeo.UDF], datacube: Union[str, xarray.D
266272
.astype(numpy.float64)
267273
)
268274
# wrap to udf_data
269-
udf_data = UdfData(datacube_list=[d])
275+
udf_data = UdfData(datacube_list=[d], user_context=udf.context)
270276

271277
# TODO: enrich to other types like time series, vector data,... probalby by adding named arguments
272278
# signature: UdfData(proj, datacube_list, feature_collection_list, structured_data_list, ml_model_list, metadata)
273279

274280
# run the udf through the same routine as it would have been parsed in the backend
275-
result = run_udf_code(udf, udf_data)
281+
result = run_udf_code(udf_code, udf_data)
276282
return result
277283

278284

tests/udf/test_run_code.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,31 @@ def test_run_local_udf_from_file_netcdf(tmp_path):
307307
assert result[2, 0, 4, 3] == _ndvi(2034, 2134)
308308

309309

310+
def test_run_local_udf_from_file_netcdf_with_context(tmp_path):
311+
udf_code = _get_udf_code("multiply_factor.py")
312+
xdc = _build_xdc(
313+
ts=[numpy.datetime64("2020-08-01"), numpy.datetime64("2020-08-11"), numpy.datetime64("2020-08-21")],
314+
bands=["bandzero", "bandone"],
315+
xs=[10.0, 11.0, 12.0, 13.0, 14.0],
316+
ys=[20.0, 21.0, 22.0, 23.0, 24.0, 25.0],
317+
)
318+
assert xdc.array.shape == (3, 2, 5, 6)
319+
data_path = tmp_path / "data.nc"
320+
xdc.save_to_file(path=data_path, fmt="netcdf")
321+
322+
factor = 100
323+
udf = UDF(udf_code, runtime="Python", context={"factor": factor})
324+
res = execute_local_udf(udf, data_path, fmt="netcdf")
325+
326+
assert isinstance(res, UdfData)
327+
result = res.get_datacube_list()[0].get_array()
328+
329+
assert result.shape == (3, 2, 6, 5)
330+
swapped_result = result.transpose("t", "bands", "x", "y")
331+
expected = xdc.array * factor
332+
xarray.testing.assert_equal(swapped_result, expected)
333+
334+
310335
def _is_package_available(name: str) -> bool:
311336
# TODO: move this to a more general test utility module.
312337
return importlib.util.find_spec(name) is not None

0 commit comments

Comments
 (0)