@@ -107,46 +107,94 @@ def test_hazard_forecast_concat(haz_fc, lead_time, member):
107107 npt .assert_array_equal (haz_fc_concat .member , np .concatenate ([member , member ]))
108108
109109
110- @pytest .mark .parametrize (
111- "var, var_select" ,
112- [("event_id" , "event_id" ), ("event_name" , "event_names" ), ("date" , "date" )],
113- )
114- def test_hazard_forecast_select (haz_fc , lead_time , member , haz_kwargs , var , var_select ):
115- """Check if Hazard.select works on the derived class"""
116-
117- select_mask = np .array ([3 , 2 ])
118- ordered_select_mask = np .array ([3 , 2 ])
119- if var == "date" :
120- # Date needs to be a valid delta
121- select_mask = np .array ([2 , 3 ])
122- ordered_select_mask = np .array ([2 , 3 ])
123-
124- var_value = np .array (haz_kwargs [var ])[select_mask ]
125- # event_name is a list, convert to numpy array for indexing
126- haz_fc_sel = haz_fc .select (** {var_select : var_value })
127- # Note: order is preserved
128- npt .assert_array_equal (
129- haz_fc_sel .event_id ,
130- haz_fc .event_id [ordered_select_mask ],
131- )
132- npt .assert_array_equal (
133- haz_fc_sel .event_name ,
134- np .array (haz_fc .event_name )[ordered_select_mask ],
135- )
136- npt .assert_array_equal (haz_fc_sel .date , haz_fc .date [ordered_select_mask ])
137- npt .assert_array_equal (haz_fc_sel .frequency , haz_fc .frequency [ordered_select_mask ])
138- npt .assert_array_equal (haz_fc_sel .member , member [ordered_select_mask ])
139- npt .assert_array_equal (haz_fc_sel .lead_time , lead_time [ordered_select_mask ])
140- npt .assert_array_equal (
141- haz_fc_sel .intensity .todense (),
142- haz_fc .intensity .todense ()[ordered_select_mask ],
143- )
144- npt .assert_array_equal (
145- haz_fc_sel .fraction .todense (),
146- haz_fc .fraction .todense ()[ordered_select_mask ],
147- )
110+ class TestSelect :
148111
149- assert haz_fc_sel .centroids == haz_fc .centroids
112+ @pytest .mark .parametrize (
113+ "var, var_select" ,
114+ [("event_id" , "event_id" ), ("event_name" , "event_names" ), ("date" , "date" )],
115+ )
116+ def test_base_class_select (
117+ self , haz_fc , lead_time , member , haz_kwargs , var , var_select
118+ ):
119+ """Check if Hazard.select works on the derived class"""
120+
121+ select_mask = np .array ([3 , 2 ])
122+ ordered_select_mask = np .array ([3 , 2 ])
123+ if var == "date" :
124+ # Date needs to be a valid delta
125+ select_mask = np .array ([2 , 3 ])
126+ ordered_select_mask = np .array ([2 , 3 ])
127+
128+ var_value = np .array (haz_kwargs [var ])[select_mask ]
129+ # event_name is a list, convert to numpy array for indexing
130+ haz_fc_sel = haz_fc .select (** {var_select : var_value })
131+ # Note: order is preserved
132+ npt .assert_array_equal (
133+ haz_fc_sel .event_id ,
134+ haz_fc .event_id [ordered_select_mask ],
135+ )
136+ npt .assert_array_equal (
137+ haz_fc_sel .event_name ,
138+ np .array (haz_fc .event_name )[ordered_select_mask ],
139+ )
140+ npt .assert_array_equal (haz_fc_sel .date , haz_fc .date [ordered_select_mask ])
141+ npt .assert_array_equal (
142+ haz_fc_sel .frequency , haz_fc .frequency [ordered_select_mask ]
143+ )
144+ npt .assert_array_equal (haz_fc_sel .member , member [ordered_select_mask ])
145+ npt .assert_array_equal (haz_fc_sel .lead_time , lead_time [ordered_select_mask ])
146+ npt .assert_array_equal (
147+ haz_fc_sel .intensity .todense (),
148+ haz_fc .intensity .todense ()[ordered_select_mask ],
149+ )
150+ npt .assert_array_equal (
151+ haz_fc_sel .fraction .todense (),
152+ haz_fc .fraction .todense ()[ordered_select_mask ],
153+ )
154+
155+ assert haz_fc_sel .centroids == haz_fc .centroids
156+
157+ def test_derived_select_single (self , haz_fc , lead_time , member ):
158+ haz_fc_select = haz_fc .select (member = [3 , 0 ])
159+ idx = np .array ([0 , 3 ])
160+ npt .assert_array_equal (haz_fc_select .event_id , haz_fc .event_id [idx ])
161+ npt .assert_array_equal (haz_fc_select .member , member [idx ])
162+ npt .assert_array_equal (haz_fc_select .lead_time , lead_time [idx ])
163+
164+ haz_fc_select = haz_fc .select (lead_time = lead_time [np .array ([3 , 0 ])])
165+ npt .assert_array_equal (haz_fc_select .event_id , haz_fc .event_id [idx ])
166+ npt .assert_array_equal (haz_fc_select .member , member [idx ])
167+ npt .assert_array_equal (haz_fc_select .lead_time , lead_time [idx ])
168+
169+ def test_derived_select_intersections (self , haz_fc , lead_time , member , haz_kwargs ):
170+ haz_fc_select = haz_fc .select (event_id = [1 , 4 ], member = [0 , 1 , 2 ])
171+ npt .assert_array_equal (haz_fc_select .event_id , haz_fc .event_id [np .array ([0 ])])
172+
173+ haz_fc_select = haz_fc .select (
174+ event_id = [1 , 2 , 4 ], member = [0 , 1 , 2 ], lead_time = lead_time [1 :3 ]
175+ )
176+ npt .assert_array_equal (haz_fc_select .event_id , haz_fc .event_id [np .array ([1 ])])
177+
178+ # Test "outer"
179+ haz_fc2 = HazardForecast (
180+ lead_time = lead_time , member = np .zeros_like (member , dtype = "int" ), ** haz_kwargs
181+ )
182+ haz_fc_select = haz_fc2 .select (event_id = [1 , 2 , 4 ], member = [0 ])
183+ npt .assert_array_equal (haz_fc_select .event_id , [1 , 2 , 4 ])
184+ npt .assert_array_equal (haz_fc_select .member , [0 , 0 , 0 ])
185+
186+ def test_derived_select_null (self , haz_fc , haz_kwargs ):
187+ haz_fc_select = haz_fc .select ()
188+ assert_hazard_kwargs (haz_fc_select , ** haz_kwargs )
189+
190+ with pytest .raises (IndexError ):
191+ haz_fc .select (event_id = [- 1 ])
192+ with pytest .raises (IndexError ):
193+ haz_fc .select (member = [- 1 ])
194+ with pytest .raises (IndexError ):
195+ haz_fc .select (
196+ lead_time = [np .timedelta64 ("2" , "Y" ).astype ("timedelta64[ns]" )]
197+ )
150198
151199
152200def test_write_read_hazard_forecast (haz_fc , tmp_path ):
0 commit comments