Skip to content

Commit 45751fc

Browse files
committed
refactored slice tailoring
1 parent aab0f27 commit 45751fc

File tree

3 files changed

+163
-93
lines changed

3 files changed

+163
-93
lines changed

tests/test_tailoring.py

Lines changed: 136 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,23 @@
88
import pyproj
99
import xarray as xr
1010

11-
from zappend.metadata import DatasetMetadata
11+
from zappend.context import Context
1212
from zappend.tailoring import tailor_target_dataset
1313
from zappend.tailoring import tailor_slice_dataset
1414

1515

16+
def make_context(config: dict, target_ds: xr.Dataset, write: bool = False) -> Context:
17+
target_dir = "memory://target.zarr"
18+
if write:
19+
target_ds.to_zarr(target_dir, mode="w")
20+
ctx = Context({"target_dir": target_dir, **config})
21+
ctx.target_metadata = ctx.get_dataset_metadata(target_ds)
22+
return ctx
23+
24+
1625
class TailorTargetDatasetTest(unittest.TestCase):
17-
def test_it_sets_metadata(self):
18-
ds = xr.Dataset(
26+
def test_it_sets_vars_metadata(self):
27+
slice_ds = xr.Dataset(
1928
{
2029
"a": xr.DataArray(
2130
np.zeros((2, 3, 4)),
@@ -29,18 +38,17 @@ def test_it_sets_metadata(self):
2938
),
3039
}
3140
)
32-
tailored_ds = tailor_target_dataset(
33-
ds,
34-
DatasetMetadata.from_dataset(
35-
ds,
36-
{
37-
"variables": {
38-
"a": {"encoding": {"dtype": "uint8", "fill_value": 0}},
39-
"b": {"encoding": {"dtype": "int8", "fill_value": -1}},
40-
}
41+
ctx = make_context(
42+
{
43+
"variables": {
44+
"a": {"encoding": {"dtype": "uint8", "fill_value": 0}},
45+
"b": {"encoding": {"dtype": "int8", "fill_value": -1}},
4146
},
42-
),
47+
},
48+
slice_ds,
4349
)
50+
51+
tailored_ds = tailor_target_dataset(ctx, slice_ds)
4452
self.assertIsInstance(tailored_ds, xr.Dataset)
4553
self.assertEqual({"a", "b"}, set(tailored_ds.variables.keys()))
4654

@@ -57,49 +65,45 @@ def test_it_sets_metadata(self):
5765
self.assertEqual({"units": "g/m^3"}, b.attrs)
5866

5967
def test_it_strips_vars(self):
60-
ds = xr.Dataset(
68+
slice_ds = xr.Dataset(
6169
{
6270
"a": xr.DataArray(np.zeros((2, 3, 4)), dims=("time", "y", "x")),
6371
"b": xr.DataArray(np.zeros((2, 3, 4)), dims=("time", "y", "x")),
6472
}
6573
)
6674

67-
tailored_ds = tailor_target_dataset(
68-
ds, DatasetMetadata.from_dataset(ds, {"included_variables": ["b"]})
69-
)
75+
ctx = make_context({"included_variables": ["b"]}, slice_ds)
76+
tailored_ds = tailor_target_dataset(ctx, slice_ds)
7077
self.assertEqual({"b"}, set(tailored_ds.variables.keys()))
7178

72-
tailored_ds = tailor_target_dataset(
73-
ds, DatasetMetadata.from_dataset(ds, {"excluded_variables": ["b"]})
74-
)
79+
ctx = make_context({"excluded_variables": ["b"]}, slice_ds)
80+
tailored_ds = tailor_target_dataset(ctx, slice_ds)
7581
self.assertEqual({"a"}, set(tailored_ds.variables.keys()))
7682

7783
def test_it_completes_vars(self):
78-
ds = xr.Dataset(
84+
slice_ds = xr.Dataset(
7985
{
8086
"a": xr.DataArray(np.zeros((2, 3, 4)), dims=("time", "y", "x")),
8187
}
8288
)
83-
84-
tailored_ds = tailor_target_dataset(
85-
ds,
86-
DatasetMetadata.from_dataset(
87-
ds,
88-
{
89-
"variables": {
90-
"a": {"dims": ["time", "y", "x"]},
91-
"b": {
92-
"dims": ["time", "y", "x"],
93-
"encoding": {"dtype": "int16", "fill_value": 0},
94-
},
95-
"c": {
96-
"dims": ["time", "y", "x"],
97-
"encoding": {"dtype": "uint32"},
98-
},
89+
ctx = make_context(
90+
{
91+
"variables": {
92+
"a": {"dims": ["time", "y", "x"]},
93+
"b": {
94+
"dims": ["time", "y", "x"],
95+
"encoding": {"dtype": "int16", "fill_value": 0},
96+
},
97+
"c": {
98+
"dims": ["time", "y", "x"],
99+
"encoding": {"dtype": "uint32"},
99100
},
100101
},
101-
),
102+
},
103+
slice_ds,
102104
)
105+
106+
tailored_ds = tailor_target_dataset(ctx, slice_ds)
103107
self.assertEqual({"a", "b", "c"}, set(tailored_ds.variables.keys()))
104108

105109
b = tailored_ds.b
@@ -110,12 +114,72 @@ def test_it_completes_vars(self):
110114
self.assertEqual(np.dtype("uint32"), c.dtype)
111115
self.assertEqual(np.dtype("uint32"), c.encoding.get("dtype"))
112116

113-
# noinspection PyMethodMayBeStatic
117+
def test_it_updates_attrs_according_to_update_mode(self):
118+
target_ds = xr.Dataset(
119+
{
120+
"a": xr.DataArray(
121+
np.zeros((2, 3, 4)),
122+
dims=("time", "y", "x"),
123+
),
124+
},
125+
attrs={"Conventions": "CF-1.8"},
126+
)
127+
128+
ctx = make_context(
129+
{"attrs_update_mode": "keep", "attrs": {"a": 12, "b": True}}, target_ds
130+
)
131+
tailored_ds = tailor_target_dataset(ctx, target_ds)
132+
self.assertEqual(
133+
{
134+
"Conventions": "CF-1.8",
135+
"a": 12,
136+
"b": True,
137+
},
138+
tailored_ds.attrs,
139+
)
140+
141+
ctx = make_context(
142+
{"attrs_update_mode": "replace", "attrs": {"a": 12, "b": True}}, target_ds
143+
)
144+
tailored_ds = tailor_target_dataset(ctx, target_ds)
145+
self.assertEqual(
146+
{
147+
"Conventions": "CF-1.8",
148+
"a": 12,
149+
"b": True,
150+
},
151+
tailored_ds.attrs,
152+
)
153+
154+
ctx = make_context(
155+
{"attrs_update_mode": "update", "attrs": {"a": 12, "b": True}}, target_ds
156+
)
157+
tailored_ds = tailor_target_dataset(ctx, target_ds)
158+
self.assertEqual(
159+
{
160+
"Conventions": "CF-1.8",
161+
"a": 12,
162+
"b": True,
163+
},
164+
tailored_ds.attrs,
165+
)
166+
167+
ctx = make_context(
168+
{"attrs_update_mode": "ignore", "attrs": {"a": 12, "b": True}}, target_ds
169+
)
170+
tailored_ds = tailor_target_dataset(ctx, target_ds)
171+
self.assertEqual(
172+
{
173+
"a": 12,
174+
"b": True,
175+
},
176+
tailored_ds.attrs,
177+
)
114178

115179

116180
class TailorSliceDatasetTest(unittest.TestCase):
117181
def test_it_drops_constant_variables(self):
118-
ds = xr.Dataset(
182+
slice_ds = xr.Dataset(
119183
{
120184
"a": xr.DataArray(np.zeros((2, 3, 4)), dims=("time", "y", "x")),
121185
"b": xr.DataArray(np.zeros((2, 3, 4)), dims=("time", "y", "x")),
@@ -128,14 +192,13 @@ def test_it_drops_constant_variables(self):
128192
"y": xr.DataArray(np.linspace(0.0, 1.0, 3), dims="y"),
129193
},
130194
)
131-
tailored_ds = tailor_slice_dataset(
132-
ds, DatasetMetadata.from_dataset(ds, {}), "time"
133-
)
195+
ctx = make_context({}, slice_ds)
196+
tailored_ds = tailor_slice_dataset(ctx, slice_ds)
134197
self.assertIsInstance(tailored_ds, xr.Dataset)
135198
self.assertEqual({"a", "b"}, set(tailored_ds.variables.keys()))
136199

137200
def test_it_clears_var_encoding_and_attrs(self):
138-
ds = xr.Dataset(
201+
slice_ds = xr.Dataset(
139202
{
140203
"a": xr.DataArray(
141204
np.zeros((2, 3, 4)),
@@ -149,19 +212,16 @@ def test_it_clears_var_encoding_and_attrs(self):
149212
),
150213
}
151214
)
152-
tailored_ds = tailor_slice_dataset(
153-
ds,
154-
DatasetMetadata.from_dataset(
155-
ds,
156-
{
157-
"variables": {
158-
"a": {"encoding": {"dtype": "uint8", "fill_value": 0}},
159-
"b": {"encoding": {"dtype": "int8", "fill_value": -1}},
160-
}
161-
},
162-
),
163-
"time",
215+
ctx = make_context(
216+
{
217+
"variables": {
218+
"a": {"encoding": {"dtype": "uint8", "fill_value": 0}},
219+
"b": {"encoding": {"dtype": "int8", "fill_value": -1}},
220+
}
221+
},
222+
slice_ds,
164223
)
224+
tailored_ds = tailor_slice_dataset(ctx, slice_ds)
165225
self.assertIsInstance(tailored_ds, xr.Dataset)
166226

167227
self.assertIn("a", tailored_ds.variables)
@@ -194,11 +254,12 @@ def test_it_updates_attrs_according_to_update_mode(self):
194254
attrs={"title": "OCC 2024"},
195255
)
196256

197-
target_md = DatasetMetadata.from_dataset(target_ds)
198-
199-
tailored_ds = tailor_slice_dataset(
200-
slice_ds, target_md, "time", "keep", {"a": 12, "b": True}
257+
ctx = make_context(
258+
{"attrs_update_mode": "keep", "attrs": {"a": 12, "b": True}},
259+
target_ds,
260+
True,
201261
)
262+
tailored_ds = tailor_slice_dataset(ctx, slice_ds)
202263
self.assertEqual(
203264
{
204265
"Conventions": "CF-1.8",
@@ -208,9 +269,12 @@ def test_it_updates_attrs_according_to_update_mode(self):
208269
tailored_ds.attrs,
209270
)
210271

211-
tailored_ds = tailor_slice_dataset(
212-
slice_ds, target_md, "time", "replace", {"a": 12, "b": True}
272+
ctx = make_context(
273+
{"attrs_update_mode": "replace", "attrs": {"a": 12, "b": True}},
274+
target_ds,
275+
True,
213276
)
277+
tailored_ds = tailor_slice_dataset(ctx, slice_ds)
214278
self.assertEqual(
215279
{
216280
"title": "OCC 2024",
@@ -220,9 +284,12 @@ def test_it_updates_attrs_according_to_update_mode(self):
220284
tailored_ds.attrs,
221285
)
222286

223-
tailored_ds = tailor_slice_dataset(
224-
slice_ds, target_md, "time", "update", {"a": 12, "b": True}
287+
ctx = make_context(
288+
{"attrs_update_mode": "update", "attrs": {"a": 12, "b": True}},
289+
target_ds,
290+
True,
225291
)
292+
tailored_ds = tailor_slice_dataset(ctx, slice_ds)
226293
self.assertEqual(
227294
{
228295
"Conventions": "CF-1.8",
@@ -233,9 +300,12 @@ def test_it_updates_attrs_according_to_update_mode(self):
233300
tailored_ds.attrs,
234301
)
235302

236-
tailored_ds = tailor_slice_dataset(
237-
slice_ds, target_md, "time", "ignore", {"a": 12, "b": True}
303+
ctx = make_context(
304+
{"attrs_update_mode": "ignore", "attrs": {"a": 12, "b": True}},
305+
target_ds,
306+
True,
238307
)
308+
tailored_ds = tailor_slice_dataset(ctx, slice_ds)
239309
self.assertEqual(
240310
{
241311
"a": 12,

zappend/processor.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def create_target_from_slice(
125125
target_dir = ctx.target_dir
126126
logger.info(f"Creating target dataset {target_dir.uri}")
127127

128-
target_ds = tailor_target_dataset(slice_ds, ctx.target_metadata)
128+
target_ds = tailor_target_dataset(ctx, slice_ds)
129129

130130
if ctx.dry_run:
131131
return
@@ -152,9 +152,7 @@ def update_target_from_slice(
152152
target_dir = ctx.target_dir
153153
logger.info(f"Updating target dataset {target_dir.uri}")
154154

155-
slice_ds = tailor_slice_dataset(
156-
slice_ds, ctx.target_metadata, ctx.append_dim, ctx.attrs_update_mode, ctx.attrs
157-
)
155+
slice_ds = tailor_slice_dataset(ctx, slice_ds)
158156

159157
if ctx.dry_run:
160158
return

0 commit comments

Comments
 (0)