Skip to content

Commit 61ffa3a

Browse files
committed
fix fill value for structs
1 parent 58c1d8f commit 61ffa3a

File tree

1 file changed

+21
-39
lines changed

1 file changed

+21
-39
lines changed

tests/unit/v1/test_dataset_serializer.py

Lines changed: 21 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,9 @@
22

33
from pathlib import Path
44

5+
import numpy as np
56
import pytest
67
from dask import array as dask_array
7-
from numpy import array as np_array
8-
from numpy import dtype as np_dtype
9-
from numpy import isnan as np_isnan
10-
from numpy import zeros as np_zeros
118
from xarray import DataArray as xr_DataArray
129
from zarr import zeros as zarr_zeros
1310

@@ -45,7 +42,7 @@
4542
ZFPY = None
4643
HAS_ZFPY = False
4744

48-
from numcodecs import Blosc as nc_Blosc
45+
from numcodecs.zarr3 import Blosc as nc_Blosc
4946

5047
from mdio.schemas.compressors import ZFP as MDIO_ZFP
5148
from mdio.schemas.compressors import Blosc as mdio_Blosc
@@ -190,7 +187,7 @@ def test_get_fill_value() -> None:
190187
ScalarType.FLOAT64,
191188
]
192189
for scalar_type in scalar_types:
193-
assert np_isnan(_get_fill_value(scalar_type))
190+
assert np.isnan(_get_fill_value(scalar_type))
194191

195192
scalar_types = [
196193
ScalarType.UINT8,
@@ -212,8 +209,8 @@ def test_get_fill_value() -> None:
212209
for scalar_type in scalar_types:
213210
val = _get_fill_value(scalar_type)
214211
assert isinstance(val, complex)
215-
assert np_isnan(val.real)
216-
assert np_isnan(val.imag)
212+
assert np.isnan(val.real)
213+
assert np.isnan(val.imag)
217214

218215
# Test 2: StructuredType
219216
f1 = StructuredField(name="cdp_x", format=ScalarType.INT32)
@@ -222,9 +219,9 @@ def test_get_fill_value() -> None:
222219
f4 = StructuredField(name="some_scalar", format=ScalarType.FLOAT16)
223220
structured_type = StructuredType(fields=[f1, f2, f3, f4])
224221

225-
expected = np_array(
222+
expected = np.array(
226223
(0, 0, 0.0, 0.0),
227-
dtype=np_dtype([("cdp_x", "<i4"), ("cdp_y", "<i4"), ("elevation", "<f2"), ("some_scalar", "<f2")]),
224+
dtype=np.dtype([("cdp_x", "<i4"), ("cdp_y", "<i4"), ("elevation", "<f2"), ("some_scalar", "<f2")]),
228225
)
229226
result = _get_fill_value(structured_type)
230227
assert expected == result
@@ -258,24 +255,12 @@ def test_convert_compressor() -> None:
258255
)
259256
)
260257
assert isinstance(result_blosc, nc_Blosc)
261-
assert result_blosc.cname == "lz4" # BloscAlgorithm.LZ4.value
262-
assert result_blosc.clevel == 5
263-
assert result_blosc.shuffle == -1 # BloscShuffle.UTOSHUFFLE = -1
264-
assert result_blosc.blocksize == 1024
258+
assert result_blosc.codec_config["cname"] == "lz4" # BloscAlgorithm.LZ4.value
259+
assert result_blosc.codec_config["clevel"] == 5
260+
assert result_blosc.codec_config["shuffle"] == -1 # BloscShuffle.UTOSHUFFLE = -1
261+
assert result_blosc.codec_config["blocksize"] == 1024
265262

266-
# Test 3: mdio_Blosc with blocksize 0 - should use 0 as blocksize
267-
result_blosc_zero = _convert_compressor(
268-
mdio_Blosc(
269-
algorithm=mdio_BloscAlgorithm.ZSTD,
270-
level=3,
271-
shuffle=mdio_BloscShuffle.AUTOSHUFFLE,
272-
blocksize=0,
273-
)
274-
)
275-
assert isinstance(result_blosc_zero, nc_Blosc)
276-
assert result_blosc_zero.blocksize == 0
277-
278-
# Test 4: mdio_ZFP compressor - should return zfpy_ZFPY if available
263+
# Test 3: mdio_ZFP compressor - should return zfpy_ZFPY if available
279264
zfp_compressor = MDIO_ZFP(mode=mdio_ZFPMode.FIXED_RATE, tolerance=0.01, rate=8.0, precision=16)
280265

281266
if HAS_ZFPY:
@@ -345,8 +330,8 @@ def test_buf_reproducer_dask_to_zarr(tmp_path: Path) -> None:
345330
# https://github.com/TGSAI/mdio-python/issues/582
346331

347332
# Create a data type and the fill value
348-
dtype = np_dtype([("inline", "int32"), ("cdp_x", "float64")])
349-
dtype_fill_value = np_zeros((), dtype=dtype)
333+
dtype = np.dtype([("inline", "int32"), ("cdp_x", "float64")])
334+
dtype_fill_value = np.zeros((), dtype=dtype)
350335

351336
my_attr_encoding = {"fill_value": dtype_fill_value}
352337

@@ -367,10 +352,9 @@ def test_to_zarr_from_zarr_zeros_1(tmp_path: Path) -> None:
367352
Set encoding in as DataArray attributes
368353
"""
369354
# Create a data type and the fill value
370-
dtype = np_dtype([("inline", "int32"), ("cdp_x", "float64")])
371-
dtype_fill_value = np_zeros((), dtype=dtype)
355+
dtype = np.dtype([("inline", "int32"), ("cdp_x", "float64")])
372356

373-
my_attr_encoding = {"fill_value": dtype_fill_value}
357+
my_attr_encoding = {"fill_value": np.void((0, 0), dtype=dtype)}
374358

375359
# Create a zarr array using the data type,
376360
# Specify encoding as the array attribute
@@ -388,10 +372,9 @@ def test_to_zarr_from_zarr_zeros_2(tmp_path: Path) -> None:
388372
Set encoding in the to_zar method
389373
"""
390374
# Create a data type and the fill value
391-
dtype = np_dtype([("inline", "int32"), ("cdp_x", "float64")])
392-
dtype_fill_value = np_zeros((), dtype=dtype)
375+
dtype = np.dtype([("inline", "int32"), ("cdp_x", "float64")])
393376

394-
my_attr_encoding = {"fill_value": dtype_fill_value}
377+
my_attr_encoding = {"fill_value": np.void((0, 0), dtype=dtype)}
395378

396379
# Create a zarr array using the data type,
397380
# Do not specify encoding as the array attribute
@@ -407,14 +390,13 @@ def test_to_zarr_from_zarr_zeros_2(tmp_path: Path) -> None:
407390
def test_to_zarr_from_np(tmp_path: Path) -> None:
408391
"""Test writing XArray dataset with data as NumPy array to Zarr."""
409392
# Create a data type and the fill value
410-
dtype = np_dtype([("inline", "int32"), ("cdp_x", "float64")])
411-
dtype_fill_value = np_zeros((), dtype=dtype)
393+
dtype = np.dtype([("inline", "int32"), ("cdp_x", "float64")])
412394

413-
my_attr_encoding = {"fill_value": dtype_fill_value}
395+
my_attr_encoding = {"fill_value": np.void((0, 0), dtype=dtype)}
414396

415397
# Create a zarr array using the data type
416398
# Do not specify encoding as the array attribute
417-
data = np_zeros((36, 36), dtype=dtype)
399+
data = np.zeros((36, 36), dtype=dtype)
418400
aa = xr_DataArray(name="myattr", data=data)
419401

420402
file_path = output_path(tmp_path, "to_zarr/zarr_np", debugging=False)

0 commit comments

Comments
 (0)