Skip to content

Commit 443fbc6

Browse files
perf(gribjump): compute coordinates once instead of per-field
Extract grid coordinates before the field loop instead of inside it, reducing O(N) calls to geography methods to O(1). Also make shared grid indices explicit in ExtractionRequestCollection and replace assert with proper ValueError check.
1 parent 43b1d33 commit 443fbc6

File tree

1 file changed

+68
-71
lines changed

1 file changed

+68
-71
lines changed

src/earthkit/data/sources/gribjump.py

Lines changed: 68 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,6 @@ def split_mars_requests(request: dict[str, Any]) -> list[dict[str, Any]]:
3737
returns result arrays without metadata, so each field must be requested individually
3838
to map outputs correctly.
3939
40-
NOTE: Parsing of MARS requests should ideally not be handled here but in a dedicated
41-
component like pymetkit. Consider updating this function once something appropriate
42-
is available.
43-
4440
Parameters
4541
----------
4642
request : dict[str, Any]
@@ -129,17 +125,13 @@ def mask_to_ranges(mask: np.ndarray) -> list[tuple[int, int]]:
129125

130126
@dataclasses.dataclass
131127
class ExtractionRequest:
132-
"""
133-
Simple wrapper of pygribjump.ExtractionRequest and the original FDB request dict.
134-
135-
Can be removed once pygribjump.ExtractionRequest provides a reference to the request dictionary
136-
with original MARS keyword types.
128+
"""Wrapper around pygribjump.ExtractionRequest that also stores the original request dict.
137129
138130
Parameters
139131
----------
140132
extraction_request : pygj.ExtractionRequest
141133
The GribJump extraction request object.
142-
request : dict[str, str]
134+
request : dict[str, Any]
143135
The original request dictionary used to create the extraction request.
144136
"""
145137

@@ -162,8 +154,7 @@ def build_extraction_request(
162154
mask: Optional[np.ndarray] = None,
163155
indices: Optional[np.ndarray] = None,
164156
) -> ExtractionRequest:
165-
"""
166-
Builds an ExtractionRequest from the given request dictionary and optional parameters.
157+
"""Builds an ExtractionRequest from the given request dictionary and optional parameters.
167158
168159
Parameters
169160
----------
@@ -201,6 +192,17 @@ def build_extraction_request(
201192

202193

203194
class ExtractionRequestCollection(UserList):
195+
"""Collection of extraction requests sharing the same grid indices."""
196+
197+
def __init__(self, extraction_requests: list[ExtractionRequest], indices: np.ndarray):
198+
"""Internal constructor. Use from_mars_requests() to create instances."""
199+
super().__init__(extraction_requests)
200+
self._indices = indices
201+
202+
@property
203+
def indices(self) -> np.ndarray:
204+
"""Grid indices shared by all extraction requests."""
205+
return self._indices
204206

205207
@classmethod
206208
def from_mars_requests(
@@ -210,25 +212,7 @@ def from_mars_requests(
210212
mask: Optional[np.ndarray] = None,
211213
indices: Optional[np.ndarray] = None,
212214
) -> "ExtractionRequestCollection":
213-
"""Creates an ExtractionRequestCollection from MARS requests.
214-
215-
One of the parameters `ranges`, `mask`, or `indices` must be provided.
216-
217-
Parameters
218-
----------
219-
mars_requests : list[dict[str, str]]
220-
List of MARS requests, each represented as a dictionary of keywords.
221-
ranges : Optional[list[tuple[int, int]]], optional
222-
The ranges for the extraction requests, by default None.
223-
mask : Optional[np.ndarray], optional
224-
The mask for the extraction requests, by default None.
225-
indices : Optional[np.ndarray], optional
226-
The indices for the extraction requests, by default None.
227-
Returns
228-
-------
229-
ExtractionRequestCollection
230-
A collection of ExtractionRequest objects created from the MARS requests.
231-
"""
215+
"""Create collection from MARS requests. Exactly one of ranges/mask/indices must be provided."""
232216

233217
if sum(opt is not None for opt in (ranges, mask, indices)) != 1:
234218
raise ValueError(
@@ -243,34 +227,39 @@ def from_mars_requests(
243227
mask = None
244228

245229
extraction_requests = [build_extraction_request(req, ranges, mask, indices) for req in mars_requests]
246-
return cls(extraction_requests)
230+
231+
# All requests share the same indices; get from the first one
232+
if not extraction_requests:
233+
raise ValueError("Cannot create ExtractionRequestCollection from empty mars_requests list")
234+
canonical_indices = extraction_requests[0].indices()
235+
236+
return cls(extraction_requests, canonical_indices)
247237

248238

249239
class FieldExtractList(SimpleFieldList):
250240
"""Lazily loaded representation of points extracted from multiple fields using GribJump.
251241
252-
.. warning::
253-
This implementation is **not thread-safe**. Concurrent access from multiple threads
254-
may result in race conditions during lazy loading. Use appropriate synchronization
255-
if accessing from multiple threads.
256-
257242
.. note::
258243
This class should not be instantiated directly. Use the ``gribjump`` source instead:
259244
``earthkit.data.from_source("gribjump", request, ranges=ranges)``
260245
261-
This class inherits from SimpleFieldList and provides lazy loading of grid point
262-
extractions from GRIB fields via GribJump. FieldList operations like ``sel()``,
263-
``group_by()``, etc. might work but are not guaranteed to be fully functional.
246+
.. warning::
247+
This implementation is **not thread-safe**. Concurrent access from multiple threads
248+
may result in race conditions during lazy loading.
249+
250+
This class provides lazy loading of grid point extractions from GRIB fields via GribJump.
251+
FieldList operations like ``sel()``, ``group_by()``, etc. might work but are not guaranteed
252+
to be fully functional.
264253
265254
Known Limitations
266255
-----------------
267-
* **No validation**: Grid indices are not validated against actual field grids.
268-
Incorrect indices may return unexpected grid points.
269-
* **Not thread-safe**: Concurrent access may cause race conditions during lazy loading.
270-
* **Limited metadata**: Only metadata from the request dictionary is available,
271-
except for latitude/longitude coordinates when ``fetch_coords_from_fdb=True`` is used.
272-
* **No efficient slicing**: Lazy loading of selections/slices is not supported.
273-
* **Serialization issues**: Pickling/unpickling might not work reliably.
256+
* Grid indices are not validated against actual field grids. Incorrect indices may return
257+
unexpected grid points.
258+
* Not thread-safe - concurrent access may cause race conditions during lazy loading.
259+
* Limited metadata - only metadata from the request dictionary is available, except for
260+
latitude/longitude coordinates when ``fetch_coords_from_fdb=True`` is used.
261+
* No efficient slicing - lazy loading of selections/slices is not supported.
262+
* Serialization issues - pickling/unpickling might not work reliably.
274263
"""
275264

276265
def __init__(
@@ -283,7 +272,6 @@ def __init__(
283272
self._requests = requests
284273
self._fdb_retriever = fdb_retriever
285274

286-
# These attributes are set lazily after loading the data.
287275
self._loaded = False
288276
self._grid_indices = None
289277
self._reference_metadata: Optional[GribMetadata] = None
@@ -307,23 +295,26 @@ def _load(self):
307295
context = {"origin": "earthkit-data"}
308296
extraction_results = self._gj.extract(extraction_requests, ctx=context)
309297

298+
# Get the shared indices from the collection
299+
indices = self._requests.indices
300+
301+
# Pre-compute grid coordinates once for all fields
302+
coords = self._get_grid_coordinates(indices)
303+
310304
fields = []
311-
indices = None
312-
ranges = None
313305
for request, result in zip(self._requests, extraction_results):
314-
if ranges is None:
315-
ranges = request.ranges
316-
indices = request.indices()
317-
else:
318-
if request.ranges != ranges:
319-
raise ValueError(
320-
f"Extraction request has different ranges than the first request: {request.ranges} != {ranges}"
321-
)
322306
arr = result.values_flat
323307
shape = arr.shape
324308

325309
metadata = UserMetadata(request.request, shape=shape)
326-
metadata = self._enrich_metadata_with_coordinates(indices, metadata)
310+
if coords is not None:
311+
grid_latitudes, grid_longitudes = coords
312+
metadata = metadata.override(
313+
{
314+
"latitudes": grid_latitudes,
315+
"longitudes": grid_longitudes,
316+
}
317+
)
327318

328319
field = ArrayField(arr, metadata)
329320
fields.append(field)
@@ -350,21 +341,26 @@ def _load_reference_metadata(self):
350341
self._reference_metadata = metadata
351342
return metadata
352343

353-
def _enrich_metadata_with_coordinates(self, indices: np.ndarray, metadata: UserMetadata) -> UserMetadata:
354-
"""Enriches the metadata with coordinates if reference metadata is available."""
344+
def _get_grid_coordinates(self, indices: np.ndarray) -> Optional[tuple[np.ndarray, np.ndarray]]:
345+
"""Get latitude/longitude coordinates at the specified grid indices.
346+
347+
Parameters
348+
----------
349+
indices : np.ndarray
350+
Grid indices to extract coordinates for.
351+
352+
Returns
353+
-------
354+
Optional[tuple[np.ndarray, np.ndarray]]
355+
A tuple of (latitudes, longitudes) arrays, or None if no reference metadata is available.
356+
"""
355357
if (reference_metadata := self._load_reference_metadata()) is None:
356-
return metadata
358+
return None
357359

358360
reference_geography = reference_metadata.geography
359361
grid_latitudes = reference_geography.latitudes()[indices]
360362
grid_longitudes = reference_geography.longitudes()[indices]
361-
metadata = metadata.override(
362-
{
363-
"latitudes": grid_latitudes,
364-
"longitudes": grid_longitudes,
365-
}
366-
)
367-
return metadata
363+
return (grid_latitudes, grid_longitudes)
368364

369365
def to_xarray(self, *args, **kwargs):
370366
kwargs = kwargs.copy()
@@ -424,11 +420,12 @@ def __init__(
424420
A 1D array of grid indices to retrieve, by default None.
425421
fetch_coords_from_fdb : bool, optional
426422
If set to True, loads the first field's metadata from the FDB to extract the coordinates
427-
at the specified indices.
423+
at the specified indices, by default False.
428424
fdb_kwargs : Optional[dict[str, Any]], optional
429425
Only used when `fetch_coords_from_fdb=True`. A dict of
430426
keyword arguments passed to the `pyfdb.FDB` constructor. These arguments are only passed
431-
to the FDB when fetching coordinates and is not used by GribJump for the extraction itself.
427+
to the FDB when fetching coordinates and is not used by GribJump for the extraction itself,
428+
by default None.
432429
"""
433430

434431
super().__init__(**kwargs)

0 commit comments

Comments
 (0)