@@ -40,35 +40,59 @@ def impact(impact_kwargs):
4040 return Impact (** impact_kwargs )
4141
4242
43- def assert_impact_kwargs (impact : Impact , ** kwargs ):
44- for key , value in kwargs .items ():
45- attr = getattr (impact , key )
46- if isinstance (value , (np .ndarray , list )):
47- npt .assert_array_equal (attr , value )
48- elif isinstance (value , csr_matrix ):
49- npt .assert_array_equal (attr .todense (), value .todense ())
50- else :
51- assert attr == value
43+ @pytest .fixture
44+ def lead_time ():
45+ return pd .timedelta_range (start = "1 day" , periods = 6 ).to_numpy ()
46+
47+
48+ @pytest .fixture
49+ def member ():
50+ return np .arange (6 )
51+
52+
53+ @pytest .fixture
54+ def impact_forecast (impact , lead_time , member ):
55+ return ImpactForecast .from_impact (impact , lead_time = lead_time , member = member )
5256
5357
5458class TestImpactForecastInit :
55- lead_time = pd .timedelta_range (start = "1 day" , periods = 6 ).to_numpy ()
56- member = np .arange (6 )
5759
58- def test_impact_forecast_init (self , impact_kwargs ):
60+ def assert_impact_kwargs (self , impact : Impact , ** kwargs ):
61+ for key , value in kwargs .items ():
62+ attr = getattr (impact , key )
63+ if isinstance (value , (np .ndarray , list )):
64+ npt .assert_array_equal (attr , value )
65+ elif isinstance (value , csr_matrix ):
66+ npt .assert_array_equal (attr .todense (), value .todense ())
67+ else :
68+ assert attr == value
69+
70+ def test_impact_forecast_init (self , impact_kwargs , lead_time , member ):
5971 forecast1 = ImpactForecast (
60- lead_time = self . lead_time ,
61- member = self . member ,
72+ lead_time = lead_time ,
73+ member = member ,
6274 ** impact_kwargs ,
6375 )
64- npt .assert_array_equal (forecast1 .lead_time , self .lead_time )
65- npt .assert_array_equal (forecast1 .member , self .member )
66- assert_impact_kwargs (forecast1 , ** impact_kwargs )
67-
68- def test_impact_forecast_from_impact (self , impact , impact_kwargs ):
69- forecast = ImpactForecast .from_impact (
70- impact , lead_time = self .lead_time , member = self .member
71- )
72- npt .assert_array_equal (forecast .lead_time , self .lead_time )
73- npt .assert_array_equal (forecast .member , self .member )
74- assert_impact_kwargs (forecast , ** impact_kwargs )
76+ npt .assert_array_equal (forecast1 .lead_time , lead_time )
77+ npt .assert_array_equal (forecast1 .member , member )
78+ self .assert_impact_kwargs (forecast1 , ** impact_kwargs )
79+
80+ def test_impact_forecast_from_impact (
81+ self , impact_forecast , impact_kwargs , lead_time , member
82+ ):
83+ npt .assert_array_equal (impact_forecast .lead_time , lead_time )
84+ npt .assert_array_equal (impact_forecast .member , member )
85+ self .assert_impact_kwargs (impact_forecast , ** impact_kwargs )
86+
87+
88+ def test_impact_forecast_select (impact_forecast , lead_time , member , impact_kwargs ):
89+ """Check if Impact.select works on the derived class"""
90+
91+ event_ids = impact_kwargs ["event_id" ][np .array ([2 , 0 ])]
92+ impact_fc = impact_forecast .select (event_ids = event_ids )
93+ # NOTE: Events keep their original order
94+ npt .assert_array_equal (
95+ impact_fc .event_id , impact_forecast .event_id [np .array ([0 , 2 ])]
96+ )
97+ npt .assert_array_equal (impact_fc .member , member [np .array ([0 , 2 ])])
98+ npt .assert_array_equal (impact_fc .lead_time , lead_time [np .array ([0 , 2 ])])
0 commit comments