@@ -101,16 +101,15 @@ def transform(X):
101101
102102
103103@pytest .fixture (scope = "module" )
104- def penguins_df ():
105- df = load_penguins (as_frame = True ).dropna ()
106- X = df .drop (columns = "species" )
107-
104+ def penguins_df () -> pd .DataFrame :
105+ df : pd .DataFrame = load_penguins (as_frame = True )
106+ X = df .dropna ().drop (columns = "species" )
108107 return X
109108
110109
111110@pytest .fixture (scope = "module" )
112- def penguins (penguins_df ) :
113- return penguins_df .values
111+ def penguins (penguins_df : pd . DataFrame ) -> np . ndarray :
112+ return penguins_df .to_numpy ()
114113
115114
116115def test_all_groups_scaled (dataset_with_single_grouping , scaling_range ):
@@ -269,7 +268,7 @@ def test_array_with_strings():
269268
270269
271270@pytest .mark .parametrize ("frame_func" , [pd .DataFrame , pl .DataFrame , pa .table ])
272- def test_df (penguins_df , frame_func ):
271+ def test_df (penguins_df : pd . DataFrame , frame_func ):
273272 penguins_df = frame_func (penguins_df .to_dict (orient = "list" ))
274273 meta = GroupedTransformer (StandardScaler (), groups = ["island" , "sex" ])
275274
@@ -280,18 +279,17 @@ def test_df(penguins_df, frame_func):
280279
281280
282281@pytest .mark .parametrize ("frame_func" , [pd .DataFrame , pl .DataFrame , pa .table ])
283- def test_df_missing_group (penguins_df , frame_func ):
282+ def test_df_missing_group (penguins_df : pd . DataFrame , frame_func ):
284283 meta = GroupedTransformer (StandardScaler (), groups = ["island" , "sex" ])
285284
286285 # Otherwise the fixture is changed
287- X = penguins_df .copy ()
288- X .loc [0 , "island" ] = None
289- X = frame_func (X .to_dict (orient = "list" ))
286+ X = penguins_df .copy ().to_dict (orient = "list" )
287+ X ["island" ][0 ] = None
290288 with pytest .raises (ValueError ):
291- meta .fit_transform (X )
289+ meta .fit_transform (frame_func ( X ) )
292290
293291
294- def test_array_with_multiple_string_cols (penguins ):
292+ def test_array_with_multiple_string_cols (penguins : np . ndarray ):
295293 X = penguins
296294
297295 # BROKEN: Failing due to negative indexing... kind of an edge case
@@ -314,7 +312,7 @@ def test_grouping_column_not_in_array(penguins):
314312
315313
316314@pytest .mark .parametrize ("frame_func" , [pd .DataFrame , pl .DataFrame , pa .table ])
317- def test_grouping_column_not_in_df (penguins_df , frame_func ):
315+ def test_grouping_column_not_in_df (penguins_df : pd . DataFrame , frame_func ):
318316 meta = GroupedTransformer (StandardScaler (), groups = ["island" , "unexisting_column" ])
319317
320318 # This should raise ValueError
@@ -323,7 +321,7 @@ def test_grouping_column_not_in_df(penguins_df, frame_func):
323321
324322
325323@pytest .mark .parametrize ("frame_func" , [pd .DataFrame , pl .DataFrame , pa .table ])
326- def test_no_grouping (penguins_df , frame_func ):
324+ def test_no_grouping (penguins_df : pd . DataFrame , frame_func ):
327325 penguins_numeric = frame_func (
328326 penguins_df [["bill_length_mm" , "bill_depth_mm" , "flipper_length_mm" , "body_mass_g" ]].to_dict (orient = "list" )
329327 )
@@ -335,7 +333,7 @@ def test_no_grouping(penguins_df, frame_func):
335333
336334
337335@pytest .mark .parametrize ("frame_func" , [pd .DataFrame , pl .DataFrame , pa .table ])
338- def test_with_y (penguins_df , frame_func ):
336+ def test_with_y (penguins_df : pd . DataFrame , frame_func ):
339337 X = frame_func (penguins_df .drop (columns = ["sex" ]).to_dict (orient = "list" ))
340338 y = penguins_df ["sex" ].to_numpy ()
341339
@@ -400,7 +398,7 @@ def test_transform_with_y(transformer):
400398
401399
402400@pytest .mark .parametrize (("frame_func" , "transform_output" ), [(pd .DataFrame , "pandas" ), (pl .DataFrame , "polars" )])
403- def test_set_output (penguins_df , frame_func , transform_output ):
401+ def test_set_output (penguins_df : pd . DataFrame , frame_func , transform_output ):
404402 if transform_output == "polars" and sklearn .__version__ < "1.4.0" :
405403 pytest .skip ()
406404
@@ -417,7 +415,7 @@ def test_with_object_dtype():
417415
418416 data = {
419417 "big" : ["A" , "A" , "A" , "A" , "A" , "B" , "B" , "B" , "C" , "C" ],
420- "small" : ["a" , "a" , None , "a" , "a" , "b" , "b" , None , "C" , "C" ],
418+ "small" : ["a" , "a" , pd . NA , "a" , "a" , "b" , "b" , pd . NA , "C" , "C" ],
421419 "other" : [0.1 , 0.2 , 0.3 , 0.6 , 0.5 , 0.1 , 0.3 , 0.5 , 0.6 , 0.6 ],
422420 "y" : [1 , 1 , 0 , 1 , 0 , 1 , 1 , 0 , 0 , 0 ],
423421 }
@@ -426,7 +424,7 @@ def test_with_object_dtype():
426424
427425 result = (
428426 GroupedTransformer (
429- transformer = SimpleImputer (strategy = "most_frequent" , missing_values = None ),
427+ transformer = SimpleImputer (strategy = "most_frequent" , missing_values = pd . NA ),
430428 groups = ["big" ],
431429 check_X = False ,
432430 )
@@ -439,6 +437,6 @@ def test_with_object_dtype():
439437 "small" : ["a" , "a" , "a" , "a" , "a" , "b" , "b" , "b" , "C" , "C" ],
440438 "other" : [0.1 , 0.2 , 0.3 , 0.6 , 0.5 , 0.1 , 0.3 , 0.5 , 0.6 , 0.6 ],
441439 }
442- ). astype ( "object" )
440+ )
443441
444- assert_frame_equal (result , expected )
442+ assert_frame_equal (result , expected , check_dtype = False )
0 commit comments