-
Notifications
You must be signed in to change notification settings - Fork 48
Implement lazy loading to defer metadata RPCs until data access time … #253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -163,6 +163,7 @@ def open( | |
| executor_kwargs: dict[str, Any] | None = None, | ||
| getitem_kwargs: dict[str, int] | None = None, | ||
| fast_time_slicing: bool = False, | ||
| lazy_load: bool = False, | ||
| ) -> EarthEngineStore: | ||
| if mode != 'r': | ||
| raise ValueError( | ||
|
|
@@ -186,6 +187,7 @@ def open( | |
| executor_kwargs=executor_kwargs, | ||
| getitem_kwargs=getitem_kwargs, | ||
| fast_time_slicing=fast_time_slicing, | ||
| lazy_load=lazy_load, | ||
| ) | ||
|
|
||
| def __init__( | ||
|
|
@@ -206,10 +208,12 @@ def __init__( | |
| executor_kwargs: dict[str, Any] | None = None, | ||
| getitem_kwargs: dict[str, int] | None = None, | ||
| fast_time_slicing: bool = False, | ||
| lazy_load: bool = False, | ||
| ): | ||
| self.ee_init_kwargs = ee_init_kwargs | ||
| self.ee_init_if_necessary = ee_init_if_necessary | ||
| self.fast_time_slicing = fast_time_slicing | ||
| self.lazy_load = lazy_load | ||
|
|
||
| # Initialize executor_kwargs | ||
| if executor_kwargs is None: | ||
|
|
@@ -227,8 +231,11 @@ def __init__( | |
| self.primary_dim_name = primary_dim_name or 'time' | ||
| self.primary_dim_property = primary_dim_property or 'system:time_start' | ||
|
|
||
| # Always need to get size for n_images | ||
| self.n_images = self.get_info['size'] | ||
| self._props = self.get_info['props'] | ||
| # These are loaded lazily if lazy_load=True | ||
| if 'props' in self.get_info: | ||
| self._props = self.get_info['props'] | ||
| # Metadata should apply to all imgs. | ||
| self._img_info: types.ImageInfo = self.get_info['first'] | ||
|
|
||
|
|
@@ -281,57 +288,106 @@ def __init__( | |
|
|
||
| @functools.cached_property | ||
| def get_info(self) -> dict[str, Any]: | ||
| """Make all getInfo() calls to EE at once.""" | ||
| """Make all getInfo() calls to EE at once. | ||
|
|
||
| If lazy_load is True, only performs essential metadata calls and defers | ||
| other calls until data access time. | ||
| """ | ||
|
|
||
| if not hasattr(self, '_info_cache'): | ||
| self._info_cache = {} | ||
|
|
||
| # Perform minimal RPCs if lazy loading is enabled | ||
| if getattr(self, 'lazy_load', False): | ||
|
||
| # Only fetch essential metadata needed for dataset structure | ||
| if not self._info_cache: | ||
| rpcs = [ | ||
| ('size', self.image_collection.size()), | ||
| ('first', self.image_collection.first()), | ||
| ] | ||
|
|
||
| if isinstance(self.projection, ee.Projection): | ||
| rpcs.append(('projection', self.projection)) | ||
|
|
||
| if isinstance(self.geometry, ee.Geometry): | ||
| rpcs.append(('bounds', self.geometry.bounds(1, proj=self.projection))) | ||
| else: | ||
| rpcs.append( | ||
| ( | ||
| 'bounds', | ||
| self.image_collection.first() | ||
| .geometry() | ||
| .bounds(1, proj=self.projection), | ||
| ) | ||
| ) | ||
|
|
||
| info = ee.List([rpc for _, rpc in rpcs]).getInfo() | ||
| self._info_cache.update(dict(zip((name for name, _ in rpcs), info))) | ||
|
|
||
| return self._info_cache | ||
|
|
||
| # Full metadata loading if not lazy | ||
| if not self._info_cache or len(self._info_cache) < 5: # Check if we have full metadata | ||
|
||
| rpcs = [ | ||
| ('size', self.image_collection.size()), | ||
| ('props', self.image_collection.toDictionary()), | ||
| ('first', self.image_collection.first()), | ||
| ] | ||
|
|
||
| rpcs = [ | ||
| ('size', self.image_collection.size()), | ||
| ('props', self.image_collection.toDictionary()), | ||
| ('first', self.image_collection.first()), | ||
| ] | ||
| if isinstance(self.projection, ee.Projection): | ||
| rpcs.append(('projection', self.projection)) | ||
|
|
||
| if isinstance(self.projection, ee.Projection): | ||
| rpcs.append(('projection', self.projection)) | ||
| if isinstance(self.geometry, ee.Geometry): | ||
| rpcs.append(('bounds', self.geometry.bounds(1, proj=self.projection))) | ||
| else: | ||
| rpcs.append( | ||
| ( | ||
| 'bounds', | ||
| self.image_collection.first() | ||
| .geometry() | ||
| .bounds(1, proj=self.projection), | ||
| ) | ||
| ) | ||
|
|
||
| if isinstance(self.geometry, ee.Geometry): | ||
| rpcs.append(('bounds', self.geometry.bounds(1, proj=self.projection))) | ||
| else: | ||
| # TODO(#29, #30): This RPC call takes the longest time to compute. This | ||
| # requires a full scan of the images in the collection, which happens on the | ||
| # EE backend. This is essential because we want the primary dimension of the | ||
| # opened dataset to be something relevant to the data, like time (start | ||
| # time) as opposed to a random index number. | ||
| # | ||
| # One optimization that could prove really fruitful: read the first and last | ||
| # (few) values of the primary dim (read: time) and interpolate the rest | ||
| # client-side. Ideally, this would live behind a xarray-backend-specific | ||
| # feature flag, since it's not guaranteed that data is this consistent. | ||
| columns = ['system:id', self.primary_dim_property] | ||
| rpcs.append( | ||
| ( | ||
| 'bounds', | ||
| self.image_collection.first() | ||
| .geometry() | ||
| .bounds(1, proj=self.projection), | ||
| 'properties', | ||
| ( | ||
| self.image_collection.reduceColumns( | ||
| ee.Reducer.toList().repeat(len(columns)), columns | ||
| ).get('list') | ||
| ), | ||
| ) | ||
| ) | ||
|
|
||
| # TODO(#29, #30): This RPC call takes the longest time to compute. This | ||
| # requires a full scan of the images in the collection, which happens on the | ||
| # EE backend. This is essential because we want the primary dimension of the | ||
| # opened dataset to be something relevant to the data, like time (start | ||
| # time) as opposed to a random index number. | ||
| # | ||
| # One optimization that could prove really fruitful: read the first and last | ||
| # (few) values of the primary dim (read: time) and interpolate the rest | ||
| # client-side. Ideally, this would live behind a xarray-backend-specific | ||
| # feature flag, since it's not guaranteed that data is this consistent. | ||
| columns = ['system:id', self.primary_dim_property] | ||
| rpcs.append( | ||
| ( | ||
| 'properties', | ||
| ( | ||
| self.image_collection.reduceColumns( | ||
| ee.Reducer.toList().repeat(len(columns)), columns | ||
| ).get('list') | ||
| ), | ||
| ) | ||
| ) | ||
|
|
||
| info = ee.List([rpc for _, rpc in rpcs]).getInfo() | ||
|
|
||
| return dict(zip((name for name, _ in rpcs), info)) | ||
| info = ee.List([rpc for _, rpc in rpcs]).getInfo() | ||
| self._info_cache.update(dict(zip((name for name, _ in rpcs), info))) | ||
|
|
||
| return self._info_cache | ||
|
|
||
| @property | ||
| def image_collection_properties(self) -> tuple[list[str], list[str]]: | ||
| if self.lazy_load and 'properties' not in self._info_cache: | ||
| # Fetch properties on-demand if lazy loading is enabled | ||
| columns = ['system:id', self.primary_dim_property] | ||
| properties = ( | ||
| self.image_collection.reduceColumns( | ||
| ee.Reducer.toList().repeat(len(columns)), columns | ||
| ).get('list') | ||
| ).getInfo() | ||
|
||
| self._info_cache['properties'] = properties | ||
|
|
||
| system_ids, primary_coord = self.get_info['properties'] | ||
| return (system_ids, primary_coord) | ||
|
|
||
|
|
@@ -1044,6 +1100,7 @@ def open_dataset( | |
| executor_kwargs: dict[str, Any] | None = None, | ||
| getitem_kwargs: dict[str, int] | None = None, | ||
| fast_time_slicing: bool = False, | ||
| lazy_load: bool = False, | ||
| ) -> xarray.Dataset: # type: ignore | ||
| """Open an Earth Engine ImageCollection as an Xarray Dataset. | ||
|
|
||
|
|
@@ -1126,6 +1183,9 @@ def open_dataset( | |
| makes slicing an ImageCollection across time faster. This optimization | ||
| loads EE images in a slice by ID, so any modifications to images in a | ||
| computed ImageCollection will not be reflected. | ||
| lazy_load (optional): If True, defers metadata RPCs to data access time, | ||
| making opening datasets faster. Similar to xr.open_zarr(..., chunks=None) | ||
| behavior. Defaults to False. | ||
| Returns: | ||
| An xarray.Dataset that streams in remote data from Earth Engine. | ||
| """ | ||
|
|
@@ -1158,6 +1218,7 @@ def open_dataset( | |
| executor_kwargs=executor_kwargs, | ||
| getitem_kwargs=getitem_kwargs, | ||
| fast_time_slicing=fast_time_slicing, | ||
| lazy_load=lazy_load, | ||
| ) | ||
|
|
||
| store_entrypoint = backends_store.StoreBackendEntrypoint() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| import os | ||
| import pathlib | ||
| import tempfile | ||
| import time | ||
|
|
||
| from absl.testing import absltest | ||
| from google.auth import identity_pool | ||
|
|
@@ -556,6 +557,46 @@ def test_fast_time_slicing(self): | |
| fast_slicing = xr.open_dataset(**params, fast_time_slicing=True) | ||
| fast_slicing_data = getattr(fast_slicing[dict(time=0)], band).as_numpy() | ||
| self.assertTrue(np.all(fast_slicing_data > 0)) | ||
|
|
||
| def test_lazy_loading(self): | ||
| """Test that lazy loading defers metadata RPCs until data access time.""" | ||
| ic = ee.ImageCollection('ECMWF/ERA5_LAND/HOURLY').filterDate( | ||
| '1992-10-05', '1992-10-06') # Using a smaller date range for the test | ||
|
|
||
| # Open dataset with lazy loading | ||
| start_time = time.time() | ||
|
||
| lazy_ds = xr.open_dataset( | ||
| ic, | ||
| engine=xee.EarthEngineBackendEntrypoint, | ||
| lazy_load=True, | ||
| ) | ||
| lazy_open_time = time.time() - start_time | ||
|
|
||
| # Open dataset without lazy loading | ||
| start_time = time.time() | ||
| regular_ds = xr.open_dataset( | ||
| ic, | ||
| engine=xee.EarthEngineBackendEntrypoint, | ||
| lazy_load=False, | ||
| ) | ||
| regular_open_time = time.time() - start_time | ||
|
|
||
| # Verify that lazy opening is faster than regular opening | ||
| self.assertLess(lazy_open_time, regular_open_time, | ||
| f"Lazy loading ({lazy_open_time:.2f}s) should be faster than regular loading ({regular_open_time:.2f}s)") | ||
|
Comment on lines
+583
to
+586
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, I'm really happy to see this test. |
||
|
|
||
| # Verify that both datasets have the same structure | ||
| self.assertEqual(lazy_ds.dims, regular_ds.dims) | ||
| self.assertEqual(list(lazy_ds.data_vars), list(regular_ds.data_vars)) | ||
|
|
||
| # Access data and verify it's the same | ||
| var_name = list(lazy_ds.data_vars)[0] | ||
| lazy_data = lazy_ds[var_name].isel(time=0).values | ||
| regular_data = regular_ds[var_name].isel(time=0).values | ||
|
|
||
| # Both should have same shape and data should not be all zeros or NaNs | ||
| self.assertEqual(lazy_data.shape, regular_data.shape) | ||
| self.assertTrue(np.allclose(lazy_data, regular_data, equal_nan=True)) | ||
|
Comment on lines
+588
to
+599
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks this is good to have in the test. |
||
|
|
||
| @absltest.skipIf(_SKIP_RASTERIO_TESTS, 'rioxarray module not loaded') | ||
| def test_write_projected_dataset_to_raster(self): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm concerned that we have two levels of caching: this private dictionary and the
functools.cached_propertylevel. Is there any way we can use one system instead of both? I anticipate cache invalidation problems down the line.One way we could clean this system up is by breaking out what info we get into multiple functions (that are also cached): e.g. a helper fn gets the first line essential stuff below. Since it's cached, we'd need less
if...elselogic to manage state; that would be happen in Python's memoizer decorator -- we'd make use of the cache just by calling the helper functions.WDYT about this approach?