Skip to content

Commit 208c16b

Browse files
committed
fixed coverage
1 parent 067b7d3 commit 208c16b

File tree

2 files changed

+93
-8
lines changed

2 files changed

+93
-8
lines changed

tests/config/test_attrs.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44

55
import json
66
import unittest
7+
from datetime import datetime
78

89
import numpy as np
910
import pytest
1011

1112
from zappend.config.attrs import ConfigAttrsUserFunctions
1213
from zappend.config.attrs import eval_dyn_config_attrs
14+
from zappend.config.attrs import eval_expr
1315
from zappend.config.attrs import get_dyn_config_attrs_env
1416
from zappend.config.attrs import has_dyn_config_attrs
1517
from ..helpers import make_test_dataset
@@ -19,7 +21,7 @@ class EvalDynConfigAttrsTest(unittest.TestCase):
1921
def setUp(self):
2022
ds = make_test_dataset()
2123
ds.attrs["title"] = "Ocean Colour"
22-
self.env = dict(ds=ds, N=10)
24+
self.env = get_dyn_config_attrs_env(ds, N=10)
2325

2426
def test_zero(self):
2527
self.assertEqual(
@@ -69,7 +71,7 @@ def test_x_min_max(self):
6971
)
7072
self.assertEqual(attrs, json.loads(json.dumps(attrs)))
7173

72-
def test_x_min_max_corr(self):
74+
def test_x_min_max_center(self):
7375
attrs = eval_dyn_config_attrs(
7476
{
7577
"x_min": "{{ ds.x[0] - (ds.x[1]-ds.x[0])/2 }}",
@@ -81,6 +83,18 @@ def test_x_min_max_corr(self):
8183
self.assertAlmostEqual(1.0, attrs.get("x_max"))
8284
self.assertEqual(attrs, json.loads(json.dumps(attrs)))
8385

86+
def test_x_bounds_center(self):
87+
attrs = eval_dyn_config_attrs(
88+
{
89+
"x_min": "{{ lower_bound(ds.x, ref='center') }}",
90+
"x_max": "{{ upper_bound(ds.x, ref='center') }}",
91+
},
92+
self.env,
93+
)
94+
self.assertAlmostEqual(0.0, attrs.get("x_min"))
95+
self.assertAlmostEqual(1.0, attrs.get("x_max"))
96+
self.assertEqual(attrs, json.loads(json.dumps(attrs)))
97+
8498
def test_time_min_max(self):
8599
attrs = eval_dyn_config_attrs(
86100
{"time_min": "{{ ds.time[0] }}", "time_max": "{{ ds.time[-1] }}"}, self.env
@@ -91,6 +105,20 @@ def test_time_min_max(self):
91105
)
92106
self.assertEqual(attrs, json.loads(json.dumps(attrs)))
93107

108+
def test_time_bounds(self):
109+
attrs = eval_dyn_config_attrs(
110+
{
111+
"time_min": "{{ lower_bound(ds.time) }}",
112+
"time_max": "{{ upper_bound(ds.time) }}",
113+
},
114+
self.env,
115+
)
116+
self.assertEqual(
117+
{"time_min": "2024-01-01T00:00:00", "time_max": "2024-01-04T00:00:00"},
118+
attrs,
119+
)
120+
self.assertEqual(attrs, json.loads(json.dumps(attrs)))
121+
94122

95123
class HasDynConfigAttrsTest(unittest.TestCase):
96124
def test_no(self):
@@ -119,6 +147,56 @@ def test_yes(self):
119147
)
120148

121149

150+
class EvalExprTest(unittest.TestCase):
151+
def test_all_cases(self):
152+
self.assertEqual(True, eval_expr("True", {}))
153+
self.assertEqual(13, eval_expr("13", {}))
154+
self.assertEqual(0.5, eval_expr("0.5", {}))
155+
self.assertEqual("ABC", eval_expr("'ABC'", {}))
156+
157+
switches = np.array([True, False])
158+
self.assertEqual(
159+
True,
160+
eval_expr("switches[0]", {"switches": switches}),
161+
)
162+
163+
levels = [3, 4, 5]
164+
self.assertIs(
165+
levels,
166+
eval_expr("levels", {"levels": levels}),
167+
)
168+
169+
levels = np.array([3, 4, 5])
170+
self.assertEqual(
171+
3,
172+
eval_expr("levels[0]", {"levels": levels}),
173+
)
174+
175+
lon = np.array([11.05, 11.15, 11.25])
176+
self.assertIs(
177+
lon,
178+
eval_expr("lon", {"lon": lon}),
179+
)
180+
181+
lon = np.array([11.05, 11.15, 11.25])
182+
self.assertEqual(
183+
11.05,
184+
eval_expr("lon[0]", {"lon": lon}),
185+
)
186+
187+
names = np.array(["A", "B"])
188+
self.assertEqual(
189+
"A",
190+
eval_expr("names[0]", {"names": names}),
191+
)
192+
193+
time = datetime.fromisoformat("2024-01-02T10:20:30")
194+
self.assertEqual(
195+
"2024-01-02T10:20:30",
196+
eval_expr("time", {"time": time}),
197+
)
198+
199+
122200
class GetDynConfigAttrsEnvTest(unittest.TestCase):
123201
def test_env(self):
124202
ds = make_test_dataset()
@@ -207,6 +285,7 @@ def _assert_bound_func_ok(
207285

208286
def test_upper_lower_bounds_fail_for_wrong_shape(self):
209287
self._assert_upper_lower_bounds_fail(np.array(3))
288+
self._assert_upper_lower_bounds_fail(np.array([]))
210289
self._assert_upper_lower_bounds_fail(np.array([[1, 2], [3, 4]]))
211290

212291
# noinspection PyMethodMayBeStatic

zappend/config/attrs.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,29 +64,35 @@ def eval_attr_value(attr_value: Any, env: dict[str, Any]) -> Any:
6464

6565
def eval_expr(expr: str, env: dict[str, Any]) -> Any:
6666
value = eval(expr, env)
67+
if isinstance(value, (bool, int, float, str, type(None))):
68+
return value
6769
if isinstance(value, datetime.datetime):
6870
return value.replace(microsecond=0).isoformat()
6971
try:
70-
if np.ndim(value) == 0:
71-
if np.issubdtype(value, int):
72+
if value.ndim == 0:
73+
if np.issubdtype(value.dtype, np.bool_):
74+
return bool(value)
75+
if np.issubdtype(value.dtype, np.integer):
7276
return int(value)
73-
if np.issubdtype(value, float):
77+
if np.issubdtype(value.dtype, np.floating):
7478
return float(value)
75-
if np.issubdtype(value, np.datetime64):
79+
if np.issubdtype(value.dtype, np.datetime64):
7680
return np.datetime_as_string(value, unit="s")
77-
except (TypeError, ValueError):
81+
except AttributeError:
7882
pass
83+
# TODO: this is not right, handle remaining cases or raise
7984
return value
8085

8186

82-
def get_dyn_config_attrs_env(ds: xr.Dataset):
87+
def get_dyn_config_attrs_env(ds: xr.Dataset, **kwargs):
8388
return dict(
8489
ds=ds,
8590
**{
8691
k: v
8792
for k, v in ConfigAttrsUserFunctions.__dict__.items()
8893
if not k.startswith("_")
8994
},
95+
**kwargs,
9096
)
9197

9298

0 commit comments

Comments
 (0)