@@ -86,14 +86,49 @@ def test_impact_forecast_from_impact(
8686
8787def test_impact_forecast_select (impact_forecast , lead_time , member , impact_kwargs ):
8888 """Check if Impact.select works on the derived class"""
89- event_ids = impact_kwargs ["event_id" ][np .array ([2 , 0 ])]
90- impact_fc = impact_forecast .select (event_ids = event_ids )
91- # NOTE: Events keep their original order
89+ select_mask = np .array ([1 , 2 ])
90+ ordered_select_mask = np .array ([1 , 2 ])
91+ vars_to_select = {
92+ "event_id" : "event_ids" ,
93+ "event_name" : "event_names" ,
94+ "date" : "dates" ,
95+ }
96+
97+ for var , var_select in vars_to_select .items ():
98+ var_value = np .array (impact_kwargs [var ])[select_mask ]
99+ # event_name is a list, convert to numpy array for indexing
100+ impact_fc = impact_forecast .select (** {var_select : var_value })
101+ # NOTE: Events keep their original order
102+ npt .assert_array_equal (
103+ impact_fc .event_id ,
104+ impact_forecast .event_id [ordered_select_mask ],
105+ )
106+ npt .assert_array_equal (
107+ impact_fc .event_name ,
108+ np .array (impact_forecast .event_name )[ordered_select_mask ],
109+ )
110+ npt .assert_array_equal (
111+ impact_fc .date , impact_forecast .date [ordered_select_mask ]
112+ )
113+ npt .assert_array_equal (
114+ impact_fc .frequency , impact_forecast .frequency [ordered_select_mask ]
115+ )
116+ npt .assert_array_equal (impact_fc .member , member [ordered_select_mask ])
117+ npt .assert_array_equal (impact_fc .lead_time , lead_time [ordered_select_mask ])
118+ npt .assert_array_equal (
119+ impact_fc .imp_mat .todense (),
120+ impact_forecast .imp_mat .todense ()[ordered_select_mask ],
121+ )
122+
123+ exp_col = 0
124+ select_mask = np .array ([exp_col ])
125+ coord_exp = impact_kwargs ["coord_exp" ][select_mask ]
126+ impact_fc = impact_forecast .select (coord_exp = coord_exp )
127+ npt .assert_array_equal (impact_fc .member , member )
128+ npt .assert_array_equal (impact_fc .lead_time , lead_time )
92129 npt .assert_array_equal (
93- impact_fc .event_id , impact_forecast .event_id [ np . array ([ 0 , 2 ]) ]
130+ impact_fc .imp_mat . todense () , impact_forecast .imp_mat . todense ()[:, exp_col ]
94131 )
95- npt .assert_array_equal (impact_fc .member , member [np .array ([0 , 2 ])])
96- npt .assert_array_equal (impact_fc .lead_time , lead_time [np .array ([0 , 2 ])])
97132
98133
99134@pytest .mark .skip ("Concat from base class does not work" )
0 commit comments