11import numpy as np
2- import pandas as pd
32from joblib import Parallel , delayed
43from sklearn .base import check_is_fitted
54from sklearn .metrics import root_mean_squared_error
6- import warnings
75
86from hidimstat ._utils .utils import _check_vim_predict_method
9- from hidimstat ._utils .exception import InternalError
10- from hidimstat .base_variable_importance import BaseVariableImportance
7+ from hidimstat .base_variable_importance import (
8+ BaseVariableImportance ,
9+ VariableImportanceGroup ,
10+ )
1111
1212
13- class BasePerturbation (BaseVariableImportance ):
13+ class BasePerturbation (BaseVariableImportance , VariableImportanceGroup ):
1414 def __init__ (
1515 self ,
1616 estimator ,
@@ -43,6 +43,7 @@ def __init__(
4343 The number of parallel jobs to run. Parallelization is done over the
4444 variables or groups of variables.
4545 """
46+ super ().__init__ ()
4647 check_is_fitted (estimator )
4748 assert n_permutations > 0 , "n_permutations must be positive"
4849 self .estimator = estimator
@@ -51,45 +52,6 @@ def __init__(
5152 self .method = method
5253 self .n_jobs = n_jobs
5354 self .n_permutations = n_permutations
54- self .n_groups = None
55-
56- def fit (self , X , y = None , groups = None ):
57- """Base fit method for perturbation-based methods. Identifies the groups.
58-
59- Parameters
60- ----------
61- X: array-like of shape (n_samples, n_features)
62- The input samples.
63- y: array-like of shape (n_samples,)
64- Not used, only present for consistency with the sklearn API.
65- groups: dict, optional
66- A dictionary where the keys are the group names and the values are the
67- list of column names corresponding to each group. If None, the groups are
68- identified based on the columns of X.
69- """
70- if groups is None :
71- self .n_groups = X .shape [1 ]
72- self .groups = {j : [j ] for j in range (self .n_groups )}
73- self ._groups_ids = np .array (list (self .groups .values ()), dtype = int )
74- elif isinstance (groups , dict ):
75- self .n_groups = len (groups )
76- self .groups = groups
77- if isinstance (X , pd .DataFrame ):
78- self ._groups_ids = []
79- for group_key in self .groups .keys ():
80- self ._groups_ids .append (
81- [
82- i
83- for i , col in enumerate (X .columns )
84- if col in self .groups [group_key ]
85- ]
86- )
87- else :
88- self ._groups_ids = [
89- np .array (ids , dtype = int ) for ids in list (self .groups .values ())
90- ]
91- else :
92- raise ValueError ("groups needs to be a dictionnary" )
9355
9456 def predict (self , X ):
9557 """
@@ -111,8 +73,12 @@ def predict(self, X):
11173
11274 # Parallelize the computation of the importance scores for each group
11375 out_list = Parallel (n_jobs = self .n_jobs )(
114- delayed (self ._joblib_predict_one_group )(X_ , group_id , group_key )
115- for group_id , group_key in enumerate (self .groups .keys ())
76+ delayed (self ._joblib_predict_one_features_group )(
77+ X_ , features_group_id , features_group_key
78+ )
79+ for features_group_id , features_group_key in enumerate (
80+ self .features_groups .keys ()
81+ )
11682 )
11783 return np .stack (out_list , axis = 0 )
11884
@@ -160,77 +126,7 @@ def importance(self, X, y):
160126 )
161127 return out_dict
162128
163- def _check_fit (self , X ):
164- """
165- Check if the perturbation method has been properly fitted.
166-
167- This method verifies that the perturbation method has been fitted by checking
168- if required attributes are set and if the number of features matches
169- the grouped variables.
170-
171- Parameters
172- ----------
173- X : array-like of shape (n_samples, n_features)
174- Input data to validate against the fitted model.
175-
176- Raises
177- ------
178- ValueError
179- If the method has not been fitted (i.e., if n_groups, groups,
180- or _groups_ids attributes are missing).
181- AssertionError
182- If the number of features in X does not match the total number
183- of features in the grouped variables.
184- """
185- if (
186- self .n_groups is None
187- or not hasattr (self , "groups" )
188- or not hasattr (self , "_groups_ids" )
189- ):
190- raise ValueError (
191- "The class is not fitted. The fit method must be called"
192- " to set variable groups. If no grouping is needed,"
193- " call fit with groups=None"
194- )
195- if isinstance (X , pd .DataFrame ):
196- names = list (X .columns )
197- elif isinstance (X , np .ndarray ) and X .dtype .names is not None :
198- names = X .dtype .names
199- # transform Structured Array in pandas array for a better manipulation
200- X = pd .DataFrame (X )
201- elif isinstance (X , np .ndarray ):
202- names = None
203- else :
204- raise ValueError ("X should be a pandas dataframe or a numpy array." )
205- number_columns = X .shape [1 ]
206- for index_variables in self .groups .values ():
207- if type (index_variables [0 ]) is int or np .issubdtype (
208- type (index_variables [0 ]), int
209- ):
210- assert np .all (
211- np .array (index_variables , dtype = int ) < number_columns
212- ), "X does not correspond to the fitting data."
213- elif type (index_variables [0 ]) is str or np .issubdtype (
214- type (index_variables [0 ]), str
215- ):
216- assert np .all (
217- [name in names for name in index_variables ]
218- ), f"The array is missing at least one of the following columns { index_variables } ."
219- else :
220- raise InternalError (
221- "A problem with indexing has happened during the fit."
222- )
223- number_unique_feature_in_groups = np .unique (
224- np .concatenate ([values for values in self .groups .values ()])
225- ).shape [0 ]
226- if X .shape [1 ] != number_unique_feature_in_groups :
227- warnings .warn (
228- f"The number of features in X: { X .shape [1 ]} differs from the"
229- " number of features for which importance is computed: "
230- f"{ number_unique_feature_in_groups } "
231- )
232-
233- def _joblib_predict_one_group (self , X , group_id , group_key ):
129+ def _joblib_predict_one_group (self , X , features_group_id , features_group_key ):
234130 """
235131 Compute the predictions after perturbation of the data for a given
236132 group of variables. This function is parallelized.
@@ -244,13 +140,15 @@ def _joblib_predict_one_group(self, X, group_id, group_key):
244140 group_key: str, int
245141 The key of the group of variables. (parameter use for debugging)
246142 """
247- group_ids = self ._groups_ids [group_id ]
248- non_group_ids = np .delete (np .arange (X .shape [1 ]), group_ids )
143+ features_group_ids = self ._groups_ids [features_group_id ]
144+ non_features_group_ids = np .delete (np .arange (X .shape [1 ]), features_group_ids )
249145 # Create an array X_perm_j of shape (n_permutations, n_samples, n_features)
250146 # where the j-th group of covariates is permuted
251147 X_perm = np .empty ((self .n_permutations , X .shape [0 ], X .shape [1 ]))
252- X_perm [:, :, non_group_ids ] = np .delete (X , group_ids , axis = 1 )
253- X_perm [:, :, group_ids ] = self ._permutation (X , group_id = group_id )
148+ X_perm [:, :, non_features_group_ids ] = np .delete (X , features_group_ids , axis = 1 )
149+ X_perm [:, :, features_group_ids ] = self ._permutation (
150+ X , features_group_id = features_group_id
151+ )
254152 # Reshape X_perm to allow for batch prediction
255153 X_perm_batch = X_perm .reshape (- 1 , X .shape [1 ])
256154 y_pred_perm = getattr (self .estimator , self .method )(X_perm_batch )
@@ -264,6 +162,6 @@ def _joblib_predict_one_group(self, X, group_id, group_key):
264162 )
265163 return y_pred_perm
266164
267- def _permutation (self , X , group_id ):
165+ def _permutation (self , X , features_group_id ):
268166 """Method for creating the permuted data for the j-th group of covariates."""
269167 raise NotImplementedError
0 commit comments