Skip to content

Commit b69a22f

Browse files
committed
post_dry_run: add caching to determine_global_extent workflow #1565
introduces caching on _extract_spatial_extent_from_load_stac_item_collection calls involves introducing new caching utilities (GetOrCallCacheInterface) and making some data structs cache key compatible (__hash__ + __eq__)
1 parent fbe05fd commit b69a22f

File tree

6 files changed

+387
-21
lines changed

6 files changed

+387
-21
lines changed

openeogeotrellis/_backend/post_dry_run.py

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from openeogeotrellis.constants import EVAL_ENV_KEY
1717

1818
import openeogeotrellis.load_stac
19+
from openeogeotrellis.util.caching import GetOrCallCache, GetOrCallCacheInterface, AlwaysCallWithoutCache
1920
from openeogeotrellis.util.geometry import BoundingBoxMerger
2021
from openeogeotrellis.util.math import logarithmic_round
2122

@@ -56,14 +57,15 @@ def __init__(
5657
def __repr__(self) -> str:
5758
return f"_GridInfo(crs={self.crs_raw!r}, resolution={self.resolution!r}, extent_x={self.extent_x!r}, extent_y={self.extent_y!r})"
5859

60+
def _key(self) -> tuple:
61+
return (self.crs_raw, self.resolution, self.extent_x, self.extent_y)
62+
63+
def __hash__(self):
64+
return hash(self._key())
5965
def __eq__(self, other) -> bool:
60-
return (
61-
isinstance(other, _GridInfo)
62-
and self.crs_raw == other.crs_raw
63-
and self.resolution == other.resolution
64-
and self.extent_x == other.extent_x
65-
and self.extent_y == other.extent_y
66-
)
66+
if isinstance(other, _GridInfo):
67+
return self._key() == other._key()
68+
return NotImplemented
6769

6870
@classmethod
6971
def from_datacube_metadata(cls, metadata: dict) -> _GridInfo:
@@ -196,7 +198,10 @@ class AlignedExtentResult:
196198

197199

198200
def _extract_spatial_extent_from_constraint(
199-
source_constraint: SourceConstraint, *, catalog: AbstractCollectionCatalog
201+
source_constraint: SourceConstraint,
202+
*,
203+
catalog: AbstractCollectionCatalog,
204+
cache: Optional[GetOrCallCacheInterface] = None,
200205
) -> Union[None, AlignedExtentResult]:
201206
"""
202207
Extract spatial extent from given source constraint (if any), and align it to target grid.
@@ -209,18 +214,22 @@ def _extract_spatial_extent_from_constraint(
209214
if source_process == "load_collection":
210215
collection_id = source_id[1][0]
211216
return _extract_spatial_extent_from_constraint_load_collection(
212-
collection_id=collection_id, constraint=constraint, catalog=catalog
217+
collection_id=collection_id, constraint=constraint, catalog=catalog, cache=cache
213218
)
214219
elif source_process == "load_stac":
215220
url = source_id[1][0]
216-
return _extract_spatial_extent_from_constraint_load_stac(stac_url=url, constraint=constraint)
221+
return _extract_spatial_extent_from_constraint_load_stac(stac_url=url, constraint=constraint, cache=cache)
217222
else:
218223
# TODO?
219224
return None
220225

221226

222227
def _extract_spatial_extent_from_constraint_load_collection(
223-
collection_id: str, *, constraint: dict, catalog: AbstractCollectionCatalog
228+
collection_id: str,
229+
*,
230+
constraint: dict,
231+
catalog: AbstractCollectionCatalog,
232+
cache: Optional[GetOrCallCacheInterface] = None,
224233
) -> Union[None, AlignedExtentResult]:
225234
try:
226235
metadata = catalog.get_collection_metadata(collection_id)
@@ -229,7 +238,7 @@ def _extract_spatial_extent_from_constraint_load_collection(
229238

230239
if deep_get(metadata, "_vito", "data_source", "type", default=None) == "stac":
231240
stac_url = deep_get(metadata, "_vito", "data_source", "url")
232-
return _extract_spatial_extent_from_constraint_load_stac(stac_url=stac_url, constraint=constraint)
241+
return _extract_spatial_extent_from_constraint_load_stac(stac_url=stac_url, constraint=constraint, cache=cache)
233242

234243
# TODO Extracting pixel grid info from collection metadata might might be unreliable
235244
# and should be replaced by more precise item-level metadata where possible.
@@ -267,7 +276,10 @@ def _extract_spatial_extent_from_constraint_load_collection(
267276

268277

269278
def _extract_spatial_extent_from_constraint_load_stac(
270-
stac_url: str, *, constraint: dict
279+
stac_url: str,
280+
*,
281+
constraint: dict,
282+
cache: Optional[GetOrCallCacheInterface] = None,
271283
) -> Union[None, AlignedExtentResult]:
272284
spatial_extent_from_pg = constraint.get("spatial_extent") or constraint.get("weak_spatial_extent")
273285
spatiotemporal_extent = openeogeotrellis.load_stac._spatiotemporal_extent_from_load_params(
@@ -288,12 +300,21 @@ def _extract_spatial_extent_from_constraint_load_stac(
288300

289301
pixel_buffer_size = deep_get(constraint, "pixel_buffer", "buffer_size", default=None)
290302

291-
return _extract_spatial_extent_from_load_stac_item_collection(
292-
stac_url=stac_url,
293-
spatiotemporal_extent=spatiotemporal_extent,
294-
property_filter_pg_map=property_filter_pg_map,
295-
resample_grid=resample_grid,
296-
pixel_buffer_size=pixel_buffer_size,
303+
return (cache or AlwaysCallWithoutCache()).get_or_call(
304+
key=(
305+
stac_url,
306+
spatiotemporal_extent,
307+
str(property_filter_pg_map),
308+
resample_grid,
309+
pixel_buffer_size,
310+
),
311+
callback=lambda: _extract_spatial_extent_from_load_stac_item_collection(
312+
stac_url=stac_url,
313+
spatiotemporal_extent=spatiotemporal_extent,
314+
property_filter_pg_map=property_filter_pg_map,
315+
resample_grid=resample_grid,
316+
pixel_buffer_size=pixel_buffer_size,
317+
),
297318
)
298319

299320

@@ -412,6 +433,7 @@ def determine_global_extent(
412433
*,
413434
source_constraints: List[SourceConstraint],
414435
catalog: AbstractCollectionCatalog,
436+
cache: Optional[GetOrCallCacheInterface] = None,
415437
) -> dict:
416438
"""
417439
Go through all source constraints, extract the aligned extent from each (possibly with variations)
@@ -421,9 +443,16 @@ def determine_global_extent(
421443
# e.g. add stats to AlignedExtentResult for better informed decision?
422444
aligned_merger = BoundingBoxMerger()
423445
variant_mergers: Dict[str, BoundingBoxMerger] = collections.defaultdict(BoundingBoxMerger)
446+
# Enable the implementation to use caching
447+
# when going through all these source constraints
448+
# as there might be duplication that's hidden
449+
# by differences from non-relevant details
450+
cache = cache or GetOrCallCache(max_size=100)
424451
for source_id, constraint in source_constraints:
425452
try:
426-
aligned_extent_result = _extract_spatial_extent_from_constraint((source_id, constraint), catalog=catalog)
453+
aligned_extent_result = _extract_spatial_extent_from_constraint(
454+
(source_id, constraint), catalog=catalog, cache=cache
455+
)
427456
except Exception as e:
428457
raise SpatialExtentExtractionError(
429458
f"Failed to extract spatial extent from {source_id=} with {constraint=}: {e=}"
@@ -438,6 +467,7 @@ def determine_global_extent(
438467
_log.info(
439468
f"determine_global_extent: skipping {name=} from {source_id=} with {aligned_extent_result.variants=}"
440469
)
470+
_log.info(f"_extract_spatial_extent_from_constraint cache stats: {cache!s}")
441471

442472
global_extent: BoundingBox = aligned_merger.get()
443473
global_extent_variants: Dict[str, BoundingBox] = {name: merger.get() for name, merger in variant_mergers.items()}

openeogeotrellis/load_stac.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -909,13 +909,26 @@ class _TemporalExtent:
909909
"""
910910

911911
# TODO: move this to a more generic location for better reuse
912+
# TODO: enforce/ensure immutability
913+
# TODO: re-implement in dataclasses/attrs to better enforce immutability and simplify equality/hash implementation
912914

913915
__slots__ = ("from_date", "to_date")
914916

915917
def __init__(self, from_date: DateTimeLikeOrNone, to_date: DateTimeLikeOrNone):
916918
self.from_date: Union[datetime.datetime, None] = to_datetime_utc_unless_none(from_date)
917919
self.to_date: Union[datetime.datetime, None] = to_datetime_utc_unless_none(to_date)
918920

921+
def _key(self) -> tuple:
922+
return (self.from_date, self.to_date)
923+
924+
def __hash__(self):
925+
return hash(self._key())
926+
927+
def __eq__(self, other):
928+
if isinstance(other, _TemporalExtent):
929+
return self._key() == other._key()
930+
return NotImplemented
931+
919932
def as_tuple(self) -> Tuple[Union[datetime.datetime, None], Union[datetime.datetime, None]]:
920933
return self.from_date, self.to_date
921934

@@ -972,6 +985,8 @@ class _SpatialExtent:
972985
"""
973986

974987
# TODO: move this to a more generic location for better reuse
988+
# TODO: enforce/ensure immutability
989+
# TODO: re-implement in dataclasses/attrs to better enforce immutability and simplify equality/hash implementation
975990

976991
__slots__ = ("_bbox", "_bbox_lonlat_shape")
977992

@@ -981,6 +996,17 @@ def __init__(self, *, bbox: Union[BoundingBox, None]):
981996
# cache for shapely polygon in lon/lat
982997
self._bbox_lonlat_shape = self._bbox.reproject("EPSG:4326").as_polygon() if self._bbox else None
983998

999+
def _key(self) -> tuple:
1000+
return (self._bbox,)
1001+
1002+
def __hash__(self):
1003+
return hash(self._key())
1004+
1005+
def __eq__(self, other):
1006+
if isinstance(other, _SpatialExtent):
1007+
return self._key() == other._key()
1008+
return NotImplemented
1009+
9841010
def as_bbox(self, crs: Optional[str] = None) -> Union[BoundingBox, None]:
9851011
bbox = self._bbox
9861012
if bbox and crs:
@@ -996,6 +1022,9 @@ def intersects(self, bbox: Union[List[float], Tuple[float, float, float, float],
9961022

9971023
class _SpatioTemporalExtent:
9981024
# TODO: move this to a more generic location for better reuse
1025+
# TODO: enforce/ensure immutability
1026+
# TODO: re-implement in dataclasses/attrs to better enforce immutability and simplify equality/hash implementation
1027+
9991028
def __init__(
10001029
self,
10011030
*,
@@ -1006,6 +1035,17 @@ def __init__(
10061035
self._spatial_extent = _SpatialExtent(bbox=bbox)
10071036
self._temporal_extent = _TemporalExtent(from_date=from_date, to_date=to_date)
10081037

1038+
def _key(self) -> tuple:
1039+
return (self._spatial_extent, self._temporal_extent)
1040+
1041+
def __hash__(self):
1042+
return hash(self._key())
1043+
1044+
def __eq__(self, other):
1045+
if isinstance(other, _SpatioTemporalExtent):
1046+
return self._key() == other._key()
1047+
return NotImplemented
1048+
10091049
@property
10101050
def spatial_extent(self) -> _SpatialExtent:
10111051
return self._spatial_extent

openeogeotrellis/util/caching.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import collections
2+
3+
from typing import Dict, Any, Optional, Callable
4+
5+
6+
class GetOrCallCacheInterface:
7+
"""
8+
Interface for "memoizing" function calls with an explicit cache key.
9+
10+
Unlike the well known `functools.lru` cache,
11+
it requires to specify an explicit cache key,
12+
which gives more control over how to handle arguments, e.g.:
13+
- serialize or transform arguments that are not hashable by default
14+
- ignore cetain arguments
15+
- add cache key components that are not arguments to the function
16+
17+
Usage example:
18+
19+
result = cache.get_or_call(
20+
key=(name, str(sorted(options.items())),
21+
callable=lambda : function_to_cache(name, options)
22+
)
23+
24+
"""
25+
26+
def get_or_call(self, key, callback: Callable[[], Any]) -> Any:
27+
"""
28+
Try to get item from cache.
29+
If not available: call callable to build it and store result in cache.
30+
31+
:param key: key to store item at (can be a simple string,
32+
or something more complex like a tuple of strings/ints)
33+
:param callback: item builder to call when item is not in cache
34+
:return: item (from cache or freshly built)
35+
"""
36+
raise NotImplementedError()
37+
38+
39+
class AlwaysCallWithoutCache(GetOrCallCacheInterface):
40+
"""Don't cache, just call"""
41+
42+
def get_or_call(self, key, callback: Callable[[], Any]) -> Any:
43+
return callback()
44+
45+
46+
class GetOrCallCache(GetOrCallCacheInterface):
47+
"""
48+
In-memory dictionary based cache for "memoizing" function calls.
49+
50+
Supports a maximum size by pruning least frequently used items.
51+
"""
52+
53+
# TODO: support least recently used (LRU) pruning strategy as well?
54+
55+
def __init__(self, max_size: Optional[int] = None):
56+
self._cache = {}
57+
self._usage_stats = {}
58+
self._max_size = max_size
59+
self._stats = collections.defaultdict(int)
60+
61+
def get_or_call(self, key, callback: Callable[[], Any]) -> Any:
62+
self._stats["get_or_call"] += 1
63+
if key in self._cache:
64+
self._usage_stats[key] += 1
65+
self._stats["hit"] += 1
66+
return self._cache[key]
67+
else:
68+
value = callback()
69+
if self._ensure_capacity_for(count=1):
70+
self._cache[key] = value
71+
self._usage_stats[key] = 1
72+
self._stats["miss"] += 1
73+
return value
74+
75+
def _ensure_capacity_for(self, count: int = 1) -> bool:
76+
"""
77+
Make sure there is room (according to max_size setting)
78+
for adding the given number of new items
79+
by pruning least used items from the cache if necessary.
80+
"""
81+
if self._max_size is not None:
82+
to_prune = len(self._cache) + count - self._max_size
83+
if to_prune > 0:
84+
least_used_keys = sorted(self._usage_stats, key=self._usage_stats.get)[:to_prune]
85+
for key in least_used_keys:
86+
del self._cache[key]
87+
del self._usage_stats[key]
88+
self._stats["prune"] += 1
89+
return len(self._cache) + count <= self._max_size
90+
else:
91+
return True
92+
93+
def stats(self) -> dict:
94+
return dict(self._stats)
95+
96+
def __str__(self):
97+
return f"<GetOrCallCachesize size={len(self._cache)} stats={dict(self._stats)}>"
98+
99+
def clear(self):
100+
self._cache.clear()

0 commit comments

Comments
 (0)