Skip to content

Commit 2a262f5

Browse files
committed
fix attribute serialization
1 parent 64f8d58 commit 2a262f5

File tree

2 files changed

+75
-34
lines changed

2 files changed

+75
-34
lines changed

tests/config/test_attrs.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from datetime import datetime
88

99
import numpy as np
10+
import xarray as xr
1011
import pytest
1112

1213
from zappend.config.attrs import ConfigAttrsUserFunctions
@@ -149,51 +150,56 @@ def test_yes(self):
149150

150151
class EvalExprTest(unittest.TestCase):
151152
def test_all_cases(self):
153+
# scalars
154+
self.assertEqual(None, eval_expr("None", {}))
152155
self.assertEqual(True, eval_expr("True", {}))
153156
self.assertEqual(13, eval_expr("13", {}))
154157
self.assertEqual(0.5, eval_expr("0.5", {}))
155158
self.assertEqual("ABC", eval_expr("'ABC'", {}))
156-
157-
switches = np.array([True, False])
159+
time = datetime.fromisoformat("2024-01-02T10:20:30")
158160
self.assertEqual(
159-
True,
160-
eval_expr("switches[0]", {"switches": switches}),
161+
"2024-01-02T10:20:30",
162+
eval_expr("time", dict(time=time)),
161163
)
162164

163-
levels = [3, 4, 5]
164-
self.assertIs(
165-
levels,
166-
eval_expr("levels", {"levels": levels}),
167-
)
165+
# arrays
166+
self.assert_array_ok([True, False])
167+
self.assert_array_ok([3, 4, 5])
168+
self.assert_array_ok([11.05, 11.15, 11.25])
169+
self.assert_array_ok(["A", "B"])
170+
self.assert_array_ok(["2024-01-02T10:20:30"], dtype="datetime64[s]")
168171

169-
levels = np.array([3, 4, 5])
172+
def assert_array_ok(self, a: list, dtype=None):
173+
# Test list
170174
self.assertEqual(
171-
3,
172-
eval_expr("levels[0]", {"levels": levels}),
175+
a,
176+
eval_expr("a", dict(a=a)),
173177
)
174-
175-
lon = np.array([11.05, 11.15, 11.25])
176-
self.assertIs(
177-
lon,
178-
eval_expr("lon", {"lon": lon}),
178+
self.assertEqual(
179+
a[0],
180+
eval_expr("a[0]", dict(a=a)),
179181
)
180182

181-
lon = np.array([11.05, 11.15, 11.25])
183+
# Test numpy.ndarray
184+
np_a = np.array(a, dtype=dtype) if dtype is not None else np.array(a)
182185
self.assertEqual(
183-
11.05,
184-
eval_expr("lon[0]", {"lon": lon}),
186+
a,
187+
eval_expr("a", dict(a=np_a)),
185188
)
186-
187-
names = np.array(["A", "B"])
188189
self.assertEqual(
189-
"A",
190-
eval_expr("names[0]", {"names": names}),
190+
a[0],
191+
eval_expr("a[0]", dict(a=np_a)),
191192
)
192193

193-
time = datetime.fromisoformat("2024-01-02T10:20:30")
194+
# Test xarray-DataArray
195+
xr_a = xr.DataArray(np_a, dims="x")
194196
self.assertEqual(
195-
"2024-01-02T10:20:30",
196-
eval_expr("time", {"time": time}),
197+
a,
198+
eval_expr("a", dict(a=xr_a)),
199+
)
200+
self.assertEqual(
201+
a[0],
202+
eval_expr("a[0]", dict(a=xr_a)),
197203
)
198204

199205

zappend/config/attrs.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# https://opensource.org/licenses/MIT.
44

55
import datetime
6+
import math
67
from typing import Any, Literal
78

89
import numpy as np
@@ -64,24 +65,58 @@ def eval_attr_value(attr_value: Any, env: dict[str, Any]) -> Any:
6465

6566
def eval_expr(expr: str, env: dict[str, Any]) -> Any:
6667
value = eval(expr, env)
67-
if isinstance(value, (bool, int, float, str, type(None))):
68+
return to_json(value)
69+
70+
71+
def to_json(value) -> Any:
72+
if isinstance(value, float):
73+
if math.isfinite(value):
74+
return value
75+
else:
76+
# TODO: case cover by test
77+
return str(value)
78+
if isinstance(value, (bool, int, str, type(None))):
6879
return value
69-
if isinstance(value, datetime.datetime):
70-
return value.replace(microsecond=0).isoformat()
80+
if isinstance(value, datetime.date):
81+
if isinstance(value, datetime.datetime):
82+
return value.replace(microsecond=0).isoformat()
83+
else:
84+
# TODO: cover case by test
85+
return value.isoformat()
86+
7187
try:
7288
if value.ndim == 0:
89+
try:
90+
# xarray.DataArray case
91+
value = value.values
92+
except AttributeError:
93+
pass
94+
if np.issubdtype(value.dtype, np.floating):
95+
if np.isfinite(value):
96+
return float(value)
97+
else:
98+
# TODO: cover case by test
99+
return str(value)
73100
if np.issubdtype(value.dtype, np.bool_):
74101
return bool(value)
102+
if np.issubdtype(value.dtype, np.str_):
103+
return str(value)
75104
if np.issubdtype(value.dtype, np.integer):
76105
return int(value)
77-
if np.issubdtype(value.dtype, np.floating):
78-
return float(value)
79106
if np.issubdtype(value.dtype, np.datetime64):
80107
return np.datetime_as_string(value, unit="s")
108+
# TODO: cover case by test
109+
raise ValueError(
110+
f"cannot serialize 0-d array of type {value.dtype}: {value}"
111+
)
81112
except AttributeError:
82113
pass
83-
# TODO: this is not right, handle remaining cases or raise
84-
return value
114+
115+
if isinstance(value, dict):
116+
# TODO: cover case by test
117+
return {k: to_json(v) for k, v in value.items()}
118+
119+
return [to_json(v) for v in value]
85120

86121

87122
def get_dyn_config_attrs_env(ds: xr.Dataset, **kwargs):

0 commit comments

Comments
 (0)