@@ -62,8 +62,8 @@ def run_cfi(X, y, n_permutation, seed):
6262 # fit the model using the training set
6363 cfi .fit (
6464 X_train ,
65- groups = None ,
66- var_type = "auto" ,
65+ features_groups = None ,
66+ features_type = "auto" ,
6767 )
6868 # calculate feature importance using the test set
6969 vim = cfi .importance (X_test , y_test )
@@ -194,8 +194,8 @@ def test_group(data_generator):
194194 )
195195 cfi .fit (
196196 X_train_df ,
197- groups = groups ,
198- var_type = "continuous" ,
197+ features_groups = groups ,
198+ features_type = "continuous" ,
199199 )
200200 # Warning expected since column names in pandas are not considered
201201 with pytest .warns (UserWarning , match = "X does not have valid feature names, but" ):
@@ -245,8 +245,8 @@ def test_classication(data_generator):
245245 )
246246 cfi .fit (
247247 X_train ,
248- groups = None ,
249- var_type = ["continuous" ] * X .shape [1 ],
248+ features_groups = None ,
249+ features_type = ["continuous" ] * X .shape [1 ],
250250 )
251251 vim = cfi .importance (X_test , y_test_clf )
252252 importance = vim ["importance" ]
@@ -297,13 +297,13 @@ def test_fit(self, data_generator):
297297 # Test fit with auto var_type
298298 cfi .fit (X )
299299 assert len (cfi ._list_imputation_models ) == X .shape [1 ]
300- assert cfi .n_groups == X .shape [1 ]
300+ assert cfi .n_features_groups == X .shape [1 ]
301301
302302 # Test fit with specified groups
303303 groups = {"g1" : [0 , 1 ], "g2" : [2 , 3 , 4 ]}
304- cfi .fit (X , groups = groups )
304+ cfi .fit (X , features_groups = groups )
305305 assert len (cfi ._list_imputation_models ) == 2
306- assert cfi .n_groups == 2
306+ assert cfi .n_features_groups == 2
307307
308308 def test_categorical (
309309 self ,
@@ -331,8 +331,8 @@ def test_categorical(
331331 random_state = seed + 1 ,
332332 )
333333
334- var_type = ["continuous" , "continuous" , "categorical" ]
335- cfi .fit (X , y , var_type = var_type )
334+ features_type = ["continuous" , "continuous" , "categorical" ]
335+ cfi .fit (X , y , features_type = features_type )
336336
337337 importances = cfi .importance (X , y )["importance" ]
338338 assert len (importances ) == 3
@@ -415,7 +415,7 @@ def test_invalid_type(self, data_generator):
415415
416416 # Test error when passing invalid var_type
417417 with pytest .raises (ValueError , match = "type of data 'invalid' unknow." ):
418- cfi .fit (X , var_type = "invalid" )
418+ cfi .fit (X , features_type = "invalid" )
419419
420420 def test_invalid_n_permutations (self , data_generator ):
421421 """Test when invalid number of permutations is provided"""
@@ -434,7 +434,7 @@ def test_not_good_type_X(self, data_generator):
434434 imputation_model_continuous = LinearRegression (),
435435 method = "predict" ,
436436 )
437- cfi .fit (X , groups = None , var_type = "auto" )
437+ cfi .fit (X , features_groups = None , features_type = "auto" )
438438
439439 with pytest .raises (
440440 ValueError , match = "X should be a pandas dataframe or a numpy array."
@@ -450,7 +450,7 @@ def test_mismatched_features(self, data_generator):
450450 imputation_model_continuous = LinearRegression (),
451451 method = "predict" ,
452452 )
453- cfi .fit (X , groups = None , var_type = "auto" )
453+ cfi .fit (X , features_groups = None , features_type = "auto" )
454454
455455 with pytest .raises (
456456 AssertionError , match = "X does not correspond to the fitting data."
@@ -473,7 +473,7 @@ def test_mismatched_features_string(self, data_generator):
473473 "col_" + str (i ) for i in range (int (X .shape [1 ] / 2 ), X .shape [1 ] - 3 )
474474 ],
475475 }
476- cfi .fit (X , groups = subgroups , var_type = "auto" )
476+ cfi .fit (X , features_groups = subgroups , features_type = "auto" )
477477
478478 with pytest .raises (
479479 AssertionError ,
@@ -499,8 +499,8 @@ def test_internal_error(self, data_generator):
499499 "col_" + str (i ) for i in range (int (X .shape [1 ] / 2 ), X .shape [1 ] - 3 )
500500 ],
501501 }
502- cfi .fit (X , groups = subgroups , var_type = "auto" )
503- cfi .groups ["group1" ] = [None for i in range (100 )]
502+ cfi .fit (X , features_groups = subgroups , features_type = "auto" )
503+ cfi .features_groups ["group1" ] = [None for i in range (100 )]
504504
505505 X = X .to_records (index = False )
506506 X = np .array (X , dtype = X .dtype .descr )
@@ -517,7 +517,9 @@ def test_invalid_var_type(self, data_generator):
517517 cfi = CFI (estimator = fitted_model , method = "predict" )
518518
519519 with pytest .raises (ValueError , match = "type of data 'invalid_type' unknow." ):
520- cfi .fit (X , groups = None , var_type = ["invalid_type" ] * X .shape [1 ])
520+ cfi .fit (
521+ X , features_groups = None , features_type = ["invalid_type" ] * X .shape [1 ]
522+ )
521523
522524 def test_incompatible_imputer (self , data_generator ):
523525 """Test when incompatible imputer is provided"""
@@ -548,7 +550,7 @@ def test_invalid_groups_format(self, data_generator):
548550
549551 invalid_groups = ["group1" , "group2" ] # Should be dictionary
550552 with pytest .raises (ValueError , match = "groups needs to be a dictionnary" ):
551- cfi .fit (X , groups = invalid_groups , var_type = "auto" )
553+ cfi .fit (X , features_groups = invalid_groups , features_type = "auto" )
552554
553555 def test_groups_warning (self , data_generator ):
554556 """Test if a subgroup raise a warning"""
@@ -560,7 +562,7 @@ def test_groups_warning(self, data_generator):
560562 method = "predict" ,
561563 )
562564 subgroups = {"group1" : [0 , 1 ], "group2" : [2 , 3 ]}
563- cfi .fit (X , y , groups = subgroups , var_type = "auto" )
565+ cfi .fit (X , y , features_groups = subgroups , features_type = "auto" )
564566
565567 with pytest .warns (
566568 UserWarning ,
0 commit comments