66from hidimstat ._utils .utils import _check_vim_predict_method
77from hidimstat .base_variable_importance import (
88 BaseVariableImportance ,
9- VariableImportanceGroup ,
9+ VariableImportanceFeatureGroup ,
1010)
1111
1212
13- class BasePerturbation (BaseVariableImportance , VariableImportanceGroup ):
13+ class BasePerturbation (BaseVariableImportance , VariableImportanceFeatureGroup ):
1414 def __init__ (
1515 self ,
1616 estimator ,
@@ -121,12 +121,14 @@ def importance(self, X, y):
121121 out_dict ["importance" ] = np .array (
122122 [
123123 np .mean (out_dict ["loss" ][j ]) - loss_reference
124- for j in range (self .n_groups )
124+ for j in range (self .n_features_groups )
125125 ]
126126 )
127127 return out_dict
128128
129- def _joblib_predict_one_group (self , X , features_group_id , features_group_key ):
129+ def _joblib_predict_one_features_group (
130+ self , X , features_group_id , features_group_key
131+ ):
130132 """
131133 Compute the predictions after perturbation of the data for a given
132134 group of variables. This function is parallelized.
@@ -140,7 +142,7 @@ def _joblib_predict_one_group(self, X, features_group_id, features_group_key):
140142 group_key: str, int
141143 The key of the group of variables. (parameter use for debugging)
142144 """
143- features_group_ids = self ._groups_ids [features_group_id ]
145+ features_group_ids = self ._features_groups_ids [features_group_id ]
144146 non_features_group_ids = np .delete (np .arange (X .shape [1 ]), features_group_ids )
145147 # Create an array X_perm_j of shape (n_permutations, n_samples, n_features)
146148 # where the j-th group of covariates is permuted
0 commit comments