Skip to content

Commit 35a0b9f

Browse files
jdriessoxofaan
andauthored
Dry run performance optimization for huge process graphs (#427)
* add test that shows large graph is slow * replace double-for loop with list comprehension: is faster when processing 4000000 traces * logging impacts performance: reduce it * #426 changelog * skip long test * Fix pytest skip #427 * Further optimization of DryRunDataTracer.get_trace_leaves #426/#427 --------- Co-authored-by: Stefaan Lippens <[email protected]>
1 parent f0425e8 commit 35a0b9f

File tree

5 files changed

+4842
-18
lines changed

5 files changed

+4842
-18
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ and start a new "In Progress" section above it.
2222
## In progress: 0.136.0
2323

2424
- Start supporting custom `UdfRuntimes` implementation in `OpenEoBackendImplementation` ([#415](https://github.com/Open-EO/openeo-python-driver/issues/415))
25-
25+
- Process graph parsing (dry-run) for very large graphs got faster. ([#426](https://github.com/Open-EO/openeo-python-driver/issues/426))
2626

2727
## 0.135.0
2828

openeo_driver/ProcessGraphDeserializer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1914,14 +1914,13 @@ def check_subgraph_for_data_mask_optimization(args: dict) -> bool:
19141914

19151915

19161916
def apply_process(process_id: str, args: dict, namespace: Union[str, None], env: EvalEnv) -> DriverDataCube:
1917-
_log.debug(f"apply_process {process_id} with {args}")
1917+
_log.debug(f"apply_process {process_id} ")
19181918
parameters = env.collect_parameters()
19191919

19201920
if process_id == "mask" and args.get("replacement", None) is None \
19211921
and smart_bool(env.get("data_mask_optimization", True)):
19221922
mask_node = args.get("mask", None)
19231923
# evaluate the mask
1924-
_log.debug(f"data_mask: convert_node(mask_node): {mask_node}")
19251924
the_mask = convert_node(mask_node, env=env)
19261925
dry_run_tracer: DryRunDataTracer = env.get(ENV_DRY_RUN_TRACER)
19271926
if not dry_run_tracer and check_subgraph_for_data_mask_optimization(args):
@@ -1958,7 +1957,6 @@ def apply_process(process_id: str, args: dict, namespace: Union[str, None], env:
19581957

19591958
try:
19601959
process_function = process_registry.get_function(process_id, namespace=(namespace or "backend"))
1961-
_log.debug(f"Applying process {process_id} to arguments {args}")
19621960
#TODO: for API compliance, we would actually first need to check if a UDP with same name exists.
19631961
# we would however prefer to avoid overriding predefined functions with UDP's.
19641962
# if we want to do this, we require caching in UDP registry to avoid expensive UDP lookups. We only need to cache the list of UDP names for a given user.

openeo_driver/dry_run.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838
import logging
3939
from enum import Enum
40-
from typing import List, Optional, Tuple, Union
40+
from typing import List, Optional, Tuple, Union, Iterator
4141

4242
import numpy
4343
import shapely.geometry.base
@@ -100,8 +100,10 @@
100100
class DataTraceBase:
101101
"""Base class for data traces."""
102102

103+
__slots__ = ["children"]
104+
103105
def __init__(self):
104-
self.children = []
106+
self.children: List[DataTraceBase] = []
105107

106108
def __hash__(self):
107109
# Identity hash (e.g. memory address)
@@ -369,19 +371,19 @@ def get_trace_leaves(self) -> List[DataTraceBase]:
369371
Get all nodes in the tree of traces that are not parent of another trace.
370372
In openEO this could be for instance a save_result process that ends the workflow.
371373
"""
372-
leaves = []
373-
374-
def get_leaves(tree: DataTraceBase) -> List[DataTraceBase]:
375-
return (
376-
[tree] if len(tree.children) == 0 else [leaf for child in tree.children for leaf in get_leaves(child)]
377-
)
378-
379-
for trace in self._traces:
380-
for leaf in get_leaves(trace):
381-
if leaf not in leaves:
382-
leaves.append(leaf)
374+
visited = set()
375+
376+
def get_leaves(trace: DataTraceBase) -> Iterator[DataTraceBase]:
377+
nonlocal visited
378+
if trace not in visited:
379+
visited.add(trace)
380+
if trace.children:
381+
for child in trace.children:
382+
yield from get_leaves(child)
383+
else:
384+
yield trace
383385

384-
return leaves
386+
return [leave for trace in self._traces for leave in get_leaves(trace)]
385387

386388
def get_metadata_links(self):
387389
result = {}

0 commit comments

Comments
 (0)