From 48e0a19e95c83f6142be0cf989ef921cf50b29f4 Mon Sep 17 00:00:00 2001 From: malmans2 Date: Fri, 9 Jan 2026 17:39:18 +0100 Subject: [PATCH 1/2] fix bug when selecting multiple coords --- pyproject.toml | 3 +++ tests/test_20_open_dataset.py | 20 ++++++++++++++++++++ xarray_esgf/client.py | 34 ++++++++++++++++++++++++++-------- 3 files changed, 49 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c3d701d..2deceaf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,9 @@ show_error_codes = true [tool.pytest.ini_options] addopts = ["-vv"] +filterwarnings = [ + "ignore:.*Using unknown location.*:UserWarning:esgpull", +] testpaths = ["tests"] [tool.ruff] diff --git a/tests/test_20_open_dataset.py b/tests/test_20_open_dataset.py index 98f9b1a..f1f80ce 100644 --- a/tests/test_20_open_dataset.py +++ b/tests/test_20_open_dataset.py @@ -83,3 +83,23 @@ def test_open_dataset(tmp_path: Path, index_node: str, download: bool) -> None: "CMIP6.ScenarioMIP.EC-Earth-Consortium.EC-Earth3-CC.ssp585.r1i1p1f1.Amon.tas.gr.v20210113", "CMIP6.ScenarioMIP.EC-Earth-Consortium.EC-Earth3-CC.ssp585.r1i1p1f1.fx.areacella.gr.v20210113", ] + + +def test_combine_coords(tmp_path: Path, index_node: str) -> None: + esgpull_path = tmp_path / "esgpull" + selection = { + "query": [ + '"areacella_fx_IPSL-CM6A-LR_historical_r1i1p1f1_gr.nc"', + '"orog_fx_IPSL-CM6A-LR_historical_r1i1p1f1_gr.nc"', + ] + } + ds = xr.open_dataset( + selection, # type: ignore[arg-type] + esgpull_path=esgpull_path, + concat_dims="experiment_id", + engine="esgf", + index_node=index_node, + chunks={}, + ) + assert set(ds.coords) == {"areacella", "lat", "lon", "experiment_id", "orog"} + assert not ds.data_vars diff --git a/xarray_esgf/client.py b/xarray_esgf/client.py index 6ceec20..a34ecbb 100644 --- a/xarray_esgf/client.py +++ b/xarray_esgf/client.py @@ -2,7 +2,7 @@ import dataclasses import logging from collections import defaultdict -from collections.abc import Callable, Iterable +from collections.abc import Callable, Hashable, Iterable from functools import cached_property from pathlib import Path from typing import Literal, get_args @@ -114,14 +114,13 @@ def download(self) -> list[File]: raise ExceptionGroup(msg, exceptions) return files - @use_new_combine_kwarg_defaults - def open_dataset( + def _open_datasets( self, concat_dims: DATASET_ID_KEYS | Iterable[DATASET_ID_KEYS] | None, drop_variables: str | Iterable[str] | None = None, download: bool = False, show_progress: bool = True, - ) -> Dataset: + ) -> dict[str, Dataset]: if isinstance(concat_dims, str): concat_dims = [concat_dims] concat_dims = concat_dims or [] @@ -138,10 +137,9 @@ def open_dataset( chunks=-1, engine="h5netcdf", drop_variables=drop_variables, - storage_options={"verify_ssl": self.verify_ssl}, + storage_options={"ssl": self.verify_ssl}, ) grouped_objects[file.dataset_id].append(ds.drop_encoding()) - combined_datasets = {} for dataset_id, datasets in grouped_objects.items(): dataset_id_dict = dataset_id_to_dict(dataset_id) @@ -165,17 +163,37 @@ def open_dataset( combined_datasets[dataset_id] = ds LOGGER.debug(f"{dataset_id}: {dict(ds.sizes)}") + return combined_datasets + + @use_new_combine_kwarg_defaults + def open_dataset( + self, + concat_dims: DATASET_ID_KEYS | Iterable[DATASET_ID_KEYS] | None, + drop_variables: str | Iterable[str] | None = None, + download: bool = False, + show_progress: bool = True, + ) -> Dataset: + combined_datasets = self._open_datasets( + concat_dims, drop_variables, download, show_progress + ) + obj = xr.combine_by_coords( - combined_datasets.values(), + [ds.reset_coords() for ds in combined_datasets.values()], join="exact", combine_attrs="drop_conflicts", ) if isinstance(obj, DataArray): obj = obj.to_dataset() - obj.attrs["dataset_ids"] = sorted(grouped_objects) + + coords: set[Hashable] = set() + for ds in combined_datasets.values(): + coords.update(ds.coords) + obj = obj.set_coords(coords) for name, var in obj.variables.items(): if name not in obj.dims: var.encoding["preferred_chunks"] = dict(var.chunksizes) + obj.attrs["dataset_ids"] = sorted(combined_datasets) + return obj From f77818ed92842ea1ef0750e124800ce1a8635d5b Mon Sep 17 00:00:00 2001 From: malmans2 Date: Fri, 9 Jan 2026 17:41:23 +0100 Subject: [PATCH 2/2] cleanup --- xarray_esgf/client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray_esgf/client.py b/xarray_esgf/client.py index a34ecbb..05a13af 100644 --- a/xarray_esgf/client.py +++ b/xarray_esgf/client.py @@ -140,6 +140,7 @@ def _open_datasets( storage_options={"ssl": self.verify_ssl}, ) grouped_objects[file.dataset_id].append(ds.drop_encoding()) + combined_datasets = {} for dataset_id, datasets in grouped_objects.items(): dataset_id_dict = dataset_id_to_dict(dataset_id)