Skip to content

Commit 508ce53

Browse files
committed
fix some error
1 parent f3a5e0a commit 508ce53

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

src/hidimstat/base_perturbation.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from hidimstat._utils.utils import _check_vim_predict_method
77
from hidimstat.base_variable_importance import (
88
BaseVariableImportance,
9-
VariableImportanceGroup,
9+
VariableImportanceFeatureGroup,
1010
)
1111

1212

13-
class BasePerturbation(BaseVariableImportance, VariableImportanceGroup):
13+
class BasePerturbation(BaseVariableImportance, VariableImportanceFeatureGroup):
1414
def __init__(
1515
self,
1616
estimator,
@@ -121,12 +121,14 @@ def importance(self, X, y):
121121
out_dict["importance"] = np.array(
122122
[
123123
np.mean(out_dict["loss"][j]) - loss_reference
124-
for j in range(self.n_groups)
124+
for j in range(self.n_features_groups)
125125
]
126126
)
127127
return out_dict
128128

129-
def _joblib_predict_one_group(self, X, features_group_id, features_group_key):
129+
def _joblib_predict_one_features_group(
130+
self, X, features_group_id, features_group_key
131+
):
130132
"""
131133
Compute the predictions after perturbation of the data for a given
132134
group of variables. This function is parallelized.
@@ -140,7 +142,7 @@ def _joblib_predict_one_group(self, X, features_group_id, features_group_key):
140142
group_key: str, int
141143
The key of the group of variables. (parameter use for debugging)
142144
"""
143-
features_group_ids = self._groups_ids[features_group_id]
145+
features_group_ids = self._features_groups_ids[features_group_id]
144146
non_features_group_ids = np.delete(np.arange(X.shape[1]), features_group_ids)
145147
# Create an array X_perm_j of shape (n_permutations, n_samples, n_features)
146148
# where the j-th group of covariates is permuted

0 commit comments

Comments
 (0)