@@ -92,58 +92,90 @@ def test_impact_forecast_from_impact(
9292 self .assert_impact_kwargs (impact_forecast , ** impact_kwargs )
9393
9494
95- @pytest .mark .parametrize (
96- "var, var_select" ,
97- [("event_id" , "event_ids" ), ("event_name" , "event_names" ), ("date" , "dates" )],
98- )
99- def test_impact_forecast_select_events (
100- impact_forecast , lead_time , member , impact_kwargs , var , var_select
101- ):
102- """Check if Impact.select works on the derived class"""
103- select_mask = np .array ([2 , 1 ])
104- ordered_select_mask = np .array ([1 , 2 ])
105- if var == "date" :
106- # Date needs to be a valid delta
107- select_mask = np .array ([1 , 2 ])
108- ordered_select_mask = np .array ([1 , 2 ])
95+ class TestSelect :
10996
110- var_value = np .array (impact_kwargs [var ])[select_mask ]
111- # event_name is a list, convert to numpy array for indexing
112- impact_fc = impact_forecast .select (** {var_select : var_value })
113- # NOTE: Events keep their original order
114- npt .assert_array_equal (
115- impact_fc .event_id ,
116- impact_forecast .event_id [ordered_select_mask ],
117- )
118- npt .assert_array_equal (
119- impact_fc .event_name ,
120- np .array (impact_forecast .event_name )[ordered_select_mask ],
121- )
122- npt .assert_array_equal (impact_fc .date , impact_forecast .date [ordered_select_mask ])
123- npt .assert_array_equal (
124- impact_fc .frequency , impact_forecast .frequency [ordered_select_mask ]
125- )
126- npt .assert_array_equal (impact_fc .member , member [ordered_select_mask ])
127- npt .assert_array_equal (impact_fc .lead_time , lead_time [ordered_select_mask ])
128- npt .assert_array_equal (
129- impact_fc .imp_mat .todense (),
130- impact_forecast .imp_mat .todense ()[ordered_select_mask ],
97+ @pytest .mark .parametrize (
98+ "var, var_select" ,
99+ [("event_id" , "event_ids" ), ("event_name" , "event_names" ), ("date" , "dates" )],
131100 )
101+ def test_base_class_select (
102+ self , impact_forecast , lead_time , member , impact_kwargs , var , var_select
103+ ):
104+ """Check if Impact.select works on the derived class"""
105+ select_mask = np .array ([2 , 1 ])
106+ ordered_select_mask = np .array ([1 , 2 ])
107+ if var == "date" :
108+ # Date needs to be a valid delta
109+ select_mask = np .array ([1 , 2 ])
110+ ordered_select_mask = np .array ([1 , 2 ])
111+
112+ var_value = np .array (impact_kwargs [var ])[select_mask ]
113+ # event_name is a list, convert to numpy array for indexing
114+ impact_fc = impact_forecast .select (** {var_select : var_value })
115+ # NOTE: Events keep their original order
116+ npt .assert_array_equal (
117+ impact_fc .event_id ,
118+ impact_forecast .event_id [ordered_select_mask ],
119+ )
120+ npt .assert_array_equal (
121+ impact_fc .event_name ,
122+ np .array (impact_forecast .event_name )[ordered_select_mask ],
123+ )
124+ npt .assert_array_equal (
125+ impact_fc .date , impact_forecast .date [ordered_select_mask ]
126+ )
127+ npt .assert_array_equal (
128+ impact_fc .frequency , impact_forecast .frequency [ordered_select_mask ]
129+ )
130+ npt .assert_array_equal (impact_fc .member , member [ordered_select_mask ])
131+ npt .assert_array_equal (impact_fc .lead_time , lead_time [ordered_select_mask ])
132+ npt .assert_array_equal (
133+ impact_fc .imp_mat .todense (),
134+ impact_forecast .imp_mat .todense ()[ordered_select_mask ],
135+ )
132136
137+ def test_impact_forecast_select_exposure (
138+ self , impact_forecast , lead_time , member , impact_kwargs
139+ ):
140+ """Check if Impact.select works on the derived class"""
141+ exp_col = 0
142+ select_mask = np .array ([exp_col ])
143+ coord_exp = impact_kwargs ["coord_exp" ][select_mask ]
144+ impact_fc = impact_forecast .select (coord_exp = coord_exp )
145+ npt .assert_array_equal (impact_fc .member , member )
146+ npt .assert_array_equal (impact_fc .lead_time , lead_time )
147+ npt .assert_array_equal (
148+ impact_fc .imp_mat .todense (), impact_forecast .imp_mat .todense ()[:, exp_col ]
149+ )
133150
134- def test_impact_forecast_select_exposure (
135- impact_forecast , lead_time , member , impact_kwargs
136- ):
137- """Check if Impact.select works on the derived class"""
138- exp_col = 0
139- select_mask = np .array ([exp_col ])
140- coord_exp = impact_kwargs ["coord_exp" ][select_mask ]
141- impact_fc = impact_forecast .select (coord_exp = coord_exp )
142- npt .assert_array_equal (impact_fc .member , member )
143- npt .assert_array_equal (impact_fc .lead_time , lead_time )
144- npt .assert_array_equal (
145- impact_fc .imp_mat .todense (), impact_forecast .imp_mat .todense ()[:, exp_col ]
146- )
151+ def test_derived_select (self , haz_fc , lead_time , member , haz_kwargs ):
152+ haz_fc_select = haz_fc .select (member = [3 , 0 ])
153+ idx = np .array ([0 , 3 ])
154+ npt .assert_array_equal (haz_fc_select .event_id , haz_fc .event_id [idx ])
155+ npt .assert_array_equal (haz_fc_select .member , member [idx ])
156+ npt .assert_array_equal (haz_fc_select .lead_time , lead_time [idx ])
157+
158+ haz_fc_select = haz_fc .select (lead_time = lead_time [np .array ([3 , 0 ])])
159+ npt .assert_array_equal (haz_fc_select .event_id , haz_fc .event_id [idx ])
160+ npt .assert_array_equal (haz_fc_select .member , member [idx ])
161+ npt .assert_array_equal (haz_fc_select .lead_time , lead_time [idx ])
162+
163+ # Test intersections
164+ haz_fc_select = haz_fc .select (event_id = [1 , 4 ], member = [0 , 1 , 2 ])
165+ npt .assert_array_equal (haz_fc_select .event_id , haz_fc .event_id [np .array ([0 ])])
166+
167+ haz_fc_select = haz_fc .select (
168+ event_id = [1 , 2 , 4 ], member = [0 , 1 , 2 ], lead_time = lead_time [1 :3 ]
169+ )
170+ npt .assert_array_equal (haz_fc_select .event_id , haz_fc .event_id [np .array ([1 ])])
171+
172+ # Test "outer"
173+ haz_fc2 = HazardForecast (
174+ lead_time = lead_time , member = np .zeros_like (member , dtype = "int" ), ** haz_kwargs
175+ )
176+ haz_fc_select = haz_fc2 .select (event_id = [1 , 2 , 4 ], member = [0 ])
177+ npt .assert_array_equal (haz_fc_select .event_id , [1 , 2 , 4 ])
178+ npt .assert_array_equal (haz_fc_select .member , [0 , 0 , 0 ])
147179
148180
149181@pytest .mark .skip ("Concat from base class does not work" )
0 commit comments