1010
1111import abc
1212import collections
13+ import copy
1314import json
1415import sys
1516from contextlib import nullcontext
1617from pathlib import Path
17- from typing import Any , Dict , Iterator , Optional , Tuple , Union
18+ from typing import Any , Dict , Iterable , Iterator , List , Optional , Tuple , Union
1819
1920from openeo .api .process import Parameter
2021from openeo .internal .process_graph_visitor import (
@@ -243,7 +244,7 @@ def walk(x) -> Iterator[PGNode]:
243244 yield from walk (self .arguments )
244245
245246
246- def as_flat_graph (x : Union [dict , FlatGraphableMixin , Path , Any ]) -> Dict [str , dict ]:
247+ def as_flat_graph (x : Union [dict , FlatGraphableMixin , Path , List [ FlatGraphableMixin ], Any ]) -> Dict [str , dict ]:
247248 """
248249 Convert given object to a internal flat dict graph representation.
249250 """
@@ -252,12 +253,15 @@ def as_flat_graph(x: Union[dict, FlatGraphableMixin, Path, Any]) -> Dict[str, di
252253 # including `{"process_graph": {nodes}}` ("process graph")
253254 # or just the raw process graph nodes?
254255 if isinstance (x , dict ):
256+ # Assume given dict is already a flat graph representation
255257 return x
256258 elif isinstance (x , FlatGraphableMixin ):
257259 return x .flat_graph ()
258260 elif isinstance (x , (str , Path )):
259261 # Assume a JSON resource (raw JSON, path to local file, JSON url, ...)
260262 return load_json_resource (x )
263+ elif isinstance (x , (list , tuple )) and all (isinstance (i , FlatGraphableMixin ) for i in x ):
264+ return MultiLeafGraph (x ).flat_graph ()
261265 raise ValueError (x )
262266
263267
@@ -322,20 +326,29 @@ def generate(self, process_id: str):
322326
323327class GraphFlattener (ProcessGraphVisitor ):
324328
325- def __init__ (self , node_id_generator : FlatGraphNodeIdGenerator = None ):
329+ def __init__ (self , node_id_generator : FlatGraphNodeIdGenerator = None , multi_input_mode : bool = False ):
326330 super ().__init__ ()
327331 self ._node_id_generator = node_id_generator or FlatGraphNodeIdGenerator ()
328332 self ._last_node_id = None
329333 self ._flattened : Dict [str , dict ] = {}
330334 self ._argument_stack = []
331335 self ._node_cache = {}
336+ self ._multi_input_mode = multi_input_mode
332337
333338 def flatten (self , node : PGNode ) -> Dict [str , dict ]:
334339 """Consume given nested process graph and return flat dict representation"""
340+ if self ._flattened and not self ._multi_input_mode :
341+ raise RuntimeError ("Flattening multiple graphs, but not in multi-input mode" )
335342 self .accept_node (node )
336343 assert len (self ._argument_stack ) == 0
337- self ._flattened [self ._last_node_id ]["result" ] = True
338- return self ._flattened
344+ return self .flattened (set_result_flag = not self ._multi_input_mode )
345+
346+ def flattened (self , set_result_flag : bool = True ) -> Dict [str , dict ]:
347+ flat_graph = copy .deepcopy (self ._flattened )
348+ if set_result_flag :
349+ # TODO #583 an "end" node is not necessarily a "result" node
350+ flat_graph [self ._last_node_id ]["result" ] = True
351+ return flat_graph
339352
340353 def accept_node (self , node : PGNode ):
341354 # Process reused nodes only first time and remember node id.
@@ -438,3 +451,26 @@ def _process_from_parameter(self, name: str) -> Any:
438451 if name not in self ._parameters :
439452 raise ProcessGraphVisitException ("No substitution value for parameter {p!r}." .format (p = name ))
440453 return self ._parameters [name ]
454+
455+
456+ class MultiLeafGraph (FlatGraphableMixin ):
457+ """
458+ Container for process graphs with multiple leaf/result nodes.
459+ """
460+
461+ __slots__ = ["_leaves" ]
462+
463+ def __init__ (self , leaves : Iterable [FlatGraphableMixin ]):
464+ self ._leaves = list (leaves )
465+
466+ def flat_graph (self ) -> Dict [str , dict ]:
467+ flattener = GraphFlattener (multi_input_mode = True )
468+ for leaf in self ._leaves :
469+ if isinstance (leaf , PGNode ):
470+ flattener .flatten (leaf )
471+ elif isinstance (leaf , _FromNodeMixin ):
472+ flattener .flatten (leaf .from_node ())
473+ else :
474+ raise ValueError (f"Unsupported type { type (leaf )} " )
475+
476+ return flattener .flattened (set_result_flag = True )
0 commit comments