Skip to content

Commit 775f680

Browse files
Fix compression and serializer encodings (#748)
* Fix compression and serializer encodings * Fix zarr v2 support * Use the proper encoding for zfp mode * Remove configuration for write_header as it is not used by Zarr or Numcodecs * reduce repetition and simplify logic * move compressor check to beginning * handle linter complaints and handle None --------- Co-authored-by: Altay Sansal <[email protected]>
1 parent f74c218 commit 775f680

File tree

3 files changed

+54
-46
lines changed

3 files changed

+54
-46
lines changed

src/mdio/builder/schemas/compressors.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,6 @@ class ZFP(CamelCaseStrictModel):
7171
description="Fixed precision in terms of number of uncompressed bits per value.",
7272
)
7373

74-
write_header: bool = Field(
75-
default=True,
76-
description="Encode array shape, scalar type, and compression parameters.",
77-
)
78-
7974
@model_validator(mode="after")
8075
def check_requirements(self) -> ZFP:
8176
"""Check if ZFP parameters make sense."""

src/mdio/builder/xarray_builder.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
try:
1515
# zfpy is an optional dependency for ZFP compression
1616
# It is not installed by default, so we check for its presence and import it only if available.
17-
from zfpy import ZFPY as zfpy_ZFPY # noqa: N811
17+
from numcodecs import ZFPY as zfpy_ZFPY # noqa: N811
18+
from zarr.codecs.numcodecs import ZFPY as zarr_ZFPY # noqa: N811
1819
except ImportError:
1920
zfpy_ZFPY = None # noqa: N816
21+
zarr_ZFPY = None # noqa: N816
2022

2123
from mdio.builder.schemas.compressors import ZFP as mdio_ZFP # noqa: N811
2224
from mdio.builder.schemas.compressors import Blosc as mdio_Blosc
@@ -121,33 +123,34 @@ def _get_zarr_chunks(var: Variable, all_named_dims: dict[str, NamedDimension]) -
121123
return _get_zarr_shape(var, all_named_dims=all_named_dims)
122124

123125

124-
def _convert_compressor(
126+
def _compressor_to_encoding(
125127
compressor: mdio_Blosc | mdio_ZFP | None,
126-
) -> BloscCodec | Blosc | zfpy_ZFPY | None:
128+
) -> dict[str, BloscCodec | Blosc | zfpy_ZFPY | zarr_ZFPY] | None:
127129
"""Convert a compressor to a numcodecs compatible format."""
128130
if compressor is None:
129131
return None
130132

133+
if not isinstance(compressor, (mdio_Blosc, mdio_ZFP)):
134+
msg = f"Unsupported compressor model: {type(compressor)}"
135+
raise TypeError(msg)
136+
137+
is_v2 = zarr.config.get("default_zarr_format") == ZarrFormat.V2
138+
kwargs = compressor.model_dump(exclude={"name"}, mode="json")
139+
131140
if isinstance(compressor, mdio_Blosc):
132-
blosc_kwargs = compressor.model_dump(exclude={"name"}, mode="json")
133-
if zarr.config.get("default_zarr_format") == ZarrFormat.V2:
134-
blosc_kwargs["shuffle"] = -1 if blosc_kwargs["shuffle"] is None else blosc_kwargs["shuffle"]
135-
return Blosc(**blosc_kwargs)
136-
return BloscCodec(**blosc_kwargs)
137-
138-
if isinstance(compressor, mdio_ZFP):
139-
if zfpy_ZFPY is None:
140-
msg = "zfpy and numcodecs are required to use ZFP compression"
141-
raise ImportError(msg)
142-
return zfpy_ZFPY(
143-
mode=compressor.mode.value,
144-
tolerance=compressor.tolerance,
145-
rate=compressor.rate,
146-
precision=compressor.precision,
147-
)
148-
149-
msg = f"Unsupported compressor model: {type(compressor)}"
150-
raise TypeError(msg)
141+
if is_v2 and kwargs["shuffle"] is None:
142+
kwargs["shuffle"] = -1
143+
codec_cls = Blosc if is_v2 else BloscCodec
144+
return {"compressors": codec_cls(**kwargs)}
145+
146+
# must be ZFP beyond here
147+
if zfpy_ZFPY is None:
148+
msg = "zfpy and numcodecs are required to use ZFP compression"
149+
raise ImportError(msg)
150+
kwargs["mode"] = compressor.mode.int_code
151+
if is_v2:
152+
return {"compressors": zfpy_ZFPY(**kwargs)}
153+
return {"serializer": zarr_ZFPY(**kwargs), "compressors": None}
151154

152155

153156
def _get_fill_value(data_type: ScalarType | StructuredType | str) -> any:
@@ -222,10 +225,14 @@ def to_xarray_dataset(mdio_ds: Dataset) -> xr_Dataset: # noqa: PLR0912
222225

223226
encoding = {
224227
"chunks": original_chunks,
225-
"compressors": _convert_compressor(v.compressor),
226228
fill_value_key: fill_value,
227229
}
228230

231+
compressor_encodings = _compressor_to_encoding(v.compressor)
232+
233+
if compressor_encodings is not None:
234+
encoding.update(compressor_encodings)
235+
229236
if zarr_format == ZarrFormat.V2:
230237
encoding["chunk_key_encoding"] = {"name": "v2", "configuration": {"separator": "/"}}
231238

tests/unit/v1/test_dataset_serializer.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from mdio.builder.schemas.v1.variable import Coordinate
2020
from mdio.builder.schemas.v1.variable import Variable
2121
from mdio.builder.schemas.v1.variable import VariableMetadata
22-
from mdio.builder.xarray_builder import _convert_compressor
22+
from mdio.builder.xarray_builder import _compressor_to_encoding
2323
from mdio.builder.xarray_builder import _get_all_named_dimensions
2424
from mdio.builder.xarray_builder import _get_coord_names
2525
from mdio.builder.xarray_builder import _get_dimension_names
@@ -226,43 +226,49 @@ def test_get_fill_value() -> None:
226226
assert result_none_input is None
227227

228228

229-
def test_convert_compressor() -> None:
230-
"""Simple test for _convert_compressor function covering basic scenarios."""
229+
def test_compressor_to_encoding() -> None:
230+
"""Simple test for _compressor_to_encoding function covering basic scenarios."""
231231
# Test 1: None input - should return None
232-
result_none = _convert_compressor(None)
232+
result_none = _compressor_to_encoding(None)
233233
assert result_none is None
234234

235235
# Test 2: mdio_Blosc compressor - should return nc_Blosc
236236
mdio_compressor = mdio_Blosc(cname=BloscCname.lz4, clevel=5, shuffle=BloscShuffle.bitshuffle, blocksize=1024)
237-
result_blosc = _convert_compressor(mdio_compressor)
237+
result_blosc = _compressor_to_encoding(mdio_compressor)
238238

239-
assert isinstance(result_blosc, BloscCodec)
240-
assert result_blosc.cname == BloscCname.lz4
241-
assert result_blosc.clevel == 5
242-
assert result_blosc.shuffle == BloscShuffle.bitshuffle
243-
assert result_blosc.blocksize == 1024
239+
assert isinstance(result_blosc, dict)
240+
assert "compressors" in result_blosc
241+
assert isinstance(result_blosc["compressors"], BloscCodec)
242+
assert result_blosc["compressors"].cname == BloscCname.lz4
243+
assert result_blosc["compressors"].clevel == 5
244+
assert result_blosc["compressors"].shuffle == BloscShuffle.bitshuffle
245+
assert result_blosc["compressors"].blocksize == 1024
244246

245247
# Test 3: mdio_ZFP compressor - should return zfpy_ZFPY if available
246248
zfp_compressor = MDIO_ZFP(mode=mdio_ZFPMode.FIXED_RATE, tolerance=0.01, rate=8.0, precision=16)
247249

250+
# TODO(BrianMichell): Update to also test zfp compression.
251+
# https://github.com/TGSAI/mdio-python/issues/747
248252
if HAS_ZFPY: # pragma: no cover
249-
result_zfp = _convert_compressor(zfp_compressor)
250-
assert isinstance(result_zfp, ZFPY)
251-
assert result_zfp.mode == 1 # ZFPMode.FIXED_RATE.value = "fixed_rate"
252-
assert result_zfp.tolerance == 0.01
253-
assert result_zfp.rate == 8.0
254-
assert result_zfp.precision == 16
253+
result_zfp = _compressor_to_encoding(zfp_compressor)
254+
assert isinstance(result_zfp, dict)
255+
assert "compressors" not in result_zfp
256+
assert isinstance(result_zfp["serializer"], ZFPY)
257+
assert result_zfp["serializer"].mode == 1 # ZFPMode.FIXED_RATE.value = "fixed_rate"
258+
assert result_zfp["serializer"].tolerance == 0.01
259+
assert result_zfp["serializer"].rate == 8.0
260+
assert result_zfp["serializer"].precision == 16
255261
else:
256262
# Test 5: mdio_ZFP without zfpy installed - should raise ImportError
257263
with pytest.raises(ImportError) as exc_info:
258-
_convert_compressor(zfp_compressor)
264+
_compressor_to_encoding(zfp_compressor)
259265
error_message = str(exc_info.value)
260266
assert "zfpy and numcodecs are required to use ZFP compression" in error_message
261267

262268
# Test 6: Unsupported compressor type - should raise TypeError
263269
unsupported_compressor = "invalid_compressor"
264270
with pytest.raises(TypeError) as exc_info:
265-
_convert_compressor(unsupported_compressor)
271+
_compressor_to_encoding(unsupported_compressor)
266272
error_message = str(exc_info.value)
267273
assert "Unsupported compressor model" in error_message
268274
assert "<class 'str'>" in error_message

0 commit comments

Comments
 (0)