From 05c6e8ea56180fcf9b74b567ae0d751296cd33b5 Mon Sep 17 00:00:00 2001 From: malmans2 Date: Wed, 10 Dec 2025 15:30:43 +0100 Subject: [PATCH] test compute --- tests/test_20_open_dataset.py | 54 +++++++++++++++++++++++++++-------- xarray_esgf/client.py | 6 ++-- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/tests/test_20_open_dataset.py b/tests/test_20_open_dataset.py index 98f9b1a..25043f8 100644 --- a/tests/test_20_open_dataset.py +++ b/tests/test_20_open_dataset.py @@ -1,10 +1,36 @@ from pathlib import Path +import dask import pytest import xarray as xr -@pytest.mark.parametrize("download", [True, False]) +class CountingScheduler: + """Simple dask scheduler counting the number of computes. + + Reference: https://stackoverflow.com/questions/53289286/""" + + def __init__(self, max_computes: int = 0) -> None: + self.total_computes = 0 + self.max_computes = max_computes + + def __call__(self, dsk, keys, **kwargs): # type: ignore[no-untyped-def] + self.total_computes += 1 + if self.total_computes > self.max_computes: + msg = f"Too many computes. Total: {self.total_computes:d} > max: {self.max_computes:d}." + raise RuntimeError(msg) + return dask.get(dsk, keys, **kwargs) + + +def raise_if_dask_computes(max_computes: int = 0) -> dask.config.set: + scheduler = CountingScheduler(max_computes) + return dask.config.set(scheduler=scheduler) + + +@pytest.mark.parametrize( + "download", + [True, False], +) def test_open_dataset(tmp_path: Path, index_node: str, download: bool) -> None: esgpull_path = tmp_path / "esgpull" selection = { @@ -21,16 +47,16 @@ def test_open_dataset(tmp_path: Path, index_node: str, download: bool) -> None: '"CMIP6.ScenarioMIP.EC-Earth-Consortium.EC-Earth3-CC.ssp585.r1i1p1f1.fx.areacella.gr.v20210113.areacella_fx_EC-Earth3-CC_ssp585_r1i1p1f1_gr.nc"', ] } - ds = xr.open_dataset( - selection, # type: ignore[arg-type] - esgpull_path=esgpull_path, - concat_dims="experiment_id", - engine="esgf", - index_node=index_node, - download=download, - chunks={}, - ) - + with raise_if_dask_computes(): + ds = xr.open_dataset( + selection, # type: ignore[arg-type] + esgpull_path=esgpull_path, + concat_dims="experiment_id", + engine="esgf", + index_node=index_node, + download=download, + chunks={}, + ) assert (esgpull_path / "data" / "CMIP6").exists() is download # Chunks @@ -38,7 +64,7 @@ def test_open_dataset(tmp_path: Path, index_node: str, download: bool) -> None: assert not ds[dim].chunks assert ds.chunksizes == { "experiment_id": (1, 1), - "time": (12, 12), + "time": (24,), "lat": (256,), "lon": (512,), "bnds": (2,), @@ -83,3 +109,7 @@ def test_open_dataset(tmp_path: Path, index_node: str, download: bool) -> None: "CMIP6.ScenarioMIP.EC-Earth-Consortium.EC-Earth3-CC.ssp585.r1i1p1f1.Amon.tas.gr.v20210113", "CMIP6.ScenarioMIP.EC-Earth-Consortium.EC-Earth3-CC.ssp585.r1i1p1f1.fx.areacella.gr.v20210113", ] + + # Compute + with raise_if_dask_computes(1): + ds.compute() diff --git a/xarray_esgf/client.py b/xarray_esgf/client.py index 90d7ca5..ef9f268 100644 --- a/xarray_esgf/client.py +++ b/xarray_esgf/client.py @@ -132,7 +132,6 @@ def open_dataset( ): ds = xr.open_dataset( self._client.fs[file].drs if download else file.url, - chunks=-1, engine="h5netcdf", drop_variables=drop_variables, ) @@ -171,6 +170,9 @@ def open_dataset( for name, var in obj.variables.items(): if name not in obj.dims: - var.encoding["preferred_chunks"] = dict(var.chunksizes) + var.encoding["preferred_chunks"] = { + dim: (1,) * var.sizes[dim] + for dim in set(var.dims) & set(concat_dims) + } return obj