11"""Construct an MDIO dataset and write to Zarr."""
2- import xarray as xr
3- import numpy as np
4- import dask .array as da
5- from zarr .core .chunk_key_encodings import V2ChunkKeyEncoding
2+
63from typing import Any
74
8- from mdio .schema .v1 .dataset import Dataset as MDIODataset
5+ import numpy as np
6+ from numcodecs import Blosc as NumcodecsBlosc
7+ from zarr .core .chunk_key_encodings import V2ChunkKeyEncoding # noqa: F401
8+
9+ from mdio .core .v1 ._overloads import mdio
10+ from mdio .schema .compressors import ZFP
11+ from mdio .schema .compressors import Blosc
912from mdio .schema .dimension import NamedDimension
10- from mdio .schema .dtype import ScalarType , StructuredType
11- from mdio .schema .compressors import Blosc , ZFP
13+ from mdio .schema .dtype import ScalarType
14+ from mdio .schema .dtype import StructuredType
15+ from mdio .schema .v1 .dataset import Dataset as MDIODataset
1216from mdio .schema .v1 .variable import Coordinate
13- from mdio .core .v1 ._overloads import mdio
1417
15- from numcodecs import Blosc as NumcodecsBlosc
1618
1719try :
18- import zfpy as BaseZFPY # Base library
19- from numcodecs import ZFPY as NumcodecsZFPY # Codec
20+ import zfpy as zfpy_base # Base library
21+ from numcodecs import ZFPY # Codec
2022except ImportError :
2123 print (f"Tried to import zfpy and numcodecs zfpy but failed because { ImportError } " )
22- BaseZFPY = None
23- NumcodecsZFPY = None
24+ zfpy_base = None
25+ ZFPY = None
2426
2527
26- def _convert_compressor (model : Blosc | ZFP | None ) -> NumcodecsBlosc | NumcodecsZFPY | None :
28+ def _convert_compressor (
29+ model : Blosc | ZFP | None ,
30+ ) -> NumcodecsBlosc | ZFPY | None :
2731 if isinstance (model , Blosc ):
2832 return NumcodecsBlosc (
2933 cname = model .algorithm .value ,
3034 clevel = model .level ,
3135 shuffle = model .shuffle .value ,
32- blocksize = model .blocksize if model .blocksize > 0 else 0
36+ blocksize = model .blocksize if model .blocksize > 0 else 0 ,
3337 )
3438 elif isinstance (model , ZFP ):
35- if BaseZFPY is None or NumcodecsZFPY is None :
39+ if zfpy_base is None or ZFPY is None :
3640 raise ImportError ("zfpy and numcodecs are required to use ZFP compression" )
37- return NumcodecsZFPY (
41+ return ZFPY (
3842 mode = model .mode .value ,
3943 tolerance = model .tolerance ,
4044 rate = model .rate ,
@@ -46,17 +50,20 @@ def _convert_compressor(model: Blosc | ZFP | None) -> NumcodecsBlosc | Numcodecs
4650 raise TypeError (f"Unsupported compressor model: { type (model )} " )
4751
4852
49- def _construct_mdio_dataset (mdio_ds : MDIODataset ) -> mdio .Dataset :
53+ def _construct_mdio_dataset (mdio_ds : MDIODataset ) -> mdio .Dataset : # noqa: C901
5054 """Build an MDIO dataset with correct dimensions and dtypes.
51-
55+
5256 This internal function constructs the underlying data structure for an MDIO dataset,
5357 handling dimension mapping, data types, and metadata organization.
54-
58+
5559 Args:
5660 mdio_ds: The source MDIO dataset to construct from.
57-
61+
5862 Returns:
5963 The constructed dataset with proper MDIO structure and metadata.
64+
65+ Raises:
66+ TypeError: If an unsupported data type is encountered.
6067 """
6168 # Collect dimension sizes
6269 dims : dict [str , int ] = {}
@@ -68,7 +75,9 @@ def _construct_mdio_dataset(mdio_ds: MDIODataset) -> mdio.Dataset:
6875 # Build data variables
6976 data_vars : dict [str , mdio .DataArray ] = {}
7077 for var in mdio_ds .variables :
71- dim_names = [d .name if isinstance (d , NamedDimension ) else d for d in var .dimensions ]
78+ dim_names = [
79+ d .name if isinstance (d , NamedDimension ) else d for d in var .dimensions
80+ ]
7281 shape = tuple (dims [name ] for name in dim_names )
7382 dt = var .data_type
7483 if isinstance (dt , ScalarType ):
@@ -79,23 +88,23 @@ def _construct_mdio_dataset(mdio_ds: MDIODataset) -> mdio.Dataset:
7988 raise TypeError (f"Unsupported data_type: { dt } " )
8089 arr = np .zeros (shape , dtype = dtype )
8190 data_array = mdio .DataArray (arr , dims = dim_names )
82- data_array .encoding [' fill_value' ] = 0.0
83-
91+ data_array .encoding [" fill_value" ] = 0.0
92+
8493 # Set long_name if present
8594 if var .long_name is not None :
8695 data_array .attrs ["long_name" ] = var .long_name
87-
96+
8897 # Set coordinates if present, excluding dimension names
8998 if var .coordinates is not None :
9099 dim_set = set (dim_names )
91100 coord_names = [
92- c .name if isinstance (c , Coordinate ) else c
93- for c in var .coordinates
101+ c .name if isinstance (c , Coordinate ) else c
102+ for c in var .coordinates
94103 if (c .name if isinstance (c , Coordinate ) else c ) not in dim_set
95104 ]
96105 if coord_names :
97106 data_array .attrs ["coordinates" ] = " " .join (coord_names )
98-
107+
99108 # Attach variable metadata into DataArray attributes
100109 if var .metadata is not None :
101110 md = var .metadata .model_dump (
@@ -119,21 +128,24 @@ def _construct_mdio_dataset(mdio_ds: MDIODataset) -> mdio.Dataset:
119128 return ds
120129
121130
122- def Write_MDIO_metadata (mdio_ds : MDIODataset , store : str , ** kwargs : Any ) -> mdio .Dataset :
131+ def write_mdio_metadata (
132+ mdio_ds : MDIODataset , store : str , ** kwargs : Any
133+ ) -> mdio .Dataset :
123134 """Write MDIO metadata to a Zarr store and return the constructed mdio.Dataset.
124-
135+
125136 This function constructs an mdio.Dataset from the MDIO dataset and writes its metadata
126137 to a Zarr store. The actual data is not written, only the metadata structure is created.
127138 """
128139 ds = _construct_mdio_dataset (mdio_ds )
129-
140+
130141 def _generate_encodings () -> dict :
131142 """Generate encodings for each variable in the MDIO dataset.
132-
143+
133144 Returns:
134145 Dictionary mapping variable names to their encoding configurations.
135146 """
136- dimension_separator_encoding = V2ChunkKeyEncoding (separator = "/" ).to_dict ()
147+ # TODO: Re-enable chunk_key_encoding when supported by xarray
148+ # dimension_separator_encoding = V2ChunkKeyEncoding(separator="/").to_dict()
137149 global_encodings = {}
138150 for var in mdio_ds .variables :
139151 fill_value = 0
@@ -160,6 +172,6 @@ def _generate_encodings() -> dict:
160172 safe_chunks = False ,
161173 compute = False ,
162174 encoding = _generate_encodings (),
163- ** kwargs
175+ ** kwargs ,
164176 )
165177 return ds
0 commit comments