Skip to content

Commit f3a5e0a

Browse files
committed
separate the management of the group from base permutation
1 parent e958c3b commit f3a5e0a

File tree

6 files changed

+236
-177
lines changed

6 files changed

+236
-177
lines changed

src/hidimstat/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from .base_variable_importance import BaseVariableImportance
1+
from .base_variable_importance import (
2+
BaseVariableImportance,
3+
VariableImportanceFeatureGroup,
4+
)
25
from .base_perturbation import BasePerturbation
36
from .ensemble_clustered_inference import (
47
clustered_inference,

src/hidimstat/base_perturbation.py

Lines changed: 20 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import numpy as np
2-
import pandas as pd
32
from joblib import Parallel, delayed
43
from sklearn.base import check_is_fitted
54
from sklearn.metrics import root_mean_squared_error
6-
import warnings
75

86
from hidimstat._utils.utils import _check_vim_predict_method
9-
from hidimstat._utils.exception import InternalError
10-
from hidimstat.base_variable_importance import BaseVariableImportance
7+
from hidimstat.base_variable_importance import (
8+
BaseVariableImportance,
9+
VariableImportanceGroup,
10+
)
1111

1212

13-
class BasePerturbation(BaseVariableImportance):
13+
class BasePerturbation(BaseVariableImportance, VariableImportanceGroup):
1414
def __init__(
1515
self,
1616
estimator,
@@ -43,6 +43,7 @@ def __init__(
4343
The number of parallel jobs to run. Parallelization is done over the
4444
variables or groups of variables.
4545
"""
46+
super().__init__()
4647
check_is_fitted(estimator)
4748
assert n_permutations > 0, "n_permutations must be positive"
4849
self.estimator = estimator
@@ -51,45 +52,6 @@ def __init__(
5152
self.method = method
5253
self.n_jobs = n_jobs
5354
self.n_permutations = n_permutations
54-
self.n_groups = None
55-
56-
def fit(self, X, y=None, groups=None):
57-
"""Base fit method for perturbation-based methods. Identifies the groups.
58-
59-
Parameters
60-
----------
61-
X: array-like of shape (n_samples, n_features)
62-
The input samples.
63-
y: array-like of shape (n_samples,)
64-
Not used, only present for consistency with the sklearn API.
65-
groups: dict, optional
66-
A dictionary where the keys are the group names and the values are the
67-
list of column names corresponding to each group. If None, the groups are
68-
identified based on the columns of X.
69-
"""
70-
if groups is None:
71-
self.n_groups = X.shape[1]
72-
self.groups = {j: [j] for j in range(self.n_groups)}
73-
self._groups_ids = np.array(list(self.groups.values()), dtype=int)
74-
elif isinstance(groups, dict):
75-
self.n_groups = len(groups)
76-
self.groups = groups
77-
if isinstance(X, pd.DataFrame):
78-
self._groups_ids = []
79-
for group_key in self.groups.keys():
80-
self._groups_ids.append(
81-
[
82-
i
83-
for i, col in enumerate(X.columns)
84-
if col in self.groups[group_key]
85-
]
86-
)
87-
else:
88-
self._groups_ids = [
89-
np.array(ids, dtype=int) for ids in list(self.groups.values())
90-
]
91-
else:
92-
raise ValueError("groups needs to be a dictionnary")
9355

9456
def predict(self, X):
9557
"""
@@ -111,8 +73,12 @@ def predict(self, X):
11173

11274
# Parallelize the computation of the importance scores for each group
11375
out_list = Parallel(n_jobs=self.n_jobs)(
114-
delayed(self._joblib_predict_one_group)(X_, group_id, group_key)
115-
for group_id, group_key in enumerate(self.groups.keys())
76+
delayed(self._joblib_predict_one_features_group)(
77+
X_, features_group_id, features_group_key
78+
)
79+
for features_group_id, features_group_key in enumerate(
80+
self.features_groups.keys()
81+
)
11682
)
11783
return np.stack(out_list, axis=0)
11884

@@ -160,77 +126,7 @@ def importance(self, X, y):
160126
)
161127
return out_dict
162128

163-
def _check_fit(self, X):
164-
"""
165-
Check if the perturbation method has been properly fitted.
166-
167-
This method verifies that the perturbation method has been fitted by checking
168-
if required attributes are set and if the number of features matches
169-
the grouped variables.
170-
171-
Parameters
172-
----------
173-
X : array-like of shape (n_samples, n_features)
174-
Input data to validate against the fitted model.
175-
176-
Raises
177-
------
178-
ValueError
179-
If the method has not been fitted (i.e., if n_groups, groups,
180-
or _groups_ids attributes are missing).
181-
AssertionError
182-
If the number of features in X does not match the total number
183-
of features in the grouped variables.
184-
"""
185-
if (
186-
self.n_groups is None
187-
or not hasattr(self, "groups")
188-
or not hasattr(self, "_groups_ids")
189-
):
190-
raise ValueError(
191-
"The class is not fitted. The fit method must be called"
192-
" to set variable groups. If no grouping is needed,"
193-
" call fit with groups=None"
194-
)
195-
if isinstance(X, pd.DataFrame):
196-
names = list(X.columns)
197-
elif isinstance(X, np.ndarray) and X.dtype.names is not None:
198-
names = X.dtype.names
199-
# transform Structured Array in pandas array for a better manipulation
200-
X = pd.DataFrame(X)
201-
elif isinstance(X, np.ndarray):
202-
names = None
203-
else:
204-
raise ValueError("X should be a pandas dataframe or a numpy array.")
205-
number_columns = X.shape[1]
206-
for index_variables in self.groups.values():
207-
if type(index_variables[0]) is int or np.issubdtype(
208-
type(index_variables[0]), int
209-
):
210-
assert np.all(
211-
np.array(index_variables, dtype=int) < number_columns
212-
), "X does not correspond to the fitting data."
213-
elif type(index_variables[0]) is str or np.issubdtype(
214-
type(index_variables[0]), str
215-
):
216-
assert np.all(
217-
[name in names for name in index_variables]
218-
), f"The array is missing at least one of the following columns {index_variables}."
219-
else:
220-
raise InternalError(
221-
"A problem with indexing has happened during the fit."
222-
)
223-
number_unique_feature_in_groups = np.unique(
224-
np.concatenate([values for values in self.groups.values()])
225-
).shape[0]
226-
if X.shape[1] != number_unique_feature_in_groups:
227-
warnings.warn(
228-
f"The number of features in X: {X.shape[1]} differs from the"
229-
" number of features for which importance is computed: "
230-
f"{number_unique_feature_in_groups}"
231-
)
232-
233-
def _joblib_predict_one_group(self, X, group_id, group_key):
129+
def _joblib_predict_one_group(self, X, features_group_id, features_group_key):
234130
"""
235131
Compute the predictions after perturbation of the data for a given
236132
group of variables. This function is parallelized.
@@ -244,13 +140,15 @@ def _joblib_predict_one_group(self, X, group_id, group_key):
244140
group_key: str, int
245141
The key of the group of variables. (parameter use for debugging)
246142
"""
247-
group_ids = self._groups_ids[group_id]
248-
non_group_ids = np.delete(np.arange(X.shape[1]), group_ids)
143+
features_group_ids = self._groups_ids[features_group_id]
144+
non_features_group_ids = np.delete(np.arange(X.shape[1]), features_group_ids)
249145
# Create an array X_perm_j of shape (n_permutations, n_samples, n_features)
250146
# where the j-th group of covariates is permuted
251147
X_perm = np.empty((self.n_permutations, X.shape[0], X.shape[1]))
252-
X_perm[:, :, non_group_ids] = np.delete(X, group_ids, axis=1)
253-
X_perm[:, :, group_ids] = self._permutation(X, group_id=group_id)
148+
X_perm[:, :, non_features_group_ids] = np.delete(X, features_group_ids, axis=1)
149+
X_perm[:, :, features_group_ids] = self._permutation(
150+
X, features_group_id=features_group_id
151+
)
254152
# Reshape X_perm to allow for batch prediction
255153
X_perm_batch = X_perm.reshape(-1, X.shape[1])
256154
y_pred_perm = getattr(self.estimator, self.method)(X_perm_batch)
@@ -264,6 +162,6 @@ def _joblib_predict_one_group(self, X, group_id, group_key):
264162
)
265163
return y_pred_perm
266164

267-
def _permutation(self, X, group_id):
165+
def _permutation(self, X, features_group_id):
268166
"""Method for creating the permuted data for the j-th group of covariates."""
269167
raise NotImplementedError

src/hidimstat/base_variable_importance.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
from sklearn.base import BaseEstimator
44
import numpy as np
5+
import pandas as pd
6+
7+
from hidimstat._utils.exception import InternalError
58

69

710
class BaseVariableImportance(BaseEstimator):
@@ -131,3 +134,146 @@ def _check_importance(self):
131134
raise ValueError(
132135
"The importances need to be called before calling this method"
133136
)
137+
138+
139+
class VariableImportanceFeatureGroup:
140+
"""
141+
Base class for variable importance methods using feature groups.
142+
This class extends `BaseVariableImportance` to support variable importance
143+
methods that operate on groups of features, enabling group-wise selection
144+
and importance evaluation.
145+
146+
Attributes
147+
----------
148+
n_features_groups : int, default=None
149+
The number of feature groups.
150+
features_groups : dict, default=None
151+
A dictionary mapping group names or indices to lists of feature indices or names.
152+
_features_groups_ids : array-like of shape (n_features_groups,), default=None
153+
Internal representation of group indices for each group.
154+
155+
Methods
156+
-------
157+
fit(X, y=None, groups=None)
158+
Identifies and stores feature groups based on input or provided grouping.
159+
_check_fit(X)
160+
Checks if the class has been fitted and validates group-feature correspondence.
161+
"""
162+
163+
def __init__(self):
164+
super().__init__()
165+
self.n_features_groups = None
166+
self.features_groups = None
167+
self._features_groups_ids = None
168+
169+
def fit(self, X, y=None, features_groups=None):
170+
"""
171+
Base fit method for perturbation-based methods. Identifies the groups.
172+
173+
Parameters
174+
----------
175+
X: array-like of shape (n_samples, n_features)
176+
The input samples.
177+
y: array-like of shape (n_samples,)
178+
Not used, only present for consistency with the sklearn API.
179+
features_groups: dict, optional
180+
A dictionary where the keys are the group names and the values are the
181+
list of column names corresponding to each group. If None, the groups are
182+
identified based on the columns of X.
183+
"""
184+
if features_groups is None:
185+
self.n_features_groups = X.shape[1]
186+
self.features_groups = {j: [j] for j in range(self.n_features_groups)}
187+
self._features_groups_ids = np.array(
188+
list(self.features_groups.values()), dtype=int
189+
)
190+
elif isinstance(features_groups, dict):
191+
self.n_features_groups = len(features_groups)
192+
self.features_groups = features_groups
193+
if isinstance(X, pd.DataFrame):
194+
self._features_groups_ids = []
195+
for features_group_key in self.features_groups.keys():
196+
self._features_groups_ids.append(
197+
[
198+
i
199+
for i, col in enumerate(X.columns)
200+
if col in self.features_groups[features_group_key]
201+
]
202+
)
203+
else:
204+
self._features_groups_ids = [
205+
np.array(ids, dtype=int)
206+
for ids in list(self.features_groups.values())
207+
]
208+
else:
209+
raise ValueError("features_groups needs to be a dictionnary")
210+
211+
def _check_fit(self, X):
212+
"""
213+
Check if the perturbation method has been properly fitted.
214+
215+
This method verifies that the perturbation method has been fitted by checking
216+
if required attributes are set and if the number of features matches
217+
the feature grouped variables.
218+
219+
Parameters
220+
----------
221+
X : array-like of shape (n_samples, n_features)
222+
Input data to validate against the fitted model.
223+
224+
Raises
225+
------
226+
ValueError
227+
If the method has not been fitted (i.e., if n_features_groups, features_groups,
228+
or _features_groups_ids attributes are missing).
229+
AssertionError
230+
If the number of features in X does not match the total number
231+
of features in the grouped variables.
232+
"""
233+
if (
234+
self.n_features_groups is None
235+
or not hasattr(self, "features_groups")
236+
or not hasattr(self, "_features_groups_ids")
237+
):
238+
raise ValueError(
239+
"The class is not fitted. The fit method must be called"
240+
" to set variable features_groups. If no grouping is needed,"
241+
" call fit with features_groups=None"
242+
)
243+
if isinstance(X, pd.DataFrame):
244+
names = list(X.columns)
245+
elif isinstance(X, np.ndarray) and X.dtype.names is not None:
246+
names = X.dtype.names
247+
# transform Structured Array in pandas array for a better manipulation
248+
X = pd.DataFrame(X)
249+
elif isinstance(X, np.ndarray):
250+
names = None
251+
else:
252+
raise ValueError("X should be a pandas dataframe or a numpy array.")
253+
number_columns = X.shape[1]
254+
for index_variables in self.features_groups.values():
255+
if type(index_variables[0]) is int or np.issubdtype(
256+
type(index_variables[0]), int
257+
):
258+
assert np.all(
259+
np.array(index_variables, dtype=int) < number_columns
260+
), "X does not correspond to the fitting data."
261+
elif type(index_variables[0]) is str or np.issubdtype(
262+
type(index_variables[0]), str
263+
):
264+
assert np.all(
265+
[name in names for name in index_variables]
266+
), f"The array is missing at least one of the following columns {index_variables}."
267+
else:
268+
raise InternalError(
269+
"A problem with indexing has happened during the fit."
270+
)
271+
number_unique_feature_in_groups = np.unique(
272+
np.concatenate([values for values in self.features_groups.values()])
273+
).shape[0]
274+
if X.shape[1] != number_unique_feature_in_groups:
275+
warnings.warn(
276+
f"The number of features in X: {X.shape[1]} differs from the"
277+
" number of features for which importance is computed: "
278+
f"{number_unique_feature_in_groups}"
279+
)

0 commit comments

Comments
 (0)