Skip to content

Commit d3407c3

Browse files
committed
Merge branch 'issue391-multiple-result-nodes'
2 parents 1fef132 + d84fcc4 commit d3407c3

File tree

14 files changed

+823
-22
lines changed

14 files changed

+823
-22
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
### Added
1111

12+
- Added `MultiResult` helper class to build process graphs with multiple result nodes ([#391](https://github.com/Open-EO/openeo-python-client/issues/391))
13+
1214
### Changed
1315

1416
### Removed

docs/api.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@ openeo.rest.mlmodel
4747
:inherited-members:
4848

4949

50+
openeo.rest.multiresult
51+
-----------------------
52+
53+
.. automodule:: openeo.rest.multiresult
54+
:members: MultiResult
55+
:inherited-members:
56+
:special-members: __init__
57+
58+
5059
openeo.metadata
5160
----------------
5261

docs/datacube_construction.rst

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,55 @@ Re-parameterization
196196
```````````````````
197197

198198
TODO
199+
200+
201+
202+
.. _multi-result-process-graphs:
203+
Building process graphs with multiple result nodes
204+
===================================================
205+
206+
.. note::
207+
Multi-result support is added in version 0.35.0
208+
209+
Most openEO use cases are just about building a single result data cube,
210+
which is readily covered in the openEO Python client library through classes like
211+
:py:class:`~openeo.rest.datacube.DataCube` and :py:class:`~openeo.rest.vectorcube.VectorCube`.
212+
It is straightforward to create a batch job from these, or execute/download them synchronously.
213+
214+
The openEO API also allows multiple result nodes in a single process graph,
215+
for example to persist intermediate results or produce results in different output formats.
216+
To support this, the openEO Python client library provides the :py:class:`~openeo.rest.multiresult.MultiResult` class,
217+
which allows to group multiple :py:class:`~openeo.rest.datacube.DataCube` and :py:class:`~openeo.rest.vectorcube.VectorCube` objects
218+
in a single entity that can be used to create or run batch jobs. For example:
219+
220+
221+
.. code-block:: python
222+
223+
from openeo import MultiResult
224+
225+
cube1 = ...
226+
cube2 = ...
227+
multi_result = MultiResult([cube1, cube2])
228+
job = multi_result.create_job()
229+
230+
231+
Moreover, it is not necessary to explicitly create such a
232+
:py:class:`~openeo.rest.multiresult.MultiResult` object,
233+
as the :py:meth:`Connection.create_job() <openeo.rest.connection.Connection.create_job>` method
234+
directly supports passing multiple data cube objects in a list,
235+
which will be automatically grouped as a multi-result:
236+
237+
.. code-block:: python
238+
239+
cube1 = ...
240+
cube2 = ...
241+
job = connection.create_job([cube1, cube2])
242+
243+
244+
.. important::
245+
246+
Only a single :py:class:`~openeo.rest.connection.Connection` can be in play
247+
when grouping multiple results like this.
248+
As everything is to be merged in a single process graph
249+
to be sent to a single backend,
250+
it is not possible to mix cubes created from different connections.

openeo/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class BaseOpenEoException(Exception):
1818
from openeo.rest.datacube import UDF, DataCube
1919
from openeo.rest.graph_building import collection_property
2020
from openeo.rest.job import BatchJob, RESTJob
21+
from openeo.rest.multiresult import MultiResult
2122

2223

2324
def client_version() -> str:

openeo/internal/graph_building.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010

1111
import abc
1212
import collections
13+
import copy
1314
import json
1415
import sys
1516
from contextlib import nullcontext
1617
from pathlib import Path
17-
from typing import Any, Dict, Iterator, Optional, Tuple, Union
18+
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
1819

1920
from openeo.api.process import Parameter
2021
from openeo.internal.process_graph_visitor import (
@@ -243,7 +244,7 @@ def walk(x) -> Iterator[PGNode]:
243244
yield from walk(self.arguments)
244245

245246

246-
def as_flat_graph(x: Union[dict, FlatGraphableMixin, Path, Any]) -> Dict[str, dict]:
247+
def as_flat_graph(x: Union[dict, FlatGraphableMixin, Path, List[FlatGraphableMixin], Any]) -> Dict[str, dict]:
247248
"""
248249
Convert given object to a internal flat dict graph representation.
249250
"""
@@ -252,12 +253,15 @@ def as_flat_graph(x: Union[dict, FlatGraphableMixin, Path, Any]) -> Dict[str, di
252253
# including `{"process_graph": {nodes}}` ("process graph")
253254
# or just the raw process graph nodes?
254255
if isinstance(x, dict):
256+
# Assume given dict is already a flat graph representation
255257
return x
256258
elif isinstance(x, FlatGraphableMixin):
257259
return x.flat_graph()
258260
elif isinstance(x, (str, Path)):
259261
# Assume a JSON resource (raw JSON, path to local file, JSON url, ...)
260262
return load_json_resource(x)
263+
elif isinstance(x, (list, tuple)) and all(isinstance(i, FlatGraphableMixin) for i in x):
264+
return MultiLeafGraph(x).flat_graph()
261265
raise ValueError(x)
262266

263267

@@ -322,20 +326,29 @@ def generate(self, process_id: str):
322326

323327
class GraphFlattener(ProcessGraphVisitor):
324328

325-
def __init__(self, node_id_generator: FlatGraphNodeIdGenerator = None):
329+
def __init__(self, node_id_generator: FlatGraphNodeIdGenerator = None, multi_input_mode: bool = False):
326330
super().__init__()
327331
self._node_id_generator = node_id_generator or FlatGraphNodeIdGenerator()
328332
self._last_node_id = None
329333
self._flattened: Dict[str, dict] = {}
330334
self._argument_stack = []
331335
self._node_cache = {}
336+
self._multi_input_mode = multi_input_mode
332337

333338
def flatten(self, node: PGNode) -> Dict[str, dict]:
334339
"""Consume given nested process graph and return flat dict representation"""
340+
if self._flattened and not self._multi_input_mode:
341+
raise RuntimeError("Flattening multiple graphs, but not in multi-input mode")
335342
self.accept_node(node)
336343
assert len(self._argument_stack) == 0
337-
self._flattened[self._last_node_id]["result"] = True
338-
return self._flattened
344+
return self.flattened(set_result_flag=not self._multi_input_mode)
345+
346+
def flattened(self, set_result_flag: bool = True) -> Dict[str, dict]:
347+
flat_graph = copy.deepcopy(self._flattened)
348+
if set_result_flag:
349+
# TODO #583 an "end" node is not necessarily a "result" node
350+
flat_graph[self._last_node_id]["result"] = True
351+
return flat_graph
339352

340353
def accept_node(self, node: PGNode):
341354
# Process reused nodes only first time and remember node id.
@@ -438,3 +451,26 @@ def _process_from_parameter(self, name: str) -> Any:
438451
if name not in self._parameters:
439452
raise ProcessGraphVisitException("No substitution value for parameter {p!r}.".format(p=name))
440453
return self._parameters[name]
454+
455+
456+
class MultiLeafGraph(FlatGraphableMixin):
457+
"""
458+
Container for process graphs with multiple leaf/result nodes.
459+
"""
460+
461+
__slots__ = ["_leaves"]
462+
463+
def __init__(self, leaves: Iterable[FlatGraphableMixin]):
464+
self._leaves = list(leaves)
465+
466+
def flat_graph(self) -> Dict[str, dict]:
467+
flattener = GraphFlattener(multi_input_mode=True)
468+
for leaf in self._leaves:
469+
if isinstance(leaf, PGNode):
470+
flattener.flatten(leaf)
471+
elif isinstance(leaf, _FromNodeMixin):
472+
flattener.flatten(leaf.from_node())
473+
else:
474+
raise ValueError(f"Unsupported type {type(leaf)}")
475+
476+
return flattener.flattened(set_result_flag=True)

openeo/rest/_testing.py

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from __future__ import annotations
2+
13
import collections
24
import json
35
import re
4-
from typing import Callable, Iterator, Optional, Sequence, Union
6+
from typing import Callable, Iterable, Iterator, Optional, Sequence, Tuple, Union
57

68
from openeo import Connection, DataCube
79
from openeo.rest.vectorcube import VectorCube
@@ -19,8 +21,12 @@ class DummyBackend:
1921
and allows inspection of posted process graphs
2022
"""
2123

24+
# TODO: move to openeo.testing
25+
2226
__slots__ = (
27+
"_requests_mock",
2328
"connection",
29+
"file_formats",
2430
"sync_requests",
2531
"batch_jobs",
2632
"validation_requests",
@@ -33,8 +39,14 @@ class DummyBackend:
3339
# Default result (can serve both as JSON or binary data)
3440
DEFAULT_RESULT = b'{"what?": "Result data"}'
3541

36-
def __init__(self, requests_mock, connection: Connection):
42+
def __init__(
43+
self,
44+
requests_mock,
45+
connection: Connection,
46+
):
47+
self._requests_mock = requests_mock
3748
self.connection = connection
49+
self.file_formats = {"input": {}, "output": {}}
3850
self.sync_requests = []
3951
self.batch_jobs = {}
4052
self.validation_requests = []
@@ -69,6 +81,59 @@ def __init__(self, requests_mock, connection: Connection):
6981
)
7082
requests_mock.post(connection.build_url("/validation"), json=self._handle_post_validation)
7183

84+
@classmethod
85+
def at_url(cls, root_url: str, *, requests_mock, capabilities: Optional[dict] = None) -> DummyBackend:
86+
"""
87+
Factory to build dummy backend from given root URL
88+
including creation of connection and mocking of capabilities doc
89+
"""
90+
root_url = root_url.rstrip("/") + "/"
91+
requests_mock.get(root_url, json=build_capabilities(**(capabilities or None)))
92+
connection = Connection(root_url)
93+
return cls(requests_mock=requests_mock, connection=connection)
94+
95+
def setup_collection(
96+
self,
97+
collection_id: str,
98+
*,
99+
temporal: Union[bool, Tuple[str, str]] = True,
100+
bands: Sequence[str] = ("B1", "B2", "B3"),
101+
):
102+
# TODO: also mock `/collections` overview
103+
# TODO: option to override cube_dimensions as a whole, or override dimension names
104+
cube_dimensions = {
105+
"x": {"type": "spatial"},
106+
"y": {"type": "spatial"},
107+
}
108+
109+
if temporal:
110+
cube_dimensions["t"] = {
111+
"type": "temporal",
112+
"extent": temporal if isinstance(temporal, tuple) else [None, None],
113+
}
114+
if bands:
115+
cube_dimensions["bands"] = {"type": "bands", "values": list(bands)}
116+
117+
self._requests_mock.get(
118+
self.connection.build_url(f"/collections/{collection_id}"),
119+
# TODO: add more metadata?
120+
json={
121+
"id": collection_id,
122+
# define temporal and band dim
123+
"cube:dimensions": {"t": {"type": "temporal"}, "bands": {"type": "bands"}},
124+
},
125+
)
126+
return self
127+
128+
def setup_file_format(self, name: str, type: str = "output", gis_data_types: Iterable[str] = ("raster",)):
129+
self.file_formats[type][name] = {
130+
"title": name,
131+
"gis_data_types": list(gis_data_types),
132+
"parameters": {},
133+
}
134+
self._requests_mock.get(self.connection.build_url("/file_formats"), json=self.file_formats)
135+
return self
136+
72137
def _handle_post_result(self, request, context):
73138
"""handler of `POST /result` (synchronous execute)"""
74139
pg = request.json()["process"]["process_graph"]
@@ -150,10 +215,20 @@ def get_sync_pg(self) -> dict:
150215
return self.sync_requests[0]
151216

152217
def get_batch_pg(self) -> dict:
153-
"""Get one and only batch process graph"""
218+
"""
219+
Get process graph of the one and only batch job.
220+
Fails when there is none or more than one.
221+
"""
154222
assert len(self.batch_jobs) == 1
155223
return self.batch_jobs[max(self.batch_jobs.keys())]["pg"]
156224

225+
def get_validation_pg(self) -> dict:
226+
"""
227+
Get process graph of the one and only validation request.
228+
"""
229+
assert len(self.validation_requests) == 1
230+
return self.validation_requests[0]
231+
157232
def get_pg(self, process_id: Optional[str] = None) -> dict:
158233
"""
159234
Get one and only batch process graph (sync or batch)

0 commit comments

Comments
 (0)