@@ -40,35 +40,66 @@ def impact(impact_kwargs):
4040 return Impact (** impact_kwargs )
4141
4242
43- def assert_impact_kwargs (impact : Impact , ** kwargs ):
44- for key , value in kwargs .items ():
45- attr = getattr (impact , key )
46- if isinstance (value , (np .ndarray , list )):
47- npt .assert_array_equal (attr , value )
48- elif isinstance (value , csr_matrix ):
49- npt .assert_array_equal (attr .todense (), value .todense ())
50- else :
51- assert attr == value
43+ @pytest .fixture
44+ def lead_time ():
45+ return pd .timedelta_range (start = "1 day" , periods = 6 ).to_numpy ()
5246
5347
54- class TestImpactForecastInit :
55- lead_time = pd .date_range ("2000-01-01" , "2000-01-02" , periods = 6 ).to_numpy ()
56- member = np .arange (6 )
48+ @pytest .fixture
49+ def member ():
50+ return np .arange (6 )
51+
52+
53+ @pytest .fixture
54+ def impact_forecast (impact , lead_time , member ):
55+ return ImpactForecast .from_impact (impact , lead_time = lead_time , member = member )
56+
5757
58- def test_impact_forecast_init (self , impact_kwargs ):
58+ class TestImpactForecastInit :
59+ def assert_impact_kwargs (self , impact : Impact , ** kwargs ):
60+ for key , value in kwargs .items ():
61+ attr = getattr (impact , key )
62+ if isinstance (value , (np .ndarray , list )):
63+ npt .assert_array_equal (attr , value )
64+ elif isinstance (value , csr_matrix ):
65+ npt .assert_array_equal (attr .todense (), value .todense ())
66+ else :
67+ assert attr == value
68+
69+ def test_impact_forecast_init (self , impact_kwargs , lead_time , member ):
5970 forecast1 = ImpactForecast (
60- lead_time = self . lead_time ,
61- member = self . member ,
71+ lead_time = lead_time ,
72+ member = member ,
6273 ** impact_kwargs ,
6374 )
64- npt .assert_array_equal (forecast1 .lead_time , self .lead_time )
65- npt .assert_array_equal (forecast1 .member , self .member )
66- assert_impact_kwargs (forecast1 , ** impact_kwargs )
67-
68- def test_impact_forecast_from_impact (self , impact , impact_kwargs ):
69- forecast = ImpactForecast .from_impact (
70- impact , lead_time = self .lead_time , member = self .member
71- )
72- npt .assert_array_equal (forecast .lead_time , self .lead_time )
73- npt .assert_array_equal (forecast .member , self .member )
74- assert_impact_kwargs (forecast , ** impact_kwargs )
75+ npt .assert_array_equal (forecast1 .lead_time , lead_time )
76+ npt .assert_array_equal (forecast1 .member , member )
77+ self .assert_impact_kwargs (forecast1 , ** impact_kwargs )
78+
79+ def test_impact_forecast_from_impact (
80+ self , impact_forecast , impact_kwargs , lead_time , member
81+ ):
82+ npt .assert_array_equal (impact_forecast .lead_time , lead_time )
83+ npt .assert_array_equal (impact_forecast .member , member )
84+ self .assert_impact_kwargs (impact_forecast , ** impact_kwargs )
85+
86+
87+ def test_impact_forecast_select (impact_forecast , lead_time , member , impact_kwargs ):
88+ """Check if Impact.select works on the derived class"""
89+ event_ids = impact_kwargs ["event_id" ][np .array ([2 , 0 ])]
90+ impact_fc = impact_forecast .select (event_ids = event_ids )
91+ # NOTE: Events keep their original order
92+ npt .assert_array_equal (
93+ impact_fc .event_id , impact_forecast .event_id [np .array ([0 , 2 ])]
94+ )
95+ npt .assert_array_equal (impact_fc .member , member [np .array ([0 , 2 ])])
96+ npt .assert_array_equal (impact_fc .lead_time , lead_time [np .array ([0 , 2 ])])
97+
98+
99+ @pytest .mark .skip ("Concat from base class does not work" )
100+ def test_impact_forecast_concat (impact_forecast , member ):
101+ """Check if Impact.concat works on the derived class"""
102+ impact_fc = ImpactForecast .concat (
103+ [impact_forecast , impact_forecast ], reset_event_ids = True
104+ )
105+ npt .assert_array_equal (impact_fc .member , np .concatenate ([member , member ]))
0 commit comments