2222import numpy as np
2323import numpy .testing as npt
2424import pandas as pd
25+ import pytest
2526
26- from climada .util .forecast import Forecast
27+ from climada .util .forecast import Forecast , check_attribute_shapes
2728
2829
2930def test_forecast_init ():
@@ -50,3 +51,32 @@ def test_forecast_init():
5051 forecast = Forecast (lead_time = lead_times_seconds , member = [1 , 2 , 3 ])
5152 npt .assert_array_equal (forecast .lead_time , lead_times_seconds , strict = True )
5253 assert forecast .lead_time .dtype == np .dtype ("timedelta64[ns]" )
54+
55+
56+ class A :
57+ foo = np .array ([[0 , 1 ], [1 , 0 ]])
58+
59+
60+ class B :
61+ bar = np .array ([[1 , 1 ], [1 , 1 ]])
62+
63+
64+ class TestCheckCompareShapes :
65+ @pytest .fixture
66+ def a (self ):
67+ return A ()
68+
69+ @pytest .fixture
70+ def b (self ):
71+ return B ()
72+
73+ def test_pass (self , a , b ):
74+ check_attribute_shapes (a , "foo" , b , "bar" )
75+
76+ def test_error (self , a , b ):
77+ a .foo = np .array ([0 , 1 ])
78+ with pytest .raises (ValueError , match = r"A.foo \(2\,\)" ):
79+ check_attribute_shapes (a , "foo" , b , "bar" )
80+ b .bar = np .array ([0 , 1 , 2 ])
81+ with pytest .raises (ValueError , match = r"B.bar \(3\,\)" ):
82+ check_attribute_shapes (a , "foo" , b , "bar" )
0 commit comments