Skip to content

Commit f9f873d

Browse files
committed
post_dry_run: add caching to determine_global_extent workflow #1565
introduces caching on _extract_spatial_extent_from_load_stac_item_collection calls involves introducing new caching utilities (GetOrCallCacheInterface) and making some data structs cache key compatible (__hash__ + __eq__)
1 parent fe55983 commit f9f873d

File tree

6 files changed

+391
-21
lines changed

6 files changed

+391
-21
lines changed

openeogeotrellis/_backend/post_dry_run.py

Lines changed: 54 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from openeogeotrellis.constants import EVAL_ENV_KEY
1818

1919
import openeogeotrellis.load_stac
20+
from openeogeotrellis.util.caching import GetOrCallCache, GetOrCallCacheInterface, AlwaysCallWithoutCache
2021
from openeogeotrellis.util.geometry import BoundingBoxMerger
2122
from openeogeotrellis.util.math import logarithmic_round
2223

@@ -57,14 +58,15 @@ def __init__(
5758
def __repr__(self) -> str:
5859
return f"_GridInfo(crs={self.crs_raw!r}, resolution={self.resolution!r}, extent_x={self.extent_x!r}, extent_y={self.extent_y!r})"
5960

61+
def _key(self) -> tuple:
62+
return (self.crs_raw, self.resolution, self.extent_x, self.extent_y)
63+
64+
def __hash__(self):
65+
return hash(self._key())
6066
def __eq__(self, other) -> bool:
61-
return (
62-
isinstance(other, _GridInfo)
63-
and self.crs_raw == other.crs_raw
64-
and self.resolution == other.resolution
65-
and self.extent_x == other.extent_x
66-
and self.extent_y == other.extent_y
67-
)
67+
if isinstance(other, _GridInfo):
68+
return self._key() == other._key()
69+
return NotImplemented
6870

6971
@classmethod
7072
def from_datacube_metadata(cls, metadata: dict) -> _GridInfo:
@@ -197,7 +199,10 @@ class AlignedExtentResult:
197199

198200

199201
def _extract_spatial_extent_from_constraint(
200-
source_constraint: SourceConstraint, *, catalog: AbstractCollectionCatalog
202+
source_constraint: SourceConstraint,
203+
*,
204+
catalog: AbstractCollectionCatalog,
205+
cache: Optional[GetOrCallCacheInterface] = None,
201206
) -> Union[None, AlignedExtentResult]:
202207
"""
203208
Extract spatial extent from given source constraint (if any), and align it to target grid.
@@ -210,18 +215,22 @@ def _extract_spatial_extent_from_constraint(
210215
if source_process == "load_collection":
211216
collection_id = source_id[1][0]
212217
return _extract_spatial_extent_from_constraint_load_collection(
213-
collection_id=collection_id, constraint=constraint, catalog=catalog
218+
collection_id=collection_id, constraint=constraint, catalog=catalog, cache=cache
214219
)
215220
elif source_process == "load_stac":
216221
url = source_id[1][0]
217-
return _extract_spatial_extent_from_constraint_load_stac(stac_url=url, constraint=constraint)
222+
return _extract_spatial_extent_from_constraint_load_stac(stac_url=url, constraint=constraint, cache=cache)
218223
else:
219224
# TODO?
220225
return None
221226

222227

223228
def _extract_spatial_extent_from_constraint_load_collection(
224-
collection_id: str, *, constraint: dict, catalog: AbstractCollectionCatalog
229+
collection_id: str,
230+
*,
231+
constraint: dict,
232+
catalog: AbstractCollectionCatalog,
233+
cache: Optional[GetOrCallCacheInterface] = None,
225234
) -> Union[None, AlignedExtentResult]:
226235
try:
227236
metadata = catalog.get_collection_metadata(collection_id)
@@ -232,7 +241,7 @@ def _extract_spatial_extent_from_constraint_load_collection(
232241
stac_url = deep_get(metadata, "_vito", "data_source", "url")
233242
load_stac_feature_flags = deep_get(metadata, "_vito", "data_source", "load_stac_feature_flags", default={})
234243
return _extract_spatial_extent_from_constraint_load_stac(
235-
stac_url=stac_url, constraint=constraint, feature_flags=load_stac_feature_flags
244+
stac_url=stac_url, constraint=constraint, feature_flags=load_stac_feature_flags, cache=cache
236245
)
237246

238247
# TODO Extracting pixel grid info from collection metadata might might be unreliable
@@ -271,7 +280,11 @@ def _extract_spatial_extent_from_constraint_load_collection(
271280

272281

273282
def _extract_spatial_extent_from_constraint_load_stac(
274-
stac_url: str, *, constraint: dict, feature_flags: Optional[dict] = None
283+
stac_url: str,
284+
*,
285+
constraint: dict,
286+
feature_flags: Optional[dict] = None,
287+
cache: Optional[GetOrCallCacheInterface] = None,
275288
) -> Union[None, AlignedExtentResult]:
276289
spatial_extent_from_pg = constraint.get("spatial_extent") or constraint.get("weak_spatial_extent")
277290
spatiotemporal_extent = openeogeotrellis.load_stac._spatiotemporal_extent_from_load_params(
@@ -292,12 +305,23 @@ def _extract_spatial_extent_from_constraint_load_stac(
292305

293306
pixel_buffer_size = deep_get(constraint, "pixel_buffer", "buffer_size", default=None)
294307

295-
return _extract_spatial_extent_from_load_stac_item_collection(
296-
stac_url=stac_url,
297-
spatiotemporal_extent=spatiotemporal_extent,
298-
property_filter_pg_map=property_filter_pg_map,
299-
resample_grid=resample_grid,
300-
pixel_buffer_size=pixel_buffer_size,
308+
return (cache or AlwaysCallWithoutCache()).get_or_call(
309+
key=(
310+
stac_url,
311+
spatiotemporal_extent,
312+
str(property_filter_pg_map),
313+
resample_grid,
314+
pixel_buffer_size,
315+
str(feature_flags),
316+
),
317+
callback=lambda: _extract_spatial_extent_from_load_stac_item_collection(
318+
stac_url=stac_url,
319+
spatiotemporal_extent=spatiotemporal_extent,
320+
property_filter_pg_map=property_filter_pg_map,
321+
resample_grid=resample_grid,
322+
pixel_buffer_size=pixel_buffer_size,
323+
feature_flags=feature_flags,
324+
),
301325
)
302326

303327

@@ -308,6 +332,7 @@ def _extract_spatial_extent_from_load_stac_item_collection(
308332
property_filter_pg_map: Optional[openeogeotrellis.load_stac.PropertyFilterPGMap] = None,
309333
resample_grid: Optional[_GridInfo] = None,
310334
pixel_buffer_size: Union[Tuple[float, float], int, float, None] = None,
335+
feature_flags: Optional[dict] = None,
311336
) -> Union[None, AlignedExtentResult]:
312337
extent_orig: Union[BoundingBox, None] = spatiotemporal_extent.spatial_extent.as_bbox()
313338
extent_variants = {"original": extent_orig}
@@ -429,6 +454,7 @@ def determine_global_extent(
429454
*,
430455
source_constraints: List[SourceConstraint],
431456
catalog: AbstractCollectionCatalog,
457+
cache: Optional[GetOrCallCacheInterface] = None,
432458
) -> dict:
433459
"""
434460
Go through all source constraints, extract the aligned extent from each (possibly with variations)
@@ -438,9 +464,16 @@ def determine_global_extent(
438464
# e.g. add stats to AlignedExtentResult for better informed decision?
439465
aligned_merger = BoundingBoxMerger()
440466
variant_mergers: Dict[str, BoundingBoxMerger] = collections.defaultdict(BoundingBoxMerger)
467+
# Enable the implementation to use caching
468+
# when going through all these source constraints
469+
# as there might be duplication that's hidden
470+
# by differences from non-relevant details
471+
cache = cache or GetOrCallCache(max_size=100)
441472
for source_id, constraint in source_constraints:
442473
try:
443-
aligned_extent_result = _extract_spatial_extent_from_constraint((source_id, constraint), catalog=catalog)
474+
aligned_extent_result = _extract_spatial_extent_from_constraint(
475+
(source_id, constraint), catalog=catalog, cache=cache
476+
)
444477
except Exception as e:
445478
raise SpatialExtentExtractionError(
446479
f"Failed to extract spatial extent from {source_id=} with {constraint=}: {e=}"
@@ -455,6 +488,7 @@ def determine_global_extent(
455488
_log.info(
456489
f"determine_global_extent: skipping {name=} from {source_id=} with {aligned_extent_result.variants=}"
457490
)
491+
_log.info(f"_extract_spatial_extent_from_constraint cache stats: {cache!s}")
458492

459493
global_extent: BoundingBox = aligned_merger.get()
460494
global_extent_variants: Dict[str, BoundingBox] = {name: merger.get() for name, merger in variant_mergers.items()}

openeogeotrellis/load_stac.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -931,13 +931,26 @@ class _TemporalExtent:
931931
"""
932932

933933
# TODO: move this to a more generic location for better reuse
934+
# TODO: enforce/ensure immutability
935+
# TODO: re-implement in dataclasses/attrs to better enforce immutability and simplify equality/hash implementation
934936

935937
__slots__ = ("from_date", "to_date")
936938

937939
def __init__(self, from_date: DateTimeLikeOrNone, to_date: DateTimeLikeOrNone):
938940
self.from_date: Union[datetime.datetime, None] = to_datetime_utc_unless_none(from_date)
939941
self.to_date: Union[datetime.datetime, None] = to_datetime_utc_unless_none(to_date)
940942

943+
def _key(self) -> tuple:
944+
return (self.from_date, self.to_date)
945+
946+
def __hash__(self):
947+
return hash(self._key())
948+
949+
def __eq__(self, other):
950+
if isinstance(other, _TemporalExtent):
951+
return self._key() == other._key()
952+
return NotImplemented
953+
941954
def as_tuple(self) -> Tuple[Union[datetime.datetime, None], Union[datetime.datetime, None]]:
942955
return self.from_date, self.to_date
943956

@@ -994,6 +1007,8 @@ class _SpatialExtent:
9941007
"""
9951008

9961009
# TODO: move this to a more generic location for better reuse
1010+
# TODO: enforce/ensure immutability
1011+
# TODO: re-implement in dataclasses/attrs to better enforce immutability and simplify equality/hash implementation
9971012

9981013
__slots__ = ("_bbox", "_bbox_lonlat_shape")
9991014

@@ -1003,6 +1018,17 @@ def __init__(self, *, bbox: Union[BoundingBox, None]):
10031018
# cache for shapely polygon in lon/lat
10041019
self._bbox_lonlat_shape = self._bbox.reproject("EPSG:4326").as_polygon() if self._bbox else None
10051020

1021+
def _key(self) -> tuple:
1022+
return (self._bbox,)
1023+
1024+
def __hash__(self):
1025+
return hash(self._key())
1026+
1027+
def __eq__(self, other):
1028+
if isinstance(other, _SpatialExtent):
1029+
return self._key() == other._key()
1030+
return NotImplemented
1031+
10061032
def as_bbox(self, crs: Optional[str] = None) -> Union[BoundingBox, None]:
10071033
bbox = self._bbox
10081034
if bbox and crs:
@@ -1018,6 +1044,9 @@ def intersects(self, bbox: Union[List[float], Tuple[float, float, float, float],
10181044

10191045
class _SpatioTemporalExtent:
10201046
# TODO: move this to a more generic location for better reuse
1047+
# TODO: enforce/ensure immutability
1048+
# TODO: re-implement in dataclasses/attrs to better enforce immutability and simplify equality/hash implementation
1049+
10211050
def __init__(
10221051
self,
10231052
*,
@@ -1028,6 +1057,17 @@ def __init__(
10281057
self._spatial_extent = _SpatialExtent(bbox=bbox)
10291058
self._temporal_extent = _TemporalExtent(from_date=from_date, to_date=to_date)
10301059

1060+
def _key(self) -> tuple:
1061+
return (self._spatial_extent, self._temporal_extent)
1062+
1063+
def __hash__(self):
1064+
return hash(self._key())
1065+
1066+
def __eq__(self, other):
1067+
if isinstance(other, _SpatioTemporalExtent):
1068+
return self._key() == other._key()
1069+
return NotImplemented
1070+
10311071
@property
10321072
def spatial_extent(self) -> _SpatialExtent:
10331073
return self._spatial_extent

openeogeotrellis/util/caching.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import collections
2+
3+
from typing import Dict, Any, Optional, Callable
4+
5+
6+
class GetOrCallCacheInterface:
7+
"""
8+
Interface for "memoizing" function calls with an explicit cache key.
9+
10+
Unlike the well known `functools.lru` cache,
11+
it requires to specify an explicit cache key,
12+
which gives more control over how to handle arguments, e.g.:
13+
- serialize or transform arguments that are not hashable by default
14+
- ignore cetain arguments
15+
- add cache key components that are not arguments to the function
16+
17+
Usage example:
18+
19+
result = cache.get_or_call(
20+
key=(name, str(sorted(options.items())),
21+
callable=lambda : function_to_cache(name, options)
22+
)
23+
24+
"""
25+
26+
def get_or_call(self, key, callback: Callable[[], Any]) -> Any:
27+
"""
28+
Try to get item from cache.
29+
If not available: call callable to build it and store result in cache.
30+
31+
:param key: key to store item at (can be a simple string,
32+
or something more complex like a tuple of strings/ints)
33+
:param callback: item builder to call when item is not in cache
34+
:return: item (from cache or freshly built)
35+
"""
36+
raise NotImplementedError()
37+
38+
39+
class AlwaysCallWithoutCache(GetOrCallCacheInterface):
40+
"""Don't cache, just call"""
41+
42+
def get_or_call(self, key, callback: Callable[[], Any]) -> Any:
43+
return callback()
44+
45+
46+
class GetOrCallCache(GetOrCallCacheInterface):
47+
"""
48+
In-memory dictionary based cache for "memoizing" function calls.
49+
50+
Supports a maximum size by pruning least frequently used items.
51+
"""
52+
53+
# TODO: support least recently used (LRU) pruning strategy as well?
54+
55+
def __init__(self, max_size: Optional[int] = None):
56+
self._cache = {}
57+
self._usage_stats = {}
58+
self._max_size = max_size
59+
self._stats = collections.defaultdict(int)
60+
61+
def get_or_call(self, key, callback: Callable[[], Any]) -> Any:
62+
self._stats["get_or_call"] += 1
63+
if key in self._cache:
64+
self._usage_stats[key] += 1
65+
self._stats["hit"] += 1
66+
return self._cache[key]
67+
else:
68+
value = callback()
69+
if self._ensure_capacity_for(count=1):
70+
self._cache[key] = value
71+
self._usage_stats[key] = 1
72+
self._stats["miss"] += 1
73+
return value
74+
75+
def _ensure_capacity_for(self, count: int = 1) -> bool:
76+
"""
77+
Make sure there is room (according to max_size setting)
78+
for adding the given number of new items
79+
by pruning least used items from the cache if necessary.
80+
"""
81+
if self._max_size is not None:
82+
to_prune = len(self._cache) + count - self._max_size
83+
if to_prune > 0:
84+
least_used_keys = sorted(self._usage_stats, key=self._usage_stats.get)[:to_prune]
85+
for key in least_used_keys:
86+
del self._cache[key]
87+
del self._usage_stats[key]
88+
self._stats["prune"] += 1
89+
return len(self._cache) + count <= self._max_size
90+
else:
91+
return True
92+
93+
def stats(self) -> dict:
94+
return dict(self._stats)
95+
96+
def __str__(self):
97+
return f"<GetOrCallCachesize size={len(self._cache)} stats={dict(self._stats)}>"
98+
99+
def clear(self):
100+
self._cache.clear()

0 commit comments

Comments
 (0)