Skip to content

Commit 5c757a4

Browse files
committed
None chunk sizes should work now
1 parent ba0042d commit 5c757a4

File tree

4 files changed

+175
-48
lines changed

4 files changed

+175
-48
lines changed

tests/test_api.py

Lines changed: 77 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,20 @@ def test_some_slices_local_output_to_existing_dir_force_new(self):
142142
zappend(slices, target_dir=target_dir, force_new=True)
143143
self.assertEqual(False, lock_file.exists())
144144

145-
def test_some_slices_with_class_slice_source(self):
145+
def test_some_slices_with_slice_source_class(self):
146+
class DropTsm(SliceSource):
147+
def __init__(self, slice_ds):
148+
self.slice_ds = slice_ds
149+
150+
def get_dataset(self) -> xr.Dataset:
151+
return self.slice_ds.drop_vars(["tsm"])
152+
153+
def dispose(self):
154+
pass
155+
146156
target_dir = "memory://target.zarr"
147157
slices = [make_test_dataset(index=3 * i) for i in range(3)]
148-
zappend(slices, target_dir=target_dir, slice_source=MySliceSource)
158+
zappend(slices, target_dir=target_dir, slice_source=DropTsm)
149159
ds = xr.open_zarr(target_dir)
150160
self.assertEqual({"time": 9, "y": 50, "x": 100}, ds.sizes)
151161
self.assertEqual({"chl"}, set(ds.data_vars))
@@ -158,13 +168,13 @@ def test_some_slices_with_class_slice_source(self):
158168
ds.attrs,
159169
)
160170

161-
def test_some_slices_with_func_slice_source(self):
162-
def process_slice(slice_ds: xr.Dataset) -> SliceSource:
163-
return MySliceSource(slice_ds)
171+
def test_some_slices_with_slice_source_func(self):
172+
def drop_tsm(slice_ds: xr.Dataset) -> xr.Dataset:
173+
return slice_ds.drop_vars(["tsm"])
164174

165175
target_dir = "memory://target.zarr"
166176
slices = [make_test_dataset(index=3 * i) for i in range(3)]
167-
zappend(slices, target_dir=target_dir, slice_source=process_slice)
177+
zappend(slices, target_dir=target_dir, slice_source=drop_tsm)
168178
ds = xr.open_zarr(target_dir)
169179
self.assertEqual({"time": 9, "y": 50, "x": 100}, ds.sizes)
170180
self.assertEqual({"chl"}, set(ds.data_vars))
@@ -177,9 +187,67 @@ def process_slice(slice_ds: xr.Dataset) -> SliceSource:
177187
ds.attrs,
178188
)
179189

180-
def test_some_slices_with_cropping_slice_source(self):
181-
# TODO: implement me after #78
182-
pass
190+
# See https://github.com/bcdev/zappend/issues/77
191+
def test_some_slices_with_cropping_slice_source_no_chunks_spec(self):
192+
def crop_ds(slice_ds: xr.Dataset) -> xr.Dataset:
193+
w = slice_ds.x.size
194+
h = slice_ds.y.size
195+
return slice_ds.isel(x=slice(5, w - 5), y=slice(5, h - 5))
196+
197+
target_dir = "memory://target.zarr"
198+
slices = [make_test_dataset(index=3 * i) for i in range(3)]
199+
zappend(slices, target_dir=target_dir, slice_source=crop_ds)
200+
ds = xr.open_zarr(target_dir)
201+
self.assertEqual({"time": 9, "y": 40, "x": 90}, ds.sizes)
202+
self.assertEqual({"chl", "tsm"}, set(ds.data_vars))
203+
self.assertEqual({"time", "y", "x"}, set(ds.coords))
204+
self.assertEqual((90,), ds.x.encoding.get("chunks"))
205+
self.assertEqual((40,), ds.y.encoding.get("chunks"))
206+
self.assertEqual((3,), ds.time.encoding.get("chunks"))
207+
# Chunk sizes are the ones of the original array, because we have not
208+
# specified chunks in encoding.
209+
self.assertEqual((1, 25, 45), ds.chl.encoding.get("chunks"))
210+
self.assertEqual((1, 25, 45), ds.tsm.encoding.get("chunks"))
211+
212+
# See https://github.com/bcdev/zappend/issues/77
213+
def test_some_slices_with_cropping_slice_source_with_chunks_spec(self):
214+
def crop_ds(slice_ds: xr.Dataset) -> xr.Dataset:
215+
w = slice_ds.x.size
216+
h = slice_ds.y.size
217+
return slice_ds.isel(x=slice(5, w - 5), y=slice(5, h - 5))
218+
219+
variables = {
220+
"*": {
221+
"encoding": {
222+
"chunks": None,
223+
}
224+
},
225+
"chl": {
226+
"encoding": {
227+
"chunks": [1, None, None],
228+
}
229+
},
230+
"tsm": {
231+
"encoding": {
232+
"chunks": [None, 25, 50],
233+
}
234+
},
235+
}
236+
237+
target_dir = "memory://target.zarr"
238+
slices = [make_test_dataset(index=3 * i) for i in range(3)]
239+
zappend(
240+
slices, target_dir=target_dir, slice_source=crop_ds, variables=variables
241+
)
242+
ds = xr.open_zarr(target_dir)
243+
self.assertEqual({"time": 9, "y": 40, "x": 90}, ds.sizes)
244+
self.assertEqual({"chl", "tsm"}, set(ds.data_vars))
245+
self.assertEqual({"time", "y", "x"}, set(ds.coords))
246+
self.assertEqual((90,), ds.x.encoding.get("chunks"))
247+
self.assertEqual((40,), ds.y.encoding.get("chunks"))
248+
self.assertEqual((3,), ds.time.encoding.get("chunks"))
249+
self.assertEqual((1, 40, 90), ds.chl.encoding.get("chunks"))
250+
self.assertEqual((3, 25, 50), ds.tsm.encoding.get("chunks"))
183251

184252
def test_some_slices_with_inc_append_step(self):
185253
target_dir = "memory://target.zarr"
@@ -395,14 +463,3 @@ def test_some_slices_with_profiling(self):
395463
finally:
396464
if os.path.exists("prof.out"):
397465
os.remove("prof.out")
398-
399-
400-
class MySliceSource(SliceSource):
401-
def __init__(self, slice_ds):
402-
self.slice_ds = slice_ds
403-
404-
def get_dataset(self) -> xr.Dataset:
405-
return self.slice_ds.drop_vars(["tsm"])
406-
407-
def dispose(self):
408-
pass

tests/test_metadata.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,47 @@ def test_variable_encoding_from_netcdf(self):
291291
).to_dict(),
292292
)
293293

294+
def test_variable_encoding_can_deal_with_chunk_size_none(self):
295+
# See https://github.com/bcdev/zappend/issues/77
296+
a = xr.DataArray(np.zeros((2, 3, 4)), dims=("time", "y", "x"))
297+
b = xr.DataArray(np.zeros((2, 3, 4)), dims=("time", "y", "x"))
298+
self.assertEqual(
299+
{
300+
"attrs": {},
301+
"sizes": {"time": 2, "x": 4, "y": 3},
302+
"variables": {
303+
"a": {
304+
"attrs": {},
305+
"dims": ("time", "y", "x"),
306+
"encoding": {"chunks": (1, 3, 4)},
307+
"shape": (2, 3, 4),
308+
},
309+
"b": {
310+
"attrs": {},
311+
"dims": ("time", "y", "x"),
312+
"encoding": {"chunks": (2, 2, 3)},
313+
"shape": (2, 3, 4),
314+
},
315+
},
316+
},
317+
DatasetMetadata.from_dataset(
318+
xr.Dataset(
319+
{
320+
"a": a,
321+
"b": b,
322+
}
323+
),
324+
make_config(
325+
{
326+
"variables": {
327+
"a": {"encoding": {"chunks": [1, None, None]}},
328+
"b": {"encoding": {"chunks": [None, 2, 3]}},
329+
},
330+
}
331+
),
332+
).to_dict(),
333+
)
334+
294335
def test_variable_encoding_normalisation(self):
295336
def normalize(k, v):
296337
metadata = DatasetMetadata.from_dataset(
@@ -363,6 +404,7 @@ def test_it_raises_on_unspecified_variable(self):
363404
),
364405
)
365406

407+
# noinspection PyMethodMayBeStatic
366408
def test_it_raises_on_wrong_size_found_in_ds(self):
367409
with pytest.raises(
368410
ValueError,

zappend/metadata.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ class Undefined:
2222

2323
Codec = numcodecs.abc.Codec
2424

25+
NoneType = type(None)
26+
2527

2628
class VariableEncoding:
2729
"""The Zarr encoding of a dataset's variable.
@@ -346,6 +348,21 @@ def _get_effective_variables(
346348
chunk_sizes = encoding.pop("chunksizes")
347349
if "chunks" not in encoding:
348350
encoding["chunks"] = chunk_sizes
351+
352+
# Handle case where a chunk size in None to indicate
353+
# dimension is not chunked.
354+
# See https://github.com/bcdev/zappend/issues/77
355+
if (
356+
"chunks" in encoding
357+
and encoding["chunks"] is not None
358+
and None in encoding["chunks"]
359+
):
360+
chunks = encoding["chunks"]
361+
encoding["chunks"] = tuple(
362+
(dataset.sizes[dim_name] if chunk_size is None else chunk_size)
363+
for dim_name, chunk_size in zip(dims, chunks)
364+
)
365+
349366
variables[var_name] = VariableMetadata(
350367
dims=dims, shape=shape, encoding=VariableEncoding(**encoding), attrs=attrs
351368
)
@@ -364,7 +381,7 @@ def _normalize_chunks(value: Any) -> tuple[int, ...] | None:
364381
if not value:
365382
return None
366383
assert isinstance(value, (tuple, list))
367-
return tuple((v if isinstance(v, int) else v[0]) for v in value)
384+
return tuple((v if isinstance(v, (int, NoneType)) else v[0]) for v in value)
368385

369386

370387
def _normalize_number(value: Any) -> int | float | None:

zappend/tailoring.py

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def tailor_slice_dataset(ctx: Context, slice_ds: xr.Dataset) -> xr.Dataset:
8383

8484

8585
def _strip_dataset(dataset: xr.Dataset, target_metadata: DatasetMetadata) -> xr.Dataset:
86+
"""Remove unwanted variables from `dataset` and return a copy."""
8687
drop_var_names = set(map(str, dataset.variables.keys())) - set(
8788
target_metadata.variables.keys()
8889
)
@@ -92,36 +93,46 @@ def _strip_dataset(dataset: xr.Dataset, target_metadata: DatasetMetadata) -> xr.
9293
def _complete_dataset(
9394
dataset: xr.Dataset, target_metadata: DatasetMetadata
9495
) -> xr.Dataset:
96+
undefined = object()
97+
"""Chunk existing variables according to chunks in encoding or
98+
add missing variables to `dataset` (in-place operation) and return it.
99+
"""
95100
for var_name, var_metadata in target_metadata.variables.items():
96101
var = dataset.variables.get(var_name)
97-
if var is not None:
98-
continue
99-
logger.warning(
100-
f"Variable {var_name!r} not found in slice dataset;" f" creating it."
101-
)
102102
encoding = var_metadata.encoding.to_dict()
103-
chunks = encoding.get("chunks")
104-
if chunks is None:
105-
chunks = var_metadata.shape
106-
if encoding.get("_FillValue") is not None:
107-
# Since we have a defined fill value, the decoded in-memory
108-
# variable uses NaN where fill value will be stored.
109-
# This ia also what xarray does if decode_cf=True.
110-
memory_dtype = np.dtype("float64")
111-
memory_fill_value = float("NaN")
103+
chunks = encoding.get("chunks", undefined)
104+
if var is not None:
105+
if chunks is None:
106+
# May emit warning for large shapes
107+
chunks = var_metadata.shape
108+
if chunks is not undefined:
109+
var = var.chunk(chunks=chunks)
112110
else:
113-
# Fill value is not defined, so we use the data type
114-
# defined in the encoding, if any and fill memory with zeros.
115-
memory_dtype = encoding.get("dtype", np.dtype("float64"))
116-
memory_fill_value = 0
117-
var = xr.DataArray(
118-
dask.array.full(
119-
var_metadata.shape,
120-
memory_fill_value,
121-
chunks=chunks,
122-
dtype=np.dtype(memory_dtype),
123-
),
124-
dims=var_metadata.dims,
125-
)
111+
logger.warning(
112+
f"Variable {var_name!r} not found in slice dataset; creating it."
113+
)
114+
if chunks is None or chunks is undefined:
115+
# May emit warning for large shapes
116+
chunks = var_metadata.shape
117+
if encoding.get("_FillValue") is not None:
118+
# Since we have a defined fill value, the decoded in-memory
119+
# variable uses NaN where fill value will be stored.
120+
# This ia also what xarray does if decode_cf=True.
121+
memory_dtype = np.dtype("float64")
122+
memory_fill_value = float("NaN")
123+
else:
124+
# Fill value is not defined, so we use the data type
125+
# defined in the encoding, if any and fill memory with zeros.
126+
memory_dtype = encoding.get("dtype", np.dtype("float64"))
127+
memory_fill_value = 0
128+
var = xr.DataArray(
129+
dask.array.full(
130+
var_metadata.shape,
131+
memory_fill_value,
132+
chunks=chunks,
133+
dtype=np.dtype(memory_dtype),
134+
),
135+
dims=var_metadata.dims,
136+
)
126137
dataset[var_name] = var
127138
return dataset

0 commit comments

Comments
 (0)