Skip to content

Commit 29ae24e

Browse files
committed
Start implementing shape check
1 parent fad5ec2 commit 29ae24e

File tree

2 files changed

+50
-1
lines changed

2 files changed

+50
-1
lines changed

climada/util/forecast.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,28 @@
1919
Define Forecast base class.
2020
"""
2121

22+
from typing import Any
23+
2224
import numpy as np
2325

2426

27+
def check_attribute_shapes(obj_act: Any, attr_act: str, obj_exp: Any, attr_exp: str):
28+
"""Compare the shapes of attributes of two objects.
29+
30+
Raises
31+
------
32+
ValueError
33+
If the shapes do not match
34+
"""
35+
shape_actual = getattr(obj_act, attr_act).shape
36+
shape_expected = getattr(obj_exp, attr_exp).shape
37+
if shape_actual != shape_expected:
38+
raise ValueError(
39+
f"Shape mismatch between {type(obj_act).__name__}.{attr_act} "
40+
f"{shape_actual} and {type(obj_exp).__name__}.{attr_exp} {shape_expected}"
41+
)
42+
43+
2544
class Forecast:
2645
"""Mixin class for forecast data.
2746

climada/util/test/test_forecast.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
import numpy as np
2323
import numpy.testing as npt
2424
import pandas as pd
25+
import pytest
2526

26-
from climada.util.forecast import Forecast
27+
from climada.util.forecast import Forecast, check_attribute_shapes
2728

2829

2930
def test_forecast_init():
@@ -50,3 +51,32 @@ def test_forecast_init():
5051
forecast = Forecast(lead_time=lead_times_seconds, member=[1, 2, 3])
5152
npt.assert_array_equal(forecast.lead_time, lead_times_seconds, strict=True)
5253
assert forecast.lead_time.dtype == np.dtype("timedelta64[ns]")
54+
55+
56+
class A:
57+
foo = np.array([[0, 1], [1, 0]])
58+
59+
60+
class B:
61+
bar = np.array([[1, 1], [1, 1]])
62+
63+
64+
class TestCheckCompareShapes:
65+
@pytest.fixture
66+
def a(self):
67+
return A()
68+
69+
@pytest.fixture
70+
def b(self):
71+
return B()
72+
73+
def test_pass(self, a, b):
74+
check_attribute_shapes(a, "foo", b, "bar")
75+
76+
def test_error(self, a, b):
77+
a.foo = np.array([0, 1])
78+
with pytest.raises(ValueError, match=r"A.foo \(2\,\)"):
79+
check_attribute_shapes(a, "foo", b, "bar")
80+
b.bar = np.array([0, 1, 2])
81+
with pytest.raises(ValueError, match=r"B.bar \(3\,\)"):
82+
check_attribute_shapes(a, "foo", b, "bar")

0 commit comments

Comments
 (0)