Skip to content

Commit 7d16694

Browse files
authored
Merge pull request #5 from BrianMichell/v1_private_api_CP
Resync with v1
2 parents 616104c + 968bc60 commit 7d16694

File tree

12 files changed

+1609
-5
lines changed

12 files changed

+1609
-5
lines changed

noxfile.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def tests(session: Session) -> None:
200200
"pygments",
201201
"pytest-dependency",
202202
"s3fs",
203+
"zfpy", # TODO(BrianMichell): Ensure this is pulling from the pyproject.toml
203204
],
204205
)
205206

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ dependencies = [
3636
"tqdm (>=4.67.0,<5.0.0)",
3737
"xarray>=2025.3.1",
3838
"zarr (>=3.0.4,<3.0.7)",
39+
"pint (>=0.24.3,<0.25)",
40+
"xarray (>=2025.4.0)",
3941
]
4042

4143
[project.optional-dependencies]

src/mdio/core/v1/__init__.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""MDIO core v1 package initialization.
2+
3+
Exposes the MDIO overloads and core v1 functionality.
4+
"""
5+
6+
from ._overloads import mdio
7+
from ._serializer import make_coordinate
8+
from ._serializer import make_dataset
9+
from ._serializer import make_dataset_metadata
10+
from ._serializer import make_named_dimension
11+
from ._serializer import make_variable
12+
from .builder import MDIODatasetBuilder
13+
from .builder import write_mdio_metadata
14+
from .factory import SCHEMA_TEMPLATE_MAP
15+
from .factory import MDIOSchemaType
16+
17+
18+
__all__ = [
19+
"MDIODatasetBuilder",
20+
"make_coordinate",
21+
"make_dataset",
22+
"make_dataset_metadata",
23+
"make_named_dimension",
24+
"make_variable",
25+
"mdio",
26+
"write_mdio_metadata",
27+
"MDIOSchemaType",
28+
"SCHEMA_TEMPLATE_MAP",
29+
]

src/mdio/core/v1/_overloads.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""Overloads for xarray.
2+
3+
The intent of overloading here is:
4+
1. To provide a consistent mdio.* naming scheme.
5+
2. To simplify the API for users where it makes sense (e.g. MDIO v1 uses Zarr and not HDF5).
6+
"""
7+
8+
import xarray as xr
9+
from xarray import DataArray as _DataArray
10+
from xarray import Dataset as _Dataset
11+
12+
13+
class MDIODataset(_Dataset):
14+
"""xarray.Dataset subclass with MDIO v1 extensions."""
15+
16+
__slots__ = ()
17+
18+
def to_mdio(self, store=None, *args, **kwargs):
19+
"""Alias for `.to_zarr()`, prints a greeting, and writes to Zarr store."""
20+
print("👋 hello world from mdio.to_mdio!")
21+
return super().to_zarr(*args, store=store, **kwargs)
22+
23+
24+
class MDIODataArray(_DataArray):
25+
"""xarray.DataArray subclass with MDIO v1 extensions."""
26+
27+
__slots__ = ()
28+
29+
def to_mdio(self, store=None, *args, **kwargs):
30+
"""Alias for `.to_zarr()`, prints a greeting, and writes to Zarr store."""
31+
print("👋 hello world from mdio.to_mdio!")
32+
return super().to_zarr(*args, store=store, **kwargs)
33+
34+
35+
class MDIO:
36+
"""MDIO namespace for overloaded types and functions."""
37+
38+
Dataset = MDIODataset
39+
DataArray = MDIODataArray
40+
41+
@staticmethod
42+
def open(store, *args, engine="zarr", consolidated=False, **kwargs):
43+
"""Open a Zarr store as an MDIODataset.
44+
45+
Casts the returned xarray.Dataset (and its variables) to the MDIO subclasses.
46+
"""
47+
print("👋 hello world from mdio.open!")
48+
ds = xr.open_dataset(
49+
store,
50+
*args,
51+
engine=engine,
52+
consolidated=consolidated,
53+
**kwargs,
54+
)
55+
# Cast Dataset to MDIODataset
56+
ds.__class__ = MDIODataset
57+
# Cast each DataArray in data_vars and coords
58+
for _name, var in ds.data_vars.items():
59+
var.__class__ = MDIODataArray
60+
for _name, coord in ds.coords.items():
61+
coord.__class__ = MDIODataArray
62+
return ds
63+
64+
65+
# Create module-level MDIO namespace
66+
mdio = MDIO()

src/mdio/core/v1/_serializer.py

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
"""Internal serialization module for MDIO v1 datasets.
2+
3+
This module contains internal implementation details for serializing MDIO schema models
4+
to Zarr storage. This API is not considered stable and may change without notice.
5+
"""
6+
7+
from datetime import datetime
8+
from typing import Any
9+
from typing import Dict
10+
from typing import List
11+
from typing import Optional
12+
13+
import numpy as np
14+
from numcodecs import Blosc as NumcodecsBlosc
15+
16+
from mdio.core.v1._overloads import mdio
17+
from mdio.schemas.compressors import ZFP
18+
from mdio.schemas.compressors import Blosc
19+
from mdio.schemas.dimension import NamedDimension
20+
from mdio.schemas.dtype import ScalarType
21+
from mdio.schemas.dtype import StructuredType
22+
from mdio.schemas.metadata import UserAttributes
23+
from mdio.schemas.v1.dataset import Dataset as MDIODataset
24+
from mdio.schemas.v1.dataset import DatasetMetadata
25+
from mdio.schemas.v1.units import AllUnits
26+
from mdio.schemas.v1.variable import Coordinate
27+
from mdio.schemas.v1.variable import Variable
28+
from mdio.schemas.v1.variable import VariableMetadata
29+
30+
31+
try:
32+
import zfpy as zfpy_base # Base library
33+
from numcodecs import ZFPY # Codec
34+
except ImportError:
35+
print(f"Tried to import zfpy and numcodecs zfpy but failed because {ImportError}")
36+
zfpy_base = None
37+
ZFPY = None
38+
39+
40+
def make_named_dimension(name: str, size: int) -> NamedDimension:
41+
"""Create a NamedDimension with the given name and size."""
42+
return NamedDimension(name=name, size=size)
43+
44+
45+
def make_coordinate(
46+
name: str,
47+
dimensions: List[NamedDimension | str],
48+
data_type: ScalarType | StructuredType,
49+
long_name: str = None,
50+
metadata: Optional[List[AllUnits | UserAttributes]] = None,
51+
) -> Coordinate:
52+
"""Create a Coordinate with the given name, dimensions, data_type, and metadata."""
53+
return Coordinate(
54+
name=name,
55+
long_name=long_name,
56+
dimensions=dimensions,
57+
data_type=data_type,
58+
metadata=metadata,
59+
)
60+
61+
62+
def make_variable( # noqa: C901
63+
name: str,
64+
dimensions: List[NamedDimension | str],
65+
data_type: ScalarType | StructuredType,
66+
long_name: str = None,
67+
compressor: Blosc | ZFP | None = None,
68+
coordinates: Optional[List[Coordinate | str]] = None,
69+
metadata: Optional[
70+
List[AllUnits | UserAttributes] | Dict[str, Any] | VariableMetadata
71+
] = None,
72+
) -> Variable:
73+
"""Create a Variable with the given parameters.
74+
75+
Args:
76+
name: Name of the variable
77+
dimensions: List of dimensions
78+
data_type: Data type of the variable
79+
long_name: Optional long name
80+
compressor: Optional compressor
81+
coordinates: Optional list of coordinates
82+
metadata: Optional metadata
83+
84+
Returns:
85+
Variable: A Variable instance with the specified parameters.
86+
87+
Raises:
88+
TypeError: If the metadata type is not supported.
89+
"""
90+
# Convert metadata to VariableMetadata if needed
91+
var_metadata = None
92+
if metadata:
93+
if isinstance(metadata, list):
94+
# Convert list of metadata to dict
95+
metadata_dict = {}
96+
for md in metadata:
97+
if isinstance(md, AllUnits):
98+
# For units_v1, if it's a single element, use it directly
99+
if isinstance(md.units_v1, list) and len(md.units_v1) == 1:
100+
metadata_dict["units_v1"] = md.units_v1[0]
101+
else:
102+
metadata_dict["units_v1"] = md.units_v1
103+
elif isinstance(md, UserAttributes):
104+
# For attributes, if it's a single element, use it directly
105+
attrs = md.model_dump(by_alias=True)
106+
if isinstance(attrs, list) and len(attrs) == 1:
107+
metadata_dict["attributes"] = attrs[0]
108+
else:
109+
metadata_dict["attributes"] = attrs
110+
var_metadata = VariableMetadata(**metadata_dict)
111+
elif isinstance(metadata, dict):
112+
# Convert camelCase keys to snake_case for VariableMetadata
113+
converted_dict = {}
114+
for key, value in metadata.items():
115+
if key == "unitsV1":
116+
# For units_v1, if it's a single element array, use the element directly
117+
if isinstance(value, list) and len(value) == 1:
118+
converted_dict["units_v1"] = value[0]
119+
else:
120+
converted_dict["units_v1"] = value
121+
else:
122+
converted_dict[key] = value
123+
var_metadata = VariableMetadata(**converted_dict)
124+
elif isinstance(metadata, VariableMetadata):
125+
var_metadata = metadata
126+
else:
127+
raise TypeError(f"Unsupported metadata type: {type(metadata)}")
128+
129+
# Create the variable with all attributes explicitly set
130+
return Variable(
131+
name=name,
132+
long_name=long_name,
133+
dimensions=dimensions,
134+
data_type=data_type,
135+
compressor=compressor,
136+
coordinates=coordinates,
137+
metadata=var_metadata,
138+
)
139+
140+
141+
def make_dataset_metadata(
142+
name: str,
143+
api_version: str,
144+
created_on: datetime,
145+
attributes: Optional[Dict[str, Any]] = None,
146+
) -> DatasetMetadata:
147+
"""Create a DatasetMetadata with name, api_version, created_on, and optional attributes."""
148+
return DatasetMetadata(
149+
name=name,
150+
api_version=api_version,
151+
created_on=created_on,
152+
attributes=attributes,
153+
)
154+
155+
156+
def make_dataset(
157+
variables: List[Variable],
158+
metadata: DatasetMetadata,
159+
) -> MDIODataset:
160+
"""Create a Dataset with the given variables and metadata."""
161+
return MDIODataset(
162+
variables=variables,
163+
metadata=metadata,
164+
)
165+
166+
167+
def _convert_compressor(
168+
model: Blosc | ZFP | None,
169+
) -> NumcodecsBlosc | ZFPY | None:
170+
if isinstance(model, Blosc):
171+
return NumcodecsBlosc(
172+
cname=model.algorithm.value,
173+
clevel=model.level,
174+
shuffle=model.shuffle.value,
175+
blocksize=model.blocksize if model.blocksize > 0 else 0,
176+
)
177+
elif isinstance(model, ZFP):
178+
if zfpy_base is None or ZFPY is None:
179+
raise ImportError("zfpy and numcodecs are required to use ZFP compression")
180+
return ZFPY(
181+
mode=model.mode.value,
182+
tolerance=model.tolerance,
183+
rate=model.rate,
184+
precision=model.precision,
185+
)
186+
elif model is None:
187+
return None
188+
else:
189+
raise TypeError(f"Unsupported compressor model: {type(model)}")
190+
191+
192+
def _construct_mdio_dataset(mdio_ds: MDIODataset) -> mdio.Dataset: # noqa: C901
193+
"""Build an MDIO dataset with correct dimensions and dtypes.
194+
195+
This internal function constructs the underlying data structure for an MDIO dataset,
196+
handling dimension mapping, data types, and metadata organization.
197+
198+
Args:
199+
mdio_ds: The source MDIO dataset to construct from.
200+
201+
Returns:
202+
The constructed dataset with proper MDIO structure and metadata.
203+
204+
Raises:
205+
TypeError: If an unsupported data type is encountered.
206+
"""
207+
# Collect dimension sizes
208+
dims: dict[str, int] = {}
209+
for var in mdio_ds.variables:
210+
for d in var.dimensions:
211+
if isinstance(d, NamedDimension):
212+
dims[d.name] = d.size
213+
214+
# Build data variables
215+
data_vars: dict[str, mdio.DataArray] = {}
216+
for var in mdio_ds.variables:
217+
dim_names = [
218+
d.name if isinstance(d, NamedDimension) else d for d in var.dimensions
219+
]
220+
shape = tuple(dims[name] for name in dim_names)
221+
dt = var.data_type
222+
if isinstance(dt, ScalarType):
223+
dtype = np.dtype(dt.value)
224+
elif isinstance(dt, StructuredType):
225+
dtype = np.dtype([(f.name, f.format.value) for f in dt.fields])
226+
else:
227+
raise TypeError(f"Unsupported data_type: {dt}")
228+
arr = np.zeros(shape, dtype=dtype)
229+
data_array = mdio.DataArray(arr, dims=dim_names)
230+
data_array.encoding["fill_value"] = 0.0
231+
232+
# Set long_name if present
233+
if var.long_name is not None:
234+
data_array.attrs["long_name"] = var.long_name
235+
236+
# Set coordinates if present, excluding dimension names
237+
if var.coordinates is not None:
238+
dim_set = set(dim_names)
239+
coord_names = [
240+
c.name if isinstance(c, Coordinate) else c
241+
for c in var.coordinates
242+
if (c.name if isinstance(c, Coordinate) else c) not in dim_set
243+
]
244+
if coord_names:
245+
data_array.attrs["coordinates"] = " ".join(coord_names)
246+
247+
# Attach variable metadata into DataArray attributes
248+
if var.metadata is not None:
249+
md = var.metadata.model_dump(
250+
by_alias=True,
251+
exclude_none=True,
252+
exclude={"chunk_grid"},
253+
)
254+
for key, value in md.items():
255+
if isinstance(value, list) and len(value) == 1:
256+
md[key] = value[0]
257+
data_array.attrs.update(md)
258+
data_vars[var.name] = data_array
259+
260+
ds = mdio.Dataset(data_vars)
261+
# Attach dataset metadata
262+
ds.attrs["apiVersion"] = mdio_ds.metadata.api_version
263+
ds.attrs["createdOn"] = str(mdio_ds.metadata.created_on)
264+
ds.attrs["name"] = mdio_ds.metadata.name
265+
if mdio_ds.metadata.attributes:
266+
ds.attrs["attributes"] = mdio_ds.metadata.attributes
267+
return ds

0 commit comments

Comments
 (0)