Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ show_error_codes = true

[tool.pytest.ini_options]
addopts = ["-vv"]
filterwarnings = [
"ignore:.*Using unknown location.*:UserWarning:esgpull",
]
testpaths = ["tests"]

[tool.ruff]
Expand Down
20 changes: 20 additions & 0 deletions tests/test_20_open_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,23 @@ def test_open_dataset(tmp_path: Path, index_node: str, download: bool) -> None:
"CMIP6.ScenarioMIP.EC-Earth-Consortium.EC-Earth3-CC.ssp585.r1i1p1f1.Amon.tas.gr.v20210113",
"CMIP6.ScenarioMIP.EC-Earth-Consortium.EC-Earth3-CC.ssp585.r1i1p1f1.fx.areacella.gr.v20210113",
]


def test_combine_coords(tmp_path: Path, index_node: str) -> None:
esgpull_path = tmp_path / "esgpull"
selection = {
"query": [
'"areacella_fx_IPSL-CM6A-LR_historical_r1i1p1f1_gr.nc"',
'"orog_fx_IPSL-CM6A-LR_historical_r1i1p1f1_gr.nc"',
]
}
ds = xr.open_dataset(
selection, # type: ignore[arg-type]
esgpull_path=esgpull_path,
concat_dims="experiment_id",
engine="esgf",
index_node=index_node,
chunks={},
)
assert set(ds.coords) == {"areacella", "lat", "lon", "experiment_id", "orog"}
assert not ds.data_vars
33 changes: 26 additions & 7 deletions xarray_esgf/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import dataclasses
import logging
from collections import defaultdict
from collections.abc import Callable, Iterable
from collections.abc import Callable, Hashable, Iterable
from functools import cached_property
from pathlib import Path
from typing import Literal, get_args
Expand Down Expand Up @@ -114,14 +114,13 @@ def download(self) -> list[File]:
raise ExceptionGroup(msg, exceptions)
return files

@use_new_combine_kwarg_defaults
def open_dataset(
def _open_datasets(
self,
concat_dims: DATASET_ID_KEYS | Iterable[DATASET_ID_KEYS] | None,
drop_variables: str | Iterable[str] | None = None,
download: bool = False,
show_progress: bool = True,
) -> Dataset:
) -> dict[str, Dataset]:
if isinstance(concat_dims, str):
concat_dims = [concat_dims]
concat_dims = concat_dims or []
Expand All @@ -138,7 +137,7 @@ def open_dataset(
chunks=-1,
engine="h5netcdf",
drop_variables=drop_variables,
storage_options={"verify_ssl": self.verify_ssl},
storage_options={"ssl": self.verify_ssl},
)
grouped_objects[file.dataset_id].append(ds.drop_encoding())

Expand All @@ -165,17 +164,37 @@ def open_dataset(
combined_datasets[dataset_id] = ds
LOGGER.debug(f"{dataset_id}: {dict(ds.sizes)}")

return combined_datasets

@use_new_combine_kwarg_defaults
def open_dataset(
self,
concat_dims: DATASET_ID_KEYS | Iterable[DATASET_ID_KEYS] | None,
drop_variables: str | Iterable[str] | None = None,
download: bool = False,
show_progress: bool = True,
) -> Dataset:
combined_datasets = self._open_datasets(
concat_dims, drop_variables, download, show_progress
)

obj = xr.combine_by_coords(
combined_datasets.values(),
[ds.reset_coords() for ds in combined_datasets.values()],
join="exact",
combine_attrs="drop_conflicts",
)
if isinstance(obj, DataArray):
obj = obj.to_dataset()
obj.attrs["dataset_ids"] = sorted(grouped_objects)

coords: set[Hashable] = set()
for ds in combined_datasets.values():
coords.update(ds.coords)
obj = obj.set_coords(coords)

for name, var in obj.variables.items():
if name not in obj.dims:
var.encoding["preferred_chunks"] = dict(var.chunksizes)

obj.attrs["dataset_ids"] = sorted(combined_datasets)

return obj