Skip to content

Commit ef9418b

Browse files
committed
added additional slice compatibility checks
1 parent caa8fb7 commit ef9418b

File tree

3 files changed

+141
-8
lines changed

3 files changed

+141
-8
lines changed

tests/test_metadata.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def test_it_raises_on_dim_not_found(self):
342342
)
343343

344344
# noinspection PyMethodMayBeStatic
345-
def test_it_raises_on_on_missing_dtype_or_fill_value(self):
345+
def test_it_raises_on_missing_dtype_or_fill_value(self):
346346
with pytest.raises(ValueError,
347347
match="Missing 'dtype' in encoding configuration"
348348
" of variable 'b'"):
@@ -358,3 +358,109 @@ def test_it_raises_on_on_missing_dtype_or_fill_value(self):
358358
},
359359
}
360360
)
361+
362+
363+
class DatasetMetadataSliceCompatibilityTest(unittest.TestCase):
364+
365+
# noinspection PyMethodMayBeStatic
366+
def test_compatible(self):
367+
target_md = DatasetMetadata.from_dataset(
368+
xr.Dataset({
369+
"a": xr.DataArray(np.zeros((12, 3, 4)),
370+
dims=("time", "y", "x")),
371+
"b": xr.DataArray(np.zeros((12, 3, 4)),
372+
dims=("time", "y", "x")),
373+
}),
374+
{},
375+
)
376+
slice_md = DatasetMetadata.from_dataset(
377+
xr.Dataset({
378+
"a": xr.DataArray(np.zeros((1, 3, 4)),
379+
dims=("time", "y", "x")),
380+
"b": xr.DataArray(np.zeros((1, 3, 4)),
381+
dims=("time", "y", "x")),
382+
}),
383+
{},
384+
)
385+
386+
# Should not raise
387+
target_md.assert_compatible_slice(slice_md, "time")
388+
389+
# noinspection PyMethodMayBeStatic
390+
def test_raise_on_missing_dimension(self):
391+
target_md = DatasetMetadata.from_dataset(
392+
xr.Dataset({
393+
"a": xr.DataArray(np.zeros((12, 3, 4)),
394+
dims=("time", "y", "x")),
395+
"b": xr.DataArray(np.zeros((12, 3, 4)),
396+
dims=("time", "y", "x")),
397+
}),
398+
{},
399+
)
400+
slice_md = DatasetMetadata.from_dataset(
401+
xr.Dataset({
402+
"a": xr.DataArray(np.zeros((1, 3)),
403+
dims=("time", "y")),
404+
"b": xr.DataArray(np.zeros((1, 3)),
405+
dims=("time", "y")),
406+
}),
407+
{},
408+
)
409+
410+
with pytest.raises(ValueError,
411+
match="Missing dimension 'x' in slice dataset"):
412+
target_md.assert_compatible_slice(slice_md, "time")
413+
414+
# noinspection PyMethodMayBeStatic
415+
def test_raise_on_wrong_dimension_size(self):
416+
target_md = DatasetMetadata.from_dataset(
417+
xr.Dataset({
418+
"a": xr.DataArray(np.zeros((12, 3, 4)),
419+
dims=("time", "y", "x")),
420+
"b": xr.DataArray(np.zeros((12, 3, 4)),
421+
dims=("time", "y", "x")),
422+
}),
423+
{},
424+
)
425+
slice_md = DatasetMetadata.from_dataset(
426+
xr.Dataset({
427+
"a": xr.DataArray(np.zeros((12, 4, 4)),
428+
dims=("time", "y", "x")),
429+
"b": xr.DataArray(np.zeros((12, 4, 4)),
430+
dims=("time", "y", "x")),
431+
}),
432+
{},
433+
)
434+
435+
with pytest.raises(ValueError,
436+
match="Wrong size for dimension 'y'"
437+
" in slice dataset: expected 3, but found 4"):
438+
target_md.assert_compatible_slice(slice_md, "time")
439+
440+
# noinspection PyMethodMayBeStatic
441+
def test_raise_on_wrong_var_dimensions(self):
442+
target_md = DatasetMetadata.from_dataset(
443+
xr.Dataset({
444+
"a": xr.DataArray(np.zeros((12, 3, 4)),
445+
dims=("time", "y", "x")),
446+
"b": xr.DataArray(np.zeros((12, 3, 4)),
447+
dims=("time", "y", "x")),
448+
}),
449+
{},
450+
)
451+
slice_md = DatasetMetadata.from_dataset(
452+
xr.Dataset({
453+
"a": xr.DataArray(np.zeros((1, 3)),
454+
dims=("time", "y")),
455+
"b": xr.DataArray(np.zeros((1, 3, 4)),
456+
dims=("time", "y", "x")),
457+
}),
458+
{},
459+
)
460+
461+
with pytest.raises(ValueError,
462+
match="Wrong dimensions for variable 'a' in"
463+
" slice dataset:"
464+
" expected \\('time', 'y', 'x'\\),"
465+
" but found \\('time', 'y'\\)"):
466+
target_md.assert_compatible_slice(slice_md, "time")

zappend/metadata.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,28 @@ def to_dict(self):
8787
for k, v in self.variables.items()},
8888
attrs=self.attrs)
8989

90+
def assert_compatible_slice(self,
91+
slice_metadata: "DatasetMetadata",
92+
append_dim: str):
93+
for dim_name, dim_size in self.dims.items():
94+
if dim_name not in slice_metadata.dims:
95+
raise ValueError(f"Missing dimension"
96+
f" {dim_name!r} in slice dataset")
97+
slice_dim_size = slice_metadata.dims[dim_name]
98+
if dim_name != append_dim and dim_size != slice_dim_size:
99+
raise ValueError(f"Wrong size for dimension {dim_name!r}"
100+
f" in slice dataset:"
101+
f" expected {dim_size},"
102+
f" but found {slice_dim_size}")
103+
for var_name, var_metadata in self.variables.items():
104+
slice_var_metadata = slice_metadata.variables.get(var_name)
105+
if (slice_var_metadata is not None
106+
and var_metadata.dims != slice_var_metadata.dims):
107+
raise ValueError(f"Wrong dimensions for variable {var_name!r}"
108+
f" in slice dataset:"
109+
f" expected {var_metadata.dims},"
110+
f" but found {slice_var_metadata.dims}")
111+
90112
@classmethod
91113
def from_dataset(cls,
92114
dataset: xr.Dataset,
@@ -137,12 +159,12 @@ def _get_effective_dims(dataset: xr.Dataset,
137159
return {str(k): v for k, v in dataset.dims.items()}
138160

139161

140-
def _get_effective_variables(dataset: xr.Dataset,
141-
config_included_variables: list[str] | None,
142-
config_excluded_variables: list[str] | None,
143-
config_variables: dict[str, dict[
144-
str, Any]] | None) -> dict[
145-
str, VariableMetadata]:
162+
def _get_effective_variables(
163+
dataset: xr.Dataset,
164+
config_included_variables: list[str] | None,
165+
config_excluded_variables: list[str] | None,
166+
config_variables: dict[str, dict[str, Any]] | None
167+
) -> dict[str, VariableMetadata]:
146168
config_variables = dict(config_variables or {})
147169
defaults = config_variables.pop("*", {})
148170
config_variables = {k: merge_configs(defaults, v)
@@ -167,7 +189,7 @@ def _get_effective_variables(dataset: xr.Dataset,
167189
variables = {}
168190

169191
for var_name in selected_var_names:
170-
config_var_def = dict(config_variables.get(var_name) or {})
192+
config_var_def: dict = dict(config_variables.get(var_name) or {})
171193
ds_var = dataset.variables.get(var_name)
172194
if ds_var is not None:
173195
# Variable found in dataset: use dataset variable to complement

zappend/processor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ def process_slice(self, slice_obj: str | xr.Dataset):
6161
slice_metadata = ctx.get_dataset_metadata(slice_dataset)
6262
if ctx.target_metadata is None:
6363
ctx.target_metadata = slice_metadata
64+
else:
65+
ctx.target_metadata.assert_compatible_slice(
66+
slice_metadata,
67+
ctx.append_dim_name
68+
)
6469

6570
with Transaction(ctx.target_dir, ctx.temp_dir) as rollback_cb:
6671
if ctx.target_metadata is slice_metadata:

0 commit comments

Comments
 (0)