diff --git a/CHANGELOG.md b/CHANGELOG.md index 620ae2ad9..7172a3a0a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - `MultiBackendJobManager`: add `download_results` option to enable/disable the automated download of job results once completed by the job manager ([#744](https://github.com/Open-EO/openeo-python-client/issues/744)) +- Support UDF based spatial and temporal extents in `load_collection`, `load_stac` and `filter_temporal` ([#831](https://github.com/Open-EO/openeo-python-client/pull/831)) ### Changed diff --git a/openeo/internal/graph_building.py b/openeo/internal/graph_building.py index 6f5918ea2..10b395fdd 100644 --- a/openeo/internal/graph_building.py +++ b/openeo/internal/graph_building.py @@ -95,6 +95,8 @@ def print_json( class _FromNodeMixin(abc.ABC): """Mixin for classes that want to hook into the generation of a "from_node" reference.""" + # TODO: rename this class: it's more an interface than a mixin, and "from node" might be confusing as explained below. + @abc.abstractmethod def from_node(self) -> PGNode: # TODO: "from_node" is a bit a confusing name: diff --git a/openeo/rest/connection.py b/openeo/rest/connection.py index 07ec9e10e..25f9030fb 100644 --- a/openeo/rest/connection.py +++ b/openeo/rest/connection.py @@ -35,7 +35,12 @@ import openeo from openeo.config import config_log, get_config_option from openeo.internal.documentation import openeo_process -from openeo.internal.graph_building import FlatGraphableMixin, PGNode, as_flat_graph +from openeo.internal.graph_building import ( + FlatGraphableMixin, + PGNode, + _FromNodeMixin, + as_flat_graph, +) from openeo.internal.jupyter import VisualDict, VisualList from openeo.internal.processes.builder import ProcessBuilderBase from openeo.internal.warnings import deprecated, legacy_alias @@ -1186,8 +1191,8 @@ def load_collection( self, collection_id: Union[str, Parameter], spatial_extent: Union[dict, Parameter, shapely.geometry.base.BaseGeometry, str, Path, None] = None, - temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None, - bands: Union[Iterable[str], Parameter, str, None] = None, + temporal_extent: Union[Sequence[InputDate], Parameter, str, _FromNodeMixin, None] = None, + bands: Union[Iterable[str], Parameter, str, _FromNodeMixin, None] = None, properties: Union[ Dict[str, Union[PGNode, Callable]], List[CollectionProperty], CollectionProperty, None ] = None, @@ -1287,8 +1292,10 @@ def load_result( def load_stac( self, url: str, - spatial_extent: Union[dict, Parameter, shapely.geometry.base.BaseGeometry, str, Path, None] = None, - temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None, + spatial_extent: Union[ + dict, Parameter, shapely.geometry.base.BaseGeometry, str, Path, _FromNodeMixin, None + ] = None, + temporal_extent: Union[Sequence[InputDate], Parameter, str, _FromNodeMixin, None] = None, bands: Union[Iterable[str], Parameter, str, None] = None, properties: Union[ Dict[str, Union[PGNode, Callable]], List[CollectionProperty], CollectionProperty, None diff --git a/openeo/rest/datacube.py b/openeo/rest/datacube.py index bd19bb871..238d0bd51 100644 --- a/openeo/rest/datacube.py +++ b/openeo/rest/datacube.py @@ -91,7 +91,7 @@ # Type annotation aliases -InputDate = Union[str, datetime.date, Parameter, PGNode, ProcessBuilderBase, None] +InputDate = Union[str, datetime.date, Parameter, PGNode, ProcessBuilderBase, _FromNodeMixin, None] class DataCube(_ProcessGraphAbstraction): @@ -165,8 +165,10 @@ def load_collection( cls, collection_id: Union[str, Parameter], connection: Optional[Connection] = None, - spatial_extent: Union[dict, Parameter, shapely.geometry.base.BaseGeometry, str, pathlib.Path, None] = None, - temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None, + spatial_extent: Union[ + dict, Parameter, shapely.geometry.base.BaseGeometry, str, pathlib.Path, _FromNodeMixin, None + ] = None, + temporal_extent: Union[Sequence[InputDate], Parameter, str, _FromNodeMixin, None] = None, bands: Union[Iterable[str], Parameter, str, None] = None, fetch_metadata: bool = True, properties: Union[ @@ -480,22 +482,22 @@ def _get_temporal_extent( *args, start_date: InputDate = None, end_date: InputDate = None, - extent: Union[Sequence[InputDate], Parameter, str, None] = None, - ) -> Union[List[Union[str, Parameter, PGNode, None]], Parameter]: + extent: Union[Sequence[InputDate], Parameter, str, _FromNodeMixin, None] = None, + ) -> Union[List[Union[str, Parameter, PGNode, _FromNodeMixin, None]], Parameter, _FromNodeMixin]: """Parameter aware temporal_extent normalizer""" # TODO: move this outside of DataCube class # TODO: return extent as tuple instead of list - if len(args) == 1 and isinstance(args[0], Parameter): + if len(args) == 1 and isinstance(args[0], (Parameter, _FromNodeMixin)): assert start_date is None and end_date is None and extent is None return args[0] - elif len(args) == 0 and isinstance(extent, Parameter): + elif len(args) == 0 and isinstance(extent, (Parameter, _FromNodeMixin)): assert start_date is None and end_date is None # TODO: warn about unexpected parameter schema return extent else: def convertor(d: Any) -> Any: # TODO: can this be generalized through _FromNodeMixin? - if isinstance(d, Parameter) or isinstance(d, PGNode): + if isinstance(d, Parameter) or isinstance(d, _FromNodeMixin): # TODO: warn about unexpected parameter schema return d elif isinstance(d, ProcessBuilderBase): @@ -531,7 +533,7 @@ def filter_temporal( *args, start_date: InputDate = None, end_date: InputDate = None, - extent: Union[Sequence[InputDate], Parameter, str, None] = None, + extent: Union[Sequence[InputDate], Parameter, str, _FromNodeMixin, None] = None, ) -> DataCube: """ Limit the DataCube to a certain date range, which can be specified in several ways: diff --git a/tests/rest/datacube/test_datacube.py b/tests/rest/datacube/test_datacube.py index f99f24726..4389d52ce 100644 --- a/tests/rest/datacube/test_datacube.py +++ b/tests/rest/datacube/test_datacube.py @@ -18,8 +18,10 @@ import shapely import shapely.geometry +import openeo.processes from openeo import collection_property from openeo.api.process import Parameter +from openeo.internal.graph_building import PGNode from openeo.metadata import SpatialDimension from openeo.rest import BandMathException, OpenEoClientException from openeo.rest._testing import build_capabilities @@ -698,6 +700,69 @@ def test_filter_temporal_single_arg(s2cube: DataCube, arg, expect_failure): _ = s2cube.filter_temporal(arg) +@pytest.mark.parametrize( + "udf_factory", + [ + (lambda data, udf, runtime: openeo.processes.run_udf(data=data, udf=udf, runtime=runtime)), + (lambda data, udf, runtime: PGNode(process_id="run_udf", data=data, udf=udf, runtime=runtime)), + ], +) +def test_filter_temporal_from_udf(s2cube: DataCube, udf_factory): + temporal_extent = udf_factory(data=[1, 2, 3], udf="print('hello time')", runtime="Python") + cube = s2cube.filter_temporal(temporal_extent) + assert get_download_graph(cube, drop_save_result=True) == { + "loadcollection1": { + "process_id": "load_collection", + "arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None}, + }, + "runudf1": { + "process_id": "run_udf", + "arguments": {"data": [1, 2, 3], "udf": "print('hello time')", "runtime": "Python"}, + }, + "filtertemporal1": { + "process_id": "filter_temporal", + "arguments": { + "data": {"from_node": "loadcollection1"}, + "extent": {"from_node": "runudf1"}, + }, + }, + } + + +@pytest.mark.parametrize( + "udf_factory", + [ + (lambda data, udf, runtime: openeo.processes.run_udf(data=data, udf=udf, runtime=runtime)), + (lambda data, udf, runtime: PGNode(process_id="run_udf", data=data, udf=udf, runtime=runtime)), + ], +) +def test_filter_temporal_start_end_from_udf(s2cube: DataCube, udf_factory): + start = udf_factory(data=[1, 2, 3], udf="print('hello start')", runtime="Python") + end = udf_factory(data=[4, 5, 6], udf="print('hello end')", runtime="Python") + cube = s2cube.filter_temporal(start_date=start, end_date=end) + assert get_download_graph(cube, drop_save_result=True) == { + "loadcollection1": { + "process_id": "load_collection", + "arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None}, + }, + "runudf1": { + "process_id": "run_udf", + "arguments": {"data": [1, 2, 3], "udf": "print('hello start')", "runtime": "Python"}, + }, + "runudf2": { + "process_id": "run_udf", + "arguments": {"data": [4, 5, 6], "udf": "print('hello end')", "runtime": "Python"}, + }, + "filtertemporal1": { + "process_id": "filter_temporal", + "arguments": { + "data": {"from_node": "loadcollection1"}, + "extent": [{"from_node": "runudf1"}, {"from_node": "runudf2"}], + }, + }, + } + + def test_max_time(s2cube, api_version): im = s2cube.max_time() graph = _get_leaf_node(im, force_flat=True) diff --git a/tests/rest/datacube/test_datacube100.py b/tests/rest/datacube/test_datacube100.py index 97801365b..2788c6a7a 100644 --- a/tests/rest/datacube/test_datacube100.py +++ b/tests/rest/datacube/test_datacube100.py @@ -2375,6 +2375,70 @@ def test_load_collection_parameterized_extents(con100, spatial_extent, temporal_ } +@pytest.mark.parametrize( + "udf_factory", + [ + (lambda data, udf, runtime: openeo.processes.run_udf(data=data, udf=udf, runtime=runtime)), + (lambda data, udf, runtime: PGNode(process_id="run_udf", data=data, udf=udf, runtime=runtime)), + ], +) +def test_load_collection_extents_from_udf(con100, udf_factory): + spatial_extent = udf_factory(data=[1, 2, 3], udf="print('hello space')", runtime="Python") + temporal_extent = udf_factory(data=[4, 5, 6], udf="print('hello time')", runtime="Python") + cube = con100.load_collection("S2", spatial_extent=spatial_extent, temporal_extent=temporal_extent) + assert get_download_graph(cube, drop_save_result=True) == { + "runudf1": { + "process_id": "run_udf", + "arguments": {"data": [1, 2, 3], "udf": "print('hello space')", "runtime": "Python"}, + }, + "runudf2": { + "process_id": "run_udf", + "arguments": {"data": [4, 5, 6], "udf": "print('hello time')", "runtime": "Python"}, + }, + "loadcollection1": { + "process_id": "load_collection", + "arguments": { + "id": "S2", + "spatial_extent": {"from_node": "runudf1"}, + "temporal_extent": {"from_node": "runudf2"}, + }, + }, + } + + +@pytest.mark.parametrize( + "udf_factory", + [ + (lambda data, udf, runtime: openeo.processes.run_udf(data=data, udf=udf, runtime=runtime)), + (lambda data, udf, runtime: PGNode(process_id="run_udf", data=data, udf=udf, runtime=runtime)), + ], +) +def test_load_collection_temporal_extent_from_udf(con100, udf_factory): + temporal_extent = [ + udf_factory(data=[1, 2, 3], udf="print('hello start')", runtime="Python"), + udf_factory(data=[4, 5, 6], udf="print('hello end')", runtime="Python"), + ] + cube = con100.load_collection("S2", temporal_extent=temporal_extent) + assert get_download_graph(cube, drop_save_result=True) == { + "runudf1": { + "process_id": "run_udf", + "arguments": {"data": [1, 2, 3], "udf": "print('hello start')", "runtime": "Python"}, + }, + "runudf2": { + "process_id": "run_udf", + "arguments": {"data": [4, 5, 6], "udf": "print('hello end')", "runtime": "Python"}, + }, + "loadcollection1": { + "process_id": "load_collection", + "arguments": { + "id": "S2", + "spatial_extent": None, + "temporal_extent": [{"from_node": "runudf1"}, {"from_node": "runudf2"}], + }, + }, + } + + def test_apply_dimension_temporal_cumsum_with_target(con100, test_data): cumsum = con100.load_collection("S2").apply_dimension('cumsum', dimension="t", target_dimension="MyNewTime") actual_graph = cumsum.flat_graph() diff --git a/tests/rest/test_connection.py b/tests/rest/test_connection.py index 0fc0a3976..6cf05e62a 100644 --- a/tests/rest/test_connection.py +++ b/tests/rest/test_connection.py @@ -17,6 +17,7 @@ import shapely.geometry import openeo +import openeo.processes from openeo import BatchJob from openeo.api.process import Parameter from openeo.internal.graph_building import FlatGraphableMixin, PGNode @@ -3715,6 +3716,73 @@ def test_load_stac_spatial_extent_vector_cube(self, dummy_backend): }, } + @pytest.mark.parametrize( + "udf_factory", + [ + (lambda data, udf, runtime: openeo.processes.run_udf(data=data, udf=udf, runtime=runtime)), + (lambda data, udf, runtime: PGNode(process_id="run_udf", data=data, udf=udf, runtime=runtime)), + ], + ) + def test_load_stac_extents_from_udf(self, dummy_backend, udf_factory): + spatial_extent = udf_factory(data=[1, 2, 3], udf="print('hello space')", runtime="Python") + temporal_extent = udf_factory(data=[4, 5, 6], udf="print('hello time')", runtime="Python") + cube = dummy_backend.connection.load_stac( + "https://stac.test/data", spatial_extent=spatial_extent, temporal_extent=temporal_extent + ) + cube.execute() + assert dummy_backend.get_sync_pg() == { + "runudf1": { + "process_id": "run_udf", + "arguments": {"data": [1, 2, 3], "udf": "print('hello space')", "runtime": "Python"}, + }, + "runudf2": { + "process_id": "run_udf", + "arguments": {"data": [4, 5, 6], "udf": "print('hello time')", "runtime": "Python"}, + }, + "loadstac1": { + "process_id": "load_stac", + "arguments": { + "url": "https://stac.test/data", + "spatial_extent": {"from_node": "runudf1"}, + "temporal_extent": {"from_node": "runudf2"}, + }, + "result": True, + }, + } + + @pytest.mark.parametrize( + "udf_factory", + [ + (lambda data, udf, runtime: openeo.processes.run_udf(data=data, udf=udf, runtime=runtime)), + (lambda data, udf, runtime: PGNode(process_id="run_udf", data=data, udf=udf, runtime=runtime)), + ], + ) + def test_load_stac_temporal_extent_from_udf(self, dummy_backend, udf_factory): + temporal_extent = [ + udf_factory(data=[1, 2, 3], udf="print('hello start')", runtime="Python"), + udf_factory(data=[4, 5, 6], udf="print('hello end')", runtime="Python"), + ] + cube = dummy_backend.connection.load_stac("https://stac.test/data", temporal_extent=temporal_extent) + cube.execute() + assert dummy_backend.get_sync_pg() == { + "runudf1": { + "process_id": "run_udf", + "arguments": {"data": [1, 2, 3], "udf": "print('hello start')", "runtime": "Python"}, + }, + "runudf2": { + "process_id": "run_udf", + "arguments": {"data": [4, 5, 6], "udf": "print('hello end')", "runtime": "Python"}, + }, + "loadstac1": { + "process_id": "load_stac", + "arguments": { + "url": "https://stac.test/data", + "temporal_extent": [{"from_node": "runudf1"}, {"from_node": "runudf2"}], + }, + "result": True, + }, + } + @pytest.mark.parametrize( "data",