@@ -92,16 +92,58 @@ def test_impact_forecast_from_impact(
9292 self .assert_impact_kwargs (impact_forecast , ** impact_kwargs )
9393
9494
95- def test_impact_forecast_select (impact_forecast , lead_time , member , impact_kwargs ):
95+ @pytest .mark .parametrize (
96+ "var, var_select" ,
97+ [("event_id" , "event_ids" ), ("event_name" , "event_names" ), ("date" , "dates" )],
98+ )
99+ def test_impact_forecast_select_events (
100+ impact_forecast , lead_time , member , impact_kwargs , var , var_select
101+ ):
96102 """Check if Impact.select works on the derived class"""
97- event_ids = impact_kwargs ["event_id" ][np .array ([2 , 0 ])]
98- impact_fc = impact_forecast .select (event_ids = event_ids )
103+ select_mask = np .array ([2 , 1 ])
104+ ordered_select_mask = np .array ([1 , 2 ])
105+ if var == "date" :
106+ # Date needs to be a valid delta
107+ select_mask = np .array ([1 , 2 ])
108+ ordered_select_mask = np .array ([1 , 2 ])
109+
110+ var_value = np .array (impact_kwargs [var ])[select_mask ]
111+ # event_name is a list, convert to numpy array for indexing
112+ impact_fc = impact_forecast .select (** {var_select : var_value })
99113 # NOTE: Events keep their original order
100114 npt .assert_array_equal (
101- impact_fc .event_id , impact_forecast .event_id [np .array ([0 , 2 ])]
115+ impact_fc .event_id ,
116+ impact_forecast .event_id [ordered_select_mask ],
117+ )
118+ npt .assert_array_equal (
119+ impact_fc .event_name ,
120+ np .array (impact_forecast .event_name )[ordered_select_mask ],
121+ )
122+ npt .assert_array_equal (impact_fc .date , impact_forecast .date [ordered_select_mask ])
123+ npt .assert_array_equal (
124+ impact_fc .frequency , impact_forecast .frequency [ordered_select_mask ]
125+ )
126+ npt .assert_array_equal (impact_fc .member , member [ordered_select_mask ])
127+ npt .assert_array_equal (impact_fc .lead_time , lead_time [ordered_select_mask ])
128+ npt .assert_array_equal (
129+ impact_fc .imp_mat .todense (),
130+ impact_forecast .imp_mat .todense ()[ordered_select_mask ],
131+ )
132+
133+
134+ def test_impact_forecast_select_exposure (
135+ impact_forecast , lead_time , member , impact_kwargs
136+ ):
137+ """Check if Impact.select works on the derived class"""
138+ exp_col = 0
139+ select_mask = np .array ([exp_col ])
140+ coord_exp = impact_kwargs ["coord_exp" ][select_mask ]
141+ impact_fc = impact_forecast .select (coord_exp = coord_exp )
142+ npt .assert_array_equal (impact_fc .member , member )
143+ npt .assert_array_equal (impact_fc .lead_time , lead_time )
144+ npt .assert_array_equal (
145+ impact_fc .imp_mat .todense (), impact_forecast .imp_mat .todense ()[:, exp_col ]
102146 )
103- npt .assert_array_equal (impact_fc .member , member [np .array ([0 , 2 ])])
104- npt .assert_array_equal (impact_fc .lead_time , lead_time [np .array ([0 , 2 ])])
105147
106148
107149@pytest .mark .skip ("Concat from base class does not work" )
0 commit comments