Skip to content

Commit aaf1aa4

Browse files
committed
Update overloads
1 parent d7d5730 commit aaf1aa4

File tree

3 files changed

+93
-88
lines changed

3 files changed

+93
-88
lines changed

src/mdio/core/v1/__init__.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,12 @@
1-
# mdio/__init__.py
1+
"""
2+
MDIO core v1 package initialization.
3+
Exposes the MDIO overloads and core v1 functionality.
4+
"""
25

3-
import xarray as _xr
4-
from ._overloads import open_mdio, to_mdio
6+
from ._overloads import mdio
7+
from .constructor import Write_MDIO_metadata
58

69
__all__ = [
7-
# explicit overrides / aliases
8-
"open_mdio",
9-
"to_mdio",
10-
# everything else will be auto-populated by __dir__ / __getattr__
10+
"mdio",
11+
"Write_MDIO_metadata",
1112
]
12-
13-
def __getattr__(name: str):
14-
"""
15-
Fallback: anything not defined in mdio/__init__.py
16-
gets looked up on xarray.
17-
"""
18-
if hasattr(_xr, name):
19-
return getattr(_xr, name)
20-
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
21-
22-
def __dir__():
23-
"""
24-
Make dir(mdio) list our overrides and then all public xarray names.
25-
"""
26-
xr_public = [n for n in dir(_xr) if not n.startswith("_")]
27-
return sorted(__all__ + xr_public)

src/mdio/core/v1/_overloads.py

Lines changed: 56 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,62 @@
55
1. To provide a consistent mdio.* naming scheme.
66
2. To simplify the API for users where it makes sense (e.g. MDIO v1 uses Zarr and not HDF5).
77
"""
8+
import xarray as xr
9+
from xarray import Dataset as _Dataset, DataArray as _DataArray
810

911

10-
import xarray as _xr
11-
from xarray import Dataset as _Dataset, DataArray as _DataArray
12+
class MDIODataset(_Dataset):
13+
"""xarray.Dataset subclass with MDIO v1 extensions."""
14+
__slots__ = ()
15+
16+
def to_mdio(self, store=None, *args, **kwargs):
17+
"""
18+
Alias for `.to_zarr()`, prints a greeting, and writes to Zarr store.
19+
"""
20+
print("👋 hello world from mdio.to_mdio!")
21+
return super().to_zarr(store=store, *args, **kwargs)
22+
23+
24+
class MDIODataArray(_DataArray):
25+
"""xarray.DataArray subclass with MDIO v1 extensions."""
26+
__slots__ = ()
27+
28+
def to_mdio(self, store=None, *args, **kwargs):
29+
"""
30+
Alias for `.to_zarr()`, prints a greeting, and writes to Zarr store.
31+
"""
32+
print("👋 hello world from mdio.to_mdio!")
33+
return super().to_zarr(store=store, *args, **kwargs)
34+
35+
36+
class MDIO:
37+
"""MDIO namespace for overloaded types and functions."""
38+
Dataset = MDIODataset
39+
DataArray = MDIODataArray
40+
41+
@staticmethod
42+
def open(store, *args, engine="zarr", consolidated=False, **kwargs):
43+
"""
44+
Open a Zarr store as an MDIODataset. Prints a greeting and casts
45+
the returned xarray.Dataset (and its variables) to the MDIO subclasses.
46+
"""
47+
print("👋 hello world from mdio.open!")
48+
ds = xr.open_dataset(
49+
store,
50+
engine=engine,
51+
consolidated=consolidated,
52+
*args,
53+
**kwargs,
54+
)
55+
# Cast Dataset to MDIODataset
56+
ds.__class__ = MDIODataset
57+
# Cast each DataArray in data_vars and coords
58+
for name, var in ds.data_vars.items():
59+
var.__class__ = MDIODataArray
60+
for name, coord in ds.coords.items():
61+
coord.__class__ = MDIODataArray
62+
return ds
63+
1264

13-
def open_mdio(store, *args, engine="zarr", consolidated=False, **kwargs):
14-
"""
15-
Our mdio version of xr.open_zarr. Prints a greeting,
16-
then calls xr.open_dataset(..., engine="zarr").
17-
"""
18-
print("👋 hello world from mdio.open_mdio!")
19-
return _xr.open_dataset(store, *args,
20-
engine=engine,
21-
consolidated=consolidated,
22-
**kwargs)
23-
24-
def to_mdio(self, *args, **kwargs):
25-
"""
26-
Alias for .to_zarr, renamed to .to_mdio,
27-
so you get a consistent mdio.* naming.
28-
"""
29-
print("👋 hello world from mdio.to_mdio!")
30-
print(f"kwargs: {kwargs}")
31-
return self.to_zarr(*args, **kwargs)
32-
33-
# Monkey-patch Dataset and DataArray so that you can do:
34-
# ds.to_mdio(...) and arr.to_mdio(...)
35-
_Dataset.to_mdio = to_mdio
36-
_DataArray.to_mdio = to_mdio
65+
# Create module-level MDIO namespace
66+
mdio = MDIO()

src/mdio/core/v1/constructor.py

Lines changed: 29 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Construct an Xarray Dataset from an MDIO v1 Dataset and write to Zarr."""
1+
"""Construct an MDIO dataset and write to Zarr."""
22
import xarray as xr
33
import numpy as np
44
import dask.array as da
@@ -10,18 +10,19 @@
1010
from mdio.schema.dtype import ScalarType, StructuredType
1111
from mdio.schema.compressors import Blosc, ZFP
1212
from mdio.schema.v1.variable import Coordinate
13-
13+
from mdio.core.v1._overloads import mdio
1414

1515
from numcodecs import Blosc as NumcodecsBlosc
1616

1717
try:
18-
import zfpy as BaseZFPY # Baser library
18+
import zfpy as BaseZFPY # Base library
1919
from numcodecs import ZFPY as NumcodecsZFPY # Codec
2020
except ImportError:
21-
print(f"Tried to import zfpy and numcodes zfpy but failed because {ImportError}")
21+
print(f"Tried to import zfpy and numcodecs zfpy but failed because {ImportError}")
2222
BaseZFPY = None
2323
NumcodecsZFPY = None
2424

25+
2526
def _convert_compressor(model: Blosc | ZFP | None) -> NumcodecsBlosc | NumcodecsZFPY | None:
2627
if isinstance(model, Blosc):
2728
return NumcodecsBlosc(
@@ -45,7 +46,7 @@ def _convert_compressor(model: Blosc | ZFP | None) -> NumcodecsBlosc | Numcodecs
4546
raise TypeError(f"Unsupported compressor model: {type(model)}")
4647

4748

48-
def _construct_mdio_dataset(mdio_ds: MDIODataset) -> xr.Dataset:
49+
def _construct_mdio_dataset(mdio_ds: MDIODataset) -> mdio.Dataset:
4950
"""Build an MDIO dataset with correct dimensions and dtypes.
5051
5152
This internal function constructs the underlying data structure for an MDIO dataset,
@@ -65,7 +66,7 @@ def _construct_mdio_dataset(mdio_ds: MDIODataset) -> xr.Dataset:
6566
dims[d.name] = d.size
6667

6768
# Build data variables
68-
data_vars: dict[str, xr.DataArray] = {}
69+
data_vars: dict[str, mdio.DataArray] = {}
6970
for var in mdio_ds.variables:
7071
dim_names = [d.name if isinstance(d, NamedDimension) else d for d in var.dimensions]
7172
shape = tuple(dims[name] for name in dim_names)
@@ -76,12 +77,8 @@ def _construct_mdio_dataset(mdio_ds: MDIODataset) -> xr.Dataset:
7677
dtype = np.dtype([(f.name, f.format.value) for f in dt.fields])
7778
else:
7879
raise TypeError(f"Unsupported data_type: {dt}")
79-
# arr = da.zeros(shape, dtype=dtype)
8080
arr = np.zeros(shape, dtype=dtype)
81-
data_array = xr.DataArray(arr, dims=dim_names)
82-
# set default fill_value to zero instead of NaN
83-
# TODO: This seems to be ignored by xarray.
84-
# Setting in the _generate_encodings() function does work though.
81+
data_array = mdio.DataArray(arr, dims=dim_names)
8582
data_array.encoding['fill_value'] = 0.0
8683

8784
# Set long_name if present
@@ -90,33 +87,30 @@ def _construct_mdio_dataset(mdio_ds: MDIODataset) -> xr.Dataset:
9087

9188
# Set coordinates if present, excluding dimension names
9289
if var.coordinates is not None:
93-
# Get the set of dimension names for this variable
9490
dim_set = set(dim_names)
95-
# Filter out any coordinates that are also dimensions
9691
coord_names = [
97-
c.name if isinstance(c, Coordinate) else c
98-
for c in var.coordinates
92+
c.name if isinstance(c, Coordinate) else c
93+
for c in var.coordinates
9994
if (c.name if isinstance(c, Coordinate) else c) not in dim_set
10095
]
101-
if coord_names: # Only set coordinates if there are any non-dimension coordinates
96+
if coord_names:
10297
data_array.attrs["coordinates"] = " ".join(coord_names)
10398

104-
# attach variable metadata into DataArray attributes, excluding nulls and chunkGrid
99+
# Attach variable metadata into DataArray attributes
105100
if var.metadata is not None:
106101
md = var.metadata.model_dump(
107102
by_alias=True,
108103
exclude_none=True,
109104
exclude={"chunk_grid"},
110105
)
111-
# Convert single-element lists to objects
112106
for key, value in md.items():
113107
if isinstance(value, list) and len(value) == 1:
114108
md[key] = value[0]
115109
data_array.attrs.update(md)
116110
data_vars[var.name] = data_array
117111

118-
ds = xr.Dataset(data_vars)
119-
# Attach metadata as attrs
112+
ds = mdio.Dataset(data_vars)
113+
# Attach dataset metadata
120114
ds.attrs["apiVersion"] = mdio_ds.metadata.api_version
121115
ds.attrs["createdOn"] = str(mdio_ds.metadata.created_on)
122116
ds.attrs["name"] = mdio_ds.metadata.name
@@ -125,16 +119,13 @@ def _construct_mdio_dataset(mdio_ds: MDIODataset) -> xr.Dataset:
125119
return ds
126120

127121

128-
129-
130-
def Write_MDIO_metadata(mdio_ds: MDIODataset, store: str, **kwargs: Any) -> xr.Dataset:
131-
"""Write MDIO metadata to a Zarr store and return the constructed xarray.Dataset.
122+
def Write_MDIO_metadata(mdio_ds: MDIODataset, store: str, **kwargs: Any) -> mdio.Dataset:
123+
"""Write MDIO metadata to a Zarr store and return the constructed mdio.Dataset.
132124
133-
This function constructs an xarray.Dataset from the MDIO dataset and writes its metadata
125+
This function constructs an mdio.Dataset from the MDIO dataset and writes its metadata
134126
to a Zarr store. The actual data is not written, only the metadata structure is created.
135127
"""
136128
ds = _construct_mdio_dataset(mdio_ds)
137-
# Write to Zarr format v2 with consolidated metadata and all attributes
138129

139130
def _generate_encodings() -> dict:
140131
"""Generate encodings for each variable in the MDIO dataset.
@@ -147,29 +138,28 @@ def _generate_encodings() -> dict:
147138
for var in mdio_ds.variables:
148139
fill_value = 0
149140
if isinstance(var.data_type, StructuredType):
150-
# Create a structured fill value that matches the dtype
151-
# fill_value = np.zeros(1, dtype=[(f.name, f.format.value) for f in var.data_type.fields])[0]
152-
# TODO: Re-enable this once xarray supports this PR https://github.com/zarr-developers/zarr-python/pull/3015
153141
continue
154142
chunks = None
155143
if var.metadata is not None and var.metadata.chunk_grid is not None:
156144
chunks = var.metadata.chunk_grid.configuration.chunk_shape
157145
global_encodings[var.name] = {
158146
"chunks": chunks,
159-
# TODO: Re-enable this once xarray supports this PR https://github.com/pydata/xarray/pull/10274
147+
# TODO: Re-enable chunk_key_encoding when supported by xarray
160148
# "chunk_key_encoding": dimension_separator_encoding,
161149
"_FillValue": fill_value,
162150
"dtype": var.data_type,
163151
"compressors": _convert_compressor(var.compressor),
164152
}
165153
return global_encodings
166154

167-
ds.to_mdio(store,
168-
mode="w",
169-
zarr_format=2,
170-
consolidated=True,
171-
safe_chunks=False, # This ignores the Dask chunks
172-
compute=False, # Ensures only the metadata is written
173-
encoding=_generate_encodings(),
174-
**kwargs)
175-
return ds
155+
ds.to_mdio(
156+
store,
157+
mode="w",
158+
zarr_format=2,
159+
consolidated=True,
160+
safe_chunks=False,
161+
compute=False,
162+
encoding=_generate_encodings(),
163+
**kwargs
164+
)
165+
return ds

0 commit comments

Comments
 (0)