Skip to content

Commit 4485f55

Browse files
committed
post_dry_run: add caching to determine_global_extent workflow #1565
introduces caching on _extract_spatial_extent_from_load_stac_item_collection calls involves introducing new caching utilities (GetOrCallCacheInterface) and making some data structs cache key compatible (__hash__ + __eq__)
1 parent eea3aa9 commit 4485f55

File tree

4 files changed

+295
-14
lines changed

4 files changed

+295
-14
lines changed

openeogeotrellis/_backend/post_dry_run.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from openeogeotrellis.constants import EVAL_ENV_KEY
1818

1919
import openeogeotrellis.load_stac
20+
from openeogeotrellis.util.caching import GetOrCallCache, GetOrCallCacheInterface, AlwaysCallWithoutCache
2021
from openeogeotrellis.util.geometry import BoundingBoxMerger
2122
from openeogeotrellis.util.math import logarithmic_round
2223

@@ -198,7 +199,10 @@ class AlignedExtentResult:
198199

199200

200201
def _extract_spatial_extent_from_constraint(
201-
source_constraint: SourceConstraint, *, catalog: AbstractCollectionCatalog
202+
source_constraint: SourceConstraint,
203+
*,
204+
catalog: AbstractCollectionCatalog,
205+
cache: Optional[GetOrCallCacheInterface] = None,
202206
) -> Union[None, AlignedExtentResult]:
203207
"""
204208
Extract spatial extent from given source constraint (if any), and align it to target grid.
@@ -211,18 +215,22 @@ def _extract_spatial_extent_from_constraint(
211215
if source_process == "load_collection":
212216
collection_id = source_id[1][0]
213217
return _extract_spatial_extent_from_constraint_load_collection(
214-
collection_id=collection_id, constraint=constraint, catalog=catalog
218+
collection_id=collection_id, constraint=constraint, catalog=catalog, cache=cache
215219
)
216220
elif source_process == "load_stac":
217221
url = source_id[1][0]
218-
return _extract_spatial_extent_from_constraint_load_stac(stac_url=url, constraint=constraint)
222+
return _extract_spatial_extent_from_constraint_load_stac(stac_url=url, constraint=constraint, cache=cache)
219223
else:
220224
# TODO?
221225
return None
222226

223227

224228
def _extract_spatial_extent_from_constraint_load_collection(
225-
collection_id: str, *, constraint: dict, catalog: AbstractCollectionCatalog
229+
collection_id: str,
230+
*,
231+
constraint: dict,
232+
catalog: AbstractCollectionCatalog,
233+
cache: Optional[GetOrCallCacheInterface] = None,
226234
) -> Union[None, AlignedExtentResult]:
227235
try:
228236
metadata = catalog.get_collection_metadata(collection_id)
@@ -233,7 +241,7 @@ def _extract_spatial_extent_from_constraint_load_collection(
233241
stac_url = deep_get(metadata, "_vito", "data_source", "url")
234242
load_stac_feature_flags = deep_get(metadata, "_vito", "data_source", "load_stac_feature_flags", default={})
235243
return _extract_spatial_extent_from_constraint_load_stac(
236-
stac_url=stac_url, constraint=constraint, feature_flags=load_stac_feature_flags
244+
stac_url=stac_url, constraint=constraint, feature_flags=load_stac_feature_flags, cache=cache
237245
)
238246

239247
# TODO Extracting pixel grid info from collection metadata might might be unreliable
@@ -272,7 +280,11 @@ def _extract_spatial_extent_from_constraint_load_collection(
272280

273281

274282
def _extract_spatial_extent_from_constraint_load_stac(
275-
stac_url: str, *, constraint: dict, feature_flags: Optional[dict] = None
283+
stac_url: str,
284+
*,
285+
constraint: dict,
286+
feature_flags: Optional[dict] = None,
287+
cache: Optional[GetOrCallCacheInterface] = None,
276288
) -> Union[None, AlignedExtentResult]:
277289
spatial_extent_from_pg = constraint.get("spatial_extent") or constraint.get("weak_spatial_extent")
278290
spatiotemporal_extent = openeogeotrellis.load_stac._spatiotemporal_extent_from_load_params(
@@ -293,12 +305,23 @@ def _extract_spatial_extent_from_constraint_load_stac(
293305

294306
pixel_buffer_size = deep_get(constraint, "pixel_buffer", "buffer_size", default=None)
295307

296-
return _extract_spatial_extent_from_load_stac_item_collection(
297-
stac_url=stac_url,
298-
spatiotemporal_extent=spatiotemporal_extent,
299-
property_filter_pg_map=property_filter_pg_map,
300-
resample_grid=resample_grid,
301-
pixel_buffer_size=pixel_buffer_size,
308+
return (cache or AlwaysCallWithoutCache()).get_or_call(
309+
key=(
310+
stac_url,
311+
spatiotemporal_extent,
312+
str(property_filter_pg_map),
313+
resample_grid,
314+
pixel_buffer_size,
315+
str(feature_flags),
316+
),
317+
callback=lambda: _extract_spatial_extent_from_load_stac_item_collection(
318+
stac_url=stac_url,
319+
spatiotemporal_extent=spatiotemporal_extent,
320+
property_filter_pg_map=property_filter_pg_map,
321+
resample_grid=resample_grid,
322+
pixel_buffer_size=pixel_buffer_size,
323+
feature_flags=feature_flags,
324+
),
302325
)
303326

304327

@@ -309,6 +332,7 @@ def _extract_spatial_extent_from_load_stac_item_collection(
309332
property_filter_pg_map: Optional[openeogeotrellis.load_stac.PropertyFilterPGMap] = None,
310333
resample_grid: Optional[_GridInfo] = None,
311334
pixel_buffer_size: Union[Tuple[float, float], int, float, None] = None,
335+
feature_flags: Optional[dict] = None,
312336
) -> Union[None, AlignedExtentResult]:
313337
extent_orig: Union[BoundingBox, None] = spatiotemporal_extent.spatial_extent.as_bbox()
314338
extent_variants = {"original": extent_orig}
@@ -430,6 +454,7 @@ def determine_global_extent(
430454
*,
431455
source_constraints: List[SourceConstraint],
432456
catalog: AbstractCollectionCatalog,
457+
cache: Optional[GetOrCallCacheInterface] = None,
433458
) -> dict:
434459
"""
435460
Go through all source constraints, extract the aligned extent from each (possibly with variations)
@@ -439,9 +464,16 @@ def determine_global_extent(
439464
# e.g. add stats to AlignedExtentResult for better informed decision?
440465
aligned_merger = BoundingBoxMerger()
441466
variant_mergers: Dict[str, BoundingBoxMerger] = collections.defaultdict(BoundingBoxMerger)
467+
# Enable the implementation to use caching
468+
# when going through all these source constraints
469+
# as there might be duplication that's hidden
470+
# by differences from non-relevant details
471+
cache = cache or GetOrCallCache(max_size=100)
442472
for source_id, constraint in source_constraints:
443473
try:
444-
aligned_extent_result = _extract_spatial_extent_from_constraint((source_id, constraint), catalog=catalog)
474+
aligned_extent_result = _extract_spatial_extent_from_constraint(
475+
(source_id, constraint), catalog=catalog, cache=cache
476+
)
445477
except Exception as e:
446478
raise SpatialExtentExtractionError(
447479
f"Failed to extract spatial extent from {source_id=} with {constraint=}: {e=}"
@@ -456,6 +488,7 @@ def determine_global_extent(
456488
_log.info(
457489
f"determine_global_extent: skipping {name=} from {source_id=} with {aligned_extent_result.variants=}"
458490
)
491+
_log.info(f"_extract_spatial_extent_from_constraint cache stats: {cache!s}")
459492

460493
global_extent: BoundingBox = aligned_merger.get()
461494
global_extent_variants: Dict[str, BoundingBox] = {name: merger.get() for name, merger in variant_mergers.items()}

openeogeotrellis/util/caching.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import collections
2+
3+
from typing import Dict, Any, Optional, Callable
4+
5+
6+
class GetOrCallCacheInterface:
7+
"""
8+
Interface for "memoizing" function calls with an explicit cache key.
9+
10+
Unlike the well known `functools.lru` cache,
11+
it requires to specify an explicit cache key,
12+
which gives more control over how to handle arguments, e.g.:
13+
- serialize or transform arguments that are not hashable by default
14+
- ignore cetain arguments
15+
- add cache key components that are not arguments to the function
16+
17+
Usage example:
18+
19+
result = cache.get_or_call(
20+
key=(name, str(sorted(options.items())),
21+
callable=lambda : function_to_cache(name, options)
22+
)
23+
24+
"""
25+
26+
def get_or_call(self, key, callback: Callable[[], Any]) -> Any:
27+
"""
28+
Try to get item from cache.
29+
If not available: call callable to build it and store result in cache.
30+
31+
:param key: key to store item at (can be a simple string,
32+
or something more complex like a tuple of strings/ints)
33+
:param callback: item builder to call when item is not in cache
34+
:return: item (from cache or freshly built)
35+
"""
36+
raise NotImplementedError()
37+
38+
39+
class AlwaysCallWithoutCache(GetOrCallCacheInterface):
40+
"""Don't cache, just call"""
41+
42+
def get_or_call(self, key, callback: Callable[[], Any]) -> Any:
43+
return callback()
44+
45+
46+
class GetOrCallCache(GetOrCallCacheInterface):
47+
"""
48+
In-memory dictionary based cache for "memoizing" function calls.
49+
50+
Supports a maximum size by pruning least frequently used items.
51+
"""
52+
53+
# TODO: support least recently used (LRU) pruning strategy as well?
54+
55+
def __init__(self, max_size: Optional[int] = None):
56+
self._cache = {}
57+
self._usage_stats = {}
58+
self._max_size = max_size
59+
self._stats = collections.defaultdict(int)
60+
61+
def get_or_call(self, key, callback: Callable[[], Any]) -> Any:
62+
self._stats["get_or_call"] += 1
63+
if key in self._cache:
64+
self._usage_stats[key] += 1
65+
self._stats["hit"] += 1
66+
return self._cache[key]
67+
else:
68+
value = callback()
69+
if self._ensure_capacity_for(count=1):
70+
self._cache[key] = value
71+
self._usage_stats[key] = 1
72+
self._stats["miss"] += 1
73+
return value
74+
75+
def _ensure_capacity_for(self, count: int = 1) -> bool:
76+
"""
77+
Make sure there is room (according to max_size setting)
78+
for adding the given number of new items
79+
by pruning least used items from the cache if necessary.
80+
"""
81+
if self._max_size is not None:
82+
to_prune = len(self._cache) + count - self._max_size
83+
if to_prune > 0:
84+
least_used_keys = sorted(self._usage_stats, key=self._usage_stats.get)[:to_prune]
85+
for key in least_used_keys:
86+
del self._cache[key]
87+
del self._usage_stats[key]
88+
self._stats["prune"] += 1
89+
return len(self._cache) + count <= self._max_size
90+
else:
91+
return True
92+
93+
def stats(self) -> dict:
94+
return dict(self._stats)
95+
96+
def __str__(self):
97+
return f"<GetOrCallCachesize size={len(self._cache)} stats={dict(self._stats)}>"
98+
99+
def clear(self):
100+
self._cache.clear()

tests/_backend/test_post_dry_run.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import collections
2+
import dirty_equals
3+
14
from typing import Callable, List, Optional
25

36
import openeo_driver.ProcessGraphDeserializer
@@ -19,7 +22,7 @@
1922
determine_global_extent,
2023
)
2124
from openeogeotrellis.load_stac import _ProjectionMetadata
22-
25+
from openeogeotrellis.util.caching import AlwaysCallWithoutCache
2326

2427
class TestGridInfo:
2528
def test_minimal(self):
@@ -979,6 +982,53 @@ def test_determine_global_extent_load_stac_minimal(
979982
},
980983
}
981984

985+
@pytest.mark.parametrize(
986+
["cache", "expected_searches"],
987+
[
988+
(None, 1), # cache=None: use default caching
989+
(AlwaysCallWithoutCache(), 3),
990+
],
991+
)
992+
def test_determine_global_extent_load_stac_caching(
993+
self,
994+
dummy_catalog,
995+
extract_source_constraints,
996+
dummy_stac_api_server,
997+
dummy_stac_api,
998+
cache,
999+
expected_searches,
1000+
):
1001+
# Multiple source constraints with same essentials,
1002+
# and some differences in non-essential details
1003+
source_constraints = [
1004+
(
1005+
("load_stac", (f"{dummy_stac_api}/collections/collection-123", (), ("B02",))),
1006+
{"bands": ["B02"], "breakfast": "cereal"},
1007+
),
1008+
(
1009+
("load_stac", (f"{dummy_stac_api}/collections/collection-123", (), ("B02",))),
1010+
{"bands": ["B02"], "shoe:size": 42},
1011+
),
1012+
(
1013+
("load_stac", (f"{dummy_stac_api}/collections/collection-123", (), ("B02",))),
1014+
{"bands": ["B02"], "pets": ["dog"]},
1015+
),
1016+
]
1017+
global_extent = determine_global_extent(
1018+
source_constraints=source_constraints, catalog=dummy_catalog, cache=cache
1019+
)
1020+
expected = BoundingBox(west=2, south=49, east=7, north=52, crs="EPSG:4326").approx(abs=1e-6)
1021+
assert global_extent == {
1022+
"global_extent": expected,
1023+
"global_extent_variants": {
1024+
"assets_full_bbox": expected,
1025+
},
1026+
}
1027+
# Check request history histogram by path for search requests
1028+
assert collections.Counter(
1029+
(r["method"], r["path"]) for r in dummy_stac_api_server.request_history
1030+
) == dirty_equals.IsPartialDict({("GET", "/search"): expected_searches})
1031+
9821032
def test_extract_spatial_extent_from_constraint_load_collection_type_stac_minimal(
9831033
self, dummy_catalog, extract_source_constraints, dummy_stac_api
9841034
):

0 commit comments

Comments
 (0)