diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 48d1a4c5135..67eead54510 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -36,6 +36,7 @@ T_PathFileOrDataStore, _find_absolute_paths, _normalize_path, + datatree_from_dict_with_io_cleanup, ) from xarray.backends.locks import get_dask_scheduler from xarray.coders import CFDatetimeCoder, CFTimedeltaCoder @@ -538,6 +539,35 @@ def _datatree_from_backend_datatree( return tree +async def _maybe_create_default_indexes_async(ds): + import asyncio + + # Determine which coords need default indexes + to_index_names = [ + name + for name, coord in ds.coords.items() + if coord.dims == (name,) and name not in ds.xindexes + ] + + if to_index_names: + + async def load_var(var): + try: + return await var.load_async() + except NotImplementedError: + return await asyncio.to_thread(var.load) + + await asyncio.gather( + *[load_var(ds.coords[name].variable) for name in to_index_names] + ) + + # Build indexes (now data is in-memory so no remote I/O per coord) + to_index = {name: ds.coords[name].variable for name in to_index_names} + if to_index: + return ds.assign_coords(Coordinates(to_index)) + return ds + + def open_dataset( filename_or_obj: T_PathFileOrDataStore, *, @@ -1253,6 +1283,137 @@ def open_datatree( return tree +async def open_datatree_async( + filename_or_obj: T_PathFileOrDataStore, + *, + engine: T_Engine = None, + chunks: T_Chunks = None, + cache: bool | None = None, + decode_cf: bool | None = None, + mask_and_scale: bool | Mapping[str, bool] | None = None, + decode_times: bool + | CFDatetimeCoder + | Mapping[str, bool | CFDatetimeCoder] + | None = None, + decode_timedelta: bool + | CFTimedeltaCoder + | Mapping[str, bool | CFTimedeltaCoder] + | None = None, + use_cftime: bool | Mapping[str, bool] | None = None, + concat_characters: bool | Mapping[str, bool] | None = None, + decode_coords: Literal["coordinates", "all"] | bool | None = None, + drop_variables: str | Iterable[str] | None = None, + create_default_indexes: bool = True, + inline_array: bool = False, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, + backend_kwargs: dict[str, Any] | None = None, + **kwargs, +) -> DataTree: + """Async version of open_datatree that concurrently builds default indexes. + + Supports the "zarr" engine (both Zarr v2 and v3). For other engines, a + ValueError is raised. + """ + import asyncio + + if cache is None: + cache = chunks is None + + if backend_kwargs is not None: + kwargs.update(backend_kwargs) + + if engine is None: + engine = plugins.guess_engine(filename_or_obj) + + if from_array_kwargs is None: + from_array_kwargs = {} + + # Only zarr supports async lazy loading at present + if engine != "zarr": + raise ValueError(f"Engine {engine!r} does not support asynchronous operations") + + backend = plugins.get_backend(engine) + + decoders = _resolve_decoders_kwargs( + decode_cf, + open_backend_dataset_parameters=backend.open_dataset_parameters, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + decode_timedelta=decode_timedelta, + concat_characters=concat_characters, + use_cftime=use_cftime, + decode_coords=decode_coords, + ) + + overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None) + + # Prefer backend async group opening if available (currently zarr only) + if hasattr(backend, "open_groups_as_dict_async"): + groups_dict = await backend.open_groups_as_dict_async( + filename_or_obj, + drop_variables=drop_variables, + **decoders, + **kwargs, + ) + backend_tree = datatree_from_dict_with_io_cleanup(groups_dict) + else: + backend_tree = backend.open_datatree( + filename_or_obj, + drop_variables=drop_variables, + **decoders, + **kwargs, + ) + + # Protect variables for caching behavior consistency + _protect_datatree_variables_inplace(backend_tree, cache) + + # For each dataset in the tree, concurrently create default indexes (if requested) + results: dict[str, Dataset] = {} + + async def process_node(path: str, node_ds: Dataset) -> tuple[str, Dataset]: + ds = node_ds + if create_default_indexes: + ds = await _maybe_create_default_indexes_async(ds) + # Optional chunking (synchronous) + if chunks is not None: + ds = _chunk_ds( + ds, + filename_or_obj, + engine, + chunks, + overwrite_encoded_chunks, + inline_array, + chunked_array_type, + from_array_kwargs, + node=path, + **decoders, + **kwargs, + ) + return path, ds + + # Build tasks + tasks = [ + process_node(path, node.dataset) + for path, [node] in group_subtrees(backend_tree) + ] + + # Execute concurrently and collect + for fut in asyncio.as_completed(tasks): + path, ds = await fut + results[path] = ds + + # Build DataTree + tree = DataTree.from_dict(results) + + # Carry over close handlers from backend tree when needed (mirrors sync path) + if create_default_indexes or chunks is not None: + for _path, [node] in group_subtrees(backend_tree): + tree[_path].set_close(node._close) + + return tree + + def open_groups( filename_or_obj: T_PathFileOrDataStore, *, diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index f0578ca9352..d303cef3ec3 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import base64 import json import os @@ -1785,6 +1786,80 @@ def open_groups_as_dict( groups_dict[group_name] = group_ds return groups_dict + async def open_groups_as_dict_async( + self, + filename_or_obj: T_PathFileOrDataStore, + *, + mask_and_scale=True, + decode_times=True, + concat_characters=True, + decode_coords=True, + drop_variables: str | Iterable[str] | None = None, + use_cftime=None, + decode_timedelta=None, + group: str | None = None, + mode="r", + synchronizer=None, + consolidated=None, + chunk_store=None, + storage_options=None, + zarr_version=None, + zarr_format=None, + ) -> dict[str, Dataset]: + """Asynchronously open each group into a Dataset concurrently. + + This mirrors open_groups_as_dict but parallelizes per-group Dataset opening, + which can significantly reduce latency on high-RTT object stores. + """ + filename_or_obj = _normalize_path(filename_or_obj) + + # Determine parent group path context + if group: + parent = str(NodePath("/") / NodePath(group)) + else: + parent = str(NodePath("/")) + + # Discover group stores (synchronous metadata step) + stores = ZarrStore.open_store( + filename_or_obj, + group=parent, + mode=mode, + synchronizer=synchronizer, + consolidated=consolidated, + consolidate_on_close=False, + chunk_store=chunk_store, + storage_options=storage_options, + zarr_version=zarr_version, + zarr_format=zarr_format, + ) + + async def open_one(path_group: str, store) -> tuple[str, Dataset]: + store_entrypoint = StoreBackendEntrypoint() + + def _load_sync(): + with close_on_error(store): + return store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + + ds = await asyncio.to_thread(_load_sync) + if group: + group_name = str(NodePath(path_group).relative_to(parent)) + else: + group_name = str(NodePath(path_group)) + return group_name, ds + + tasks = [open_one(path_group, store) for path_group, store in stores.items()] + results = await asyncio.gather(*tasks) + return dict(results) + def _iter_zarr_groups(root: ZarrGroup, parent: str = "/") -> Iterable[str]: parent_nodepath = NodePath(parent) diff --git a/xarray/tests/test_backends_zarr_async.py b/xarray/tests/test_backends_zarr_async.py new file mode 100644 index 00000000000..4f6abee3e2c --- /dev/null +++ b/xarray/tests/test_backends_zarr_async.py @@ -0,0 +1,324 @@ +"""Tests for asynchronous zarr group loading functionality.""" + +from __future__ import annotations + +import asyncio +import contextlib +from unittest.mock import patch + +import numpy as np +import pytest + +import xarray as xr +from xarray.backends.api import _maybe_create_default_indexes_async, open_datatree_async +from xarray.backends.zarr import ZarrBackendEntrypoint +from xarray.testing import assert_equal +from xarray.tests import ( + has_zarr_v3, + parametrize_zarr_format, + requires_zarr, + requires_zarr_v3, +) + +if has_zarr_v3: + from zarr.storage import MemoryStore + + +def create_dataset_with_coordinates(n_coords=5): + """Create a dataset with coordinate variables to trigger indexing.""" + coords = {} + for i in range(n_coords): + coords[f"coord_{i}"] = (f"coord_{i}", np.arange(3)) + + coord_names = list(coords.keys()) + data_vars = {} + + if len(coord_names) >= 2: + data_vars["temperature"] = (coord_names[:2], np.random.random((3, 3))) + if len(coord_names) >= 1: + data_vars["pressure"] = (coord_names[:1], np.random.random(3)) + + data_vars["simple"] = ([], np.array(42.0)) + + ds = xr.Dataset(data_vars=data_vars, coords=coords) + return ds + + +def create_test_datatree(n_groups=3, coords_per_group=5): + """Create a DataTree for testing with multiple groups.""" + root_ds = create_dataset_with_coordinates(coords_per_group) + tree_dict = {"/": root_ds} + + for i in range(n_groups): + group_name = f"/group_{i:03d}" + group_ds = create_dataset_with_coordinates(n_coords=coords_per_group) + tree_dict[group_name] = group_ds + + tree = xr.DataTree.from_dict(tree_dict) + return tree + + +@requires_zarr +class TestAsyncZarrGroupLoading: + """Tests for asynchronous zarr group loading functionality.""" + + @contextlib.contextmanager + def create_zarr_store(self): + """Create a zarr target for testing.""" + if has_zarr_v3: + with MemoryStore() as store: + yield store + else: + from zarr.storage import MemoryStore as V2MemoryStore + + store = V2MemoryStore() + yield store + + @parametrize_zarr_format + def test_async_datatree_roundtrip(self, zarr_format): + """Test that async datatree loading preserves data integrity.""" + + dtree = create_test_datatree(n_groups=3, coords_per_group=4) + + with self.create_zarr_store() as store: + dtree.to_zarr(store, mode="w", consolidated=False, zarr_format=zarr_format) + + async def load_async(): + return await open_datatree_async( + store, + consolidated=False, + zarr_format=zarr_format, + create_default_indexes=True, + engine="zarr", + ) + + dtree_async = asyncio.run(load_async()) + assert_equal(dtree, dtree_async) + + def test_async_error_handling_unsupported_engine(self): + """Test that async functions properly handle unsupported engines.""" + + async def test_unsupported_engine(): + with pytest.raises( + ValueError, match="does not support asynchronous operations" + ): + await open_datatree_async("/fake/path", engine="netcdf4") + + asyncio.run(test_unsupported_engine()) + + @pytest.mark.asyncio + @requires_zarr_v3 + async def test_async_concurrent_loading(self): + """Test that async loading uses concurrent calls for multiple groups.""" + import zarr + + dtree = create_test_datatree(n_groups=3, coords_per_group=4) + + with self.create_zarr_store() as store: + dtree.to_zarr(store, mode="w", consolidated=False, zarr_format=3) + + target_class = zarr.AsyncGroup + original_method = target_class.getitem + + with patch.object( + target_class, "getitem", wraps=original_method, autospec=True + ) as mocked_method: + dtree_async = await open_datatree_async( + store, + consolidated=False, + zarr_format=3, + create_default_indexes=True, + engine="zarr", + ) + + assert_equal(dtree, dtree_async) + + assert mocked_method.call_count > 0 + mocked_method.assert_awaited() + + @pytest.mark.asyncio + @parametrize_zarr_format + async def test_async_root_only_datatree(self, zarr_format): + """Test async loading of datatree with only root node (no child groups).""" + + root_ds = create_dataset_with_coordinates(3) + dtree = xr.DataTree(root_ds) + + with self.create_zarr_store() as store: + dtree.to_zarr(store, mode="w", consolidated=False, zarr_format=zarr_format) + + dtree_async = await open_datatree_async( + store, + consolidated=False, + zarr_format=zarr_format, + create_default_indexes=True, + engine="zarr", + ) + + assert len(list(dtree_async.subtree)) == 1 + assert dtree_async.path == "/" + assert dtree_async.ds is not None + + @pytest.mark.asyncio + @parametrize_zarr_format + @pytest.mark.parametrize("n_groups", [1, 3, 10]) + async def test_async_multiple_groups(self, zarr_format, n_groups): + """Test async loading of datatree with varying numbers of groups.""" + dtree = create_test_datatree(n_groups=n_groups, coords_per_group=3) + + with self.create_zarr_store() as store: + dtree.to_zarr(store, mode="w", consolidated=False, zarr_format=zarr_format) + + # Load asynchronously + dtree_async = await open_datatree_async( + store, + consolidated=False, + zarr_format=zarr_format, + create_default_indexes=True, + engine="zarr", + ) + + expected_groups = ["/"] + [f"/group_{i:03d}" for i in range(n_groups)] + group_paths = [node.path for node in dtree_async.subtree] + + assert len(group_paths) == len(expected_groups) + for expected_path in expected_groups: + assert expected_path in group_paths + + @pytest.mark.asyncio + @parametrize_zarr_format + async def test_async_create_default_indexes_false(self, zarr_format): + """Test that create_default_indexes=False prevents index creation.""" + dtree = create_test_datatree(n_groups=2, coords_per_group=3) + + with self.create_zarr_store() as store: + dtree.to_zarr(store, mode="w", consolidated=False, zarr_format=zarr_format) + + dtree_async = await open_datatree_async( + store, + consolidated=False, + zarr_format=zarr_format, + create_default_indexes=False, + engine="zarr", + ) + + assert len(list(dtree_async.subtree)) == 3 + + for node in dtree_async.subtree: + dataset = node.dataset + if dataset is not None: + coord_names = [ + name + for name, coord in dataset.coords.items() + if coord.dims == (name,) + ] + for coord_name in coord_names: + assert coord_name not in dataset.xindexes, ( + f"Index should not exist for coordinate '{coord_name}' when create_default_indexes=False" + ) + + def test_sync_vs_async_api_compatibility(self): + """Test that sync and async APIs have compatible signatures.""" + import inspect + + from xarray.backends.api import open_datatree + + sync_sig = inspect.signature(open_datatree) + async_sig = inspect.signature(open_datatree_async) + + sync_params = list(sync_sig.parameters.keys()) + async_params = list(async_sig.parameters.keys()) + + for param in sync_params: + assert param in async_params, ( + f"Parameter '{param}' missing from async version" + ) + + @pytest.mark.asyncio + @requires_zarr + @parametrize_zarr_format + async def test_backend_open_groups_async_equivalence(self, zarr_format): + """Backend async group opening returns the same groups and datasets as sync.""" + dtree = create_test_datatree(n_groups=3, coords_per_group=4) + backend = ZarrBackendEntrypoint() + + with self.create_zarr_store() as store: + dtree.to_zarr(store, mode="w", consolidated=False, zarr_format=zarr_format) + + groups_sync = backend.open_groups_as_dict( + store, + consolidated=False, + zarr_format=zarr_format, + ) + + groups_async = await backend.open_groups_as_dict_async( + store, + consolidated=False, + zarr_format=zarr_format, + ) + + assert set(groups_sync.keys()) == set(groups_async.keys()) + for k in list(groups_sync.keys())[:2]: + assert_equal(groups_sync[k], groups_async[k]) + + @pytest.mark.asyncio + async def test_maybe_create_default_indexes_async_no_coords_needing_indexes(self): + """Test _maybe_create_default_indexes_async with no coordinates needing indexes.""" + ds = xr.Dataset( + { + "temperature": (("x", "y"), np.random.random((3, 4))), + } + ) + + result = await _maybe_create_default_indexes_async(ds) + assert_equal(ds, result) + assert len(result.xindexes) == 0 + + @pytest.mark.asyncio + async def test_maybe_create_default_indexes_async_creates_indexes(self): + """Test _maybe_create_default_indexes_async creates indexes for coordinate variables.""" + coords = {"time": ("time", np.arange(5)), "x": ("x", np.arange(3))} + data_vars = { + "temperature": (("time", "x"), np.random.random((5, 3))), + } + ds = xr.Dataset(data_vars, coords) + ds_no_indexes = ds.drop_indexes(["time", "x"]) + + assert len(ds_no_indexes.xindexes) == 0 + + result = await _maybe_create_default_indexes_async(ds_no_indexes) + + assert "time" in result.xindexes + assert "x" in result.xindexes + assert len(result.xindexes) == 2 + + @pytest.mark.asyncio + async def test_maybe_create_default_indexes_async_partial_indexes(self): + """Test with mix of coords that need indexes and those that don't.""" + coords = { + "time": ("time", np.arange(5)), + "x": ("x", np.arange(3)), + } + data_vars = { + "temperature": (("time", "x"), np.random.random((5, 3))), + } + ds = xr.Dataset(data_vars=data_vars, coords=coords) + ds_partial = ds.drop_indexes(["x"]) + + assert "time" in ds_partial.xindexes + assert "x" not in ds_partial.xindexes + + result = await _maybe_create_default_indexes_async(ds_partial) + + assert "time" in result.xindexes + assert "x" in result.xindexes + + @pytest.mark.asyncio + async def test_maybe_create_default_indexes_async_all_indexes_exist(self): + """Test that function returns original dataset when all coords already have indexes.""" + ds = create_dataset_with_coordinates(n_coords=2) + + assert len(ds.xindexes) > 0 + + result = await _maybe_create_default_indexes_async(ds) + assert result is ds # Same object returned