Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,18 @@ ds = xarray.open_dataset(
)
```

Open an ImageCollection with lazy loading to defer metadata RPCs until data access time:

```python
ic = ee.ImageCollection('ECMWF/ERA5_LAND/HOURLY').filterDate(
'1992-10-05', '1993-03-31')
ds = xarray.open_dataset(
ic,
engine='ee',
lazy_load=True # Defers metadata RPCs for faster dataset opening
)
```

Open multiple ImageCollections into one `xarray.Dataset`, all with the same
projection:

Expand Down
143 changes: 102 additions & 41 deletions xee/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def open(
executor_kwargs: dict[str, Any] | None = None,
getitem_kwargs: dict[str, int] | None = None,
fast_time_slicing: bool = False,
lazy_load: bool = False,
) -> EarthEngineStore:
if mode != 'r':
raise ValueError(
Expand All @@ -186,6 +187,7 @@ def open(
executor_kwargs=executor_kwargs,
getitem_kwargs=getitem_kwargs,
fast_time_slicing=fast_time_slicing,
lazy_load=lazy_load,
)

def __init__(
Expand All @@ -206,10 +208,12 @@ def __init__(
executor_kwargs: dict[str, Any] | None = None,
getitem_kwargs: dict[str, int] | None = None,
fast_time_slicing: bool = False,
lazy_load: bool = False,
):
self.ee_init_kwargs = ee_init_kwargs
self.ee_init_if_necessary = ee_init_if_necessary
self.fast_time_slicing = fast_time_slicing
self.lazy_load = lazy_load

# Initialize executor_kwargs
if executor_kwargs is None:
Expand All @@ -227,8 +231,11 @@ def __init__(
self.primary_dim_name = primary_dim_name or 'time'
self.primary_dim_property = primary_dim_property or 'system:time_start'

# Always need to get size for n_images
self.n_images = self.get_info['size']
self._props = self.get_info['props']
# These are loaded lazily if lazy_load=True
if 'props' in self.get_info:
self._props = self.get_info['props']
# Metadata should apply to all imgs.
self._img_info: types.ImageInfo = self.get_info['first']

Expand Down Expand Up @@ -281,57 +288,106 @@ def __init__(

@functools.cached_property
def get_info(self) -> dict[str, Any]:
"""Make all getInfo() calls to EE at once."""
"""Make all getInfo() calls to EE at once.

If lazy_load is True, only performs essential metadata calls and defers
other calls until data access time.
"""

if not hasattr(self, '_info_cache'):
self._info_cache = {}
Comment on lines +304 to +305
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm concerned that we have two levels of caching: this private dictionary and the functools.cached_property level. Is there any way we can use one system instead of both? I anticipate cache invalidation problems down the line.

One way we could clean this system up is by breaking out what info we get into multiple functions (that are also cached): e.g. a helper fn gets the first line essential stuff below. Since it's cached, we'd need less if...else logic to manage state; that would be happen in Python's memoizer decorator -- we'd make use of the cache just by calling the helper functions.

WDYT about this approach?


# Perform minimal RPCs if lazy loading is enabled
if getattr(self, 'lazy_load', False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of this eager getattr, I think we should just pass lazy_load as an argument to this function.

# Only fetch essential metadata needed for dataset structure
if not self._info_cache:
rpcs = [
('size', self.image_collection.size()),
('first', self.image_collection.first()),
]

if isinstance(self.projection, ee.Projection):
rpcs.append(('projection', self.projection))

if isinstance(self.geometry, ee.Geometry):
rpcs.append(('bounds', self.geometry.bounds(1, proj=self.projection)))
else:
rpcs.append(
(
'bounds',
self.image_collection.first()
.geometry()
.bounds(1, proj=self.projection),
)
)

info = ee.List([rpc for _, rpc in rpcs]).getInfo()
self._info_cache.update(dict(zip((name for name, _ in rpcs), info)))

return self._info_cache

# Full metadata loading if not lazy
if not self._info_cache or len(self._info_cache) < 5: # Check if we have full metadata
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this magic number. Is there a more principled way to check if the cache needs updating?

rpcs = [
('size', self.image_collection.size()),
('props', self.image_collection.toDictionary()),
('first', self.image_collection.first()),
]

rpcs = [
('size', self.image_collection.size()),
('props', self.image_collection.toDictionary()),
('first', self.image_collection.first()),
]
if isinstance(self.projection, ee.Projection):
rpcs.append(('projection', self.projection))

if isinstance(self.projection, ee.Projection):
rpcs.append(('projection', self.projection))
if isinstance(self.geometry, ee.Geometry):
rpcs.append(('bounds', self.geometry.bounds(1, proj=self.projection)))
else:
rpcs.append(
(
'bounds',
self.image_collection.first()
.geometry()
.bounds(1, proj=self.projection),
)
)

if isinstance(self.geometry, ee.Geometry):
rpcs.append(('bounds', self.geometry.bounds(1, proj=self.projection)))
else:
# TODO(#29, #30): This RPC call takes the longest time to compute. This
# requires a full scan of the images in the collection, which happens on the
# EE backend. This is essential because we want the primary dimension of the
# opened dataset to be something relevant to the data, like time (start
# time) as opposed to a random index number.
#
# One optimization that could prove really fruitful: read the first and last
# (few) values of the primary dim (read: time) and interpolate the rest
# client-side. Ideally, this would live behind a xarray-backend-specific
# feature flag, since it's not guaranteed that data is this consistent.
columns = ['system:id', self.primary_dim_property]
rpcs.append(
(
'bounds',
self.image_collection.first()
.geometry()
.bounds(1, proj=self.projection),
'properties',
(
self.image_collection.reduceColumns(
ee.Reducer.toList().repeat(len(columns)), columns
).get('list')
),
)
)

# TODO(#29, #30): This RPC call takes the longest time to compute. This
# requires a full scan of the images in the collection, which happens on the
# EE backend. This is essential because we want the primary dimension of the
# opened dataset to be something relevant to the data, like time (start
# time) as opposed to a random index number.
#
# One optimization that could prove really fruitful: read the first and last
# (few) values of the primary dim (read: time) and interpolate the rest
# client-side. Ideally, this would live behind a xarray-backend-specific
# feature flag, since it's not guaranteed that data is this consistent.
columns = ['system:id', self.primary_dim_property]
rpcs.append(
(
'properties',
(
self.image_collection.reduceColumns(
ee.Reducer.toList().repeat(len(columns)), columns
).get('list')
),
)
)

info = ee.List([rpc for _, rpc in rpcs]).getInfo()

return dict(zip((name for name, _ in rpcs), info))
info = ee.List([rpc for _, rpc in rpcs]).getInfo()
self._info_cache.update(dict(zip((name for name, _ in rpcs), info)))

return self._info_cache

@property
def image_collection_properties(self) -> tuple[list[str], list[str]]:
if self.lazy_load and 'properties' not in self._info_cache:
# Fetch properties on-demand if lazy loading is enabled
columns = ['system:id', self.primary_dim_property]
properties = (
self.image_collection.reduceColumns(
ee.Reducer.toList().repeat(len(columns)), columns
).get('list')
).getInfo()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idea: maybe we can break this out into a function and call it here and in get_info as needed.

self._info_cache['properties'] = properties

system_ids, primary_coord = self.get_info['properties']
return (system_ids, primary_coord)

Expand Down Expand Up @@ -1044,6 +1100,7 @@ def open_dataset(
executor_kwargs: dict[str, Any] | None = None,
getitem_kwargs: dict[str, int] | None = None,
fast_time_slicing: bool = False,
lazy_load: bool = False,
) -> xarray.Dataset: # type: ignore
"""Open an Earth Engine ImageCollection as an Xarray Dataset.

Expand Down Expand Up @@ -1126,6 +1183,9 @@ def open_dataset(
makes slicing an ImageCollection across time faster. This optimization
loads EE images in a slice by ID, so any modifications to images in a
computed ImageCollection will not be reflected.
lazy_load (optional): If True, defers metadata RPCs to data access time,
making opening datasets faster. Similar to xr.open_zarr(..., chunks=None)
behavior. Defaults to False.
Returns:
An xarray.Dataset that streams in remote data from Earth Engine.
"""
Expand Down Expand Up @@ -1158,6 +1218,7 @@ def open_dataset(
executor_kwargs=executor_kwargs,
getitem_kwargs=getitem_kwargs,
fast_time_slicing=fast_time_slicing,
lazy_load=lazy_load,
)

store_entrypoint = backends_store.StoreBackendEntrypoint()
Expand Down
41 changes: 41 additions & 0 deletions xee/ext_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import pathlib
import tempfile
import time

from absl.testing import absltest
from google.auth import identity_pool
Expand Down Expand Up @@ -556,6 +557,46 @@ def test_fast_time_slicing(self):
fast_slicing = xr.open_dataset(**params, fast_time_slicing=True)
fast_slicing_data = getattr(fast_slicing[dict(time=0)], band).as_numpy()
self.assertTrue(np.all(fast_slicing_data > 0))

def test_lazy_loading(self):
"""Test that lazy loading defers metadata RPCs until data access time."""
ic = ee.ImageCollection('ECMWF/ERA5_LAND/HOURLY').filterDate(
'1992-10-05', '1992-10-06') # Using a smaller date range for the test

# Open dataset with lazy loading
start_time = time.time()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use time.perf_counter() to capture time instead (and at the places below).

lazy_ds = xr.open_dataset(
ic,
engine=xee.EarthEngineBackendEntrypoint,
lazy_load=True,
)
lazy_open_time = time.time() - start_time

# Open dataset without lazy loading
start_time = time.time()
regular_ds = xr.open_dataset(
ic,
engine=xee.EarthEngineBackendEntrypoint,
lazy_load=False,
)
regular_open_time = time.time() - start_time

# Verify that lazy opening is faster than regular opening
self.assertLess(lazy_open_time, regular_open_time,
f"Lazy loading ({lazy_open_time:.2f}s) should be faster than regular loading ({regular_open_time:.2f}s)")
Comment on lines +583 to +586
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I'm really happy to see this test.


# Verify that both datasets have the same structure
self.assertEqual(lazy_ds.dims, regular_ds.dims)
self.assertEqual(list(lazy_ds.data_vars), list(regular_ds.data_vars))

# Access data and verify it's the same
var_name = list(lazy_ds.data_vars)[0]
lazy_data = lazy_ds[var_name].isel(time=0).values
regular_data = regular_ds[var_name].isel(time=0).values

# Both should have same shape and data should not be all zeros or NaNs
self.assertEqual(lazy_data.shape, regular_data.shape)
self.assertTrue(np.allclose(lazy_data, regular_data, equal_nan=True))
Comment on lines +588 to +599
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks this is good to have in the test.


@absltest.skipIf(_SKIP_RASTERIO_TESTS, 'rioxarray module not loaded')
def test_write_projected_dataset_to_raster(self):
Expand Down