diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index b025659e2..c8f755280 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -120,11 +120,9 @@ def predict(self, X): # Parallelize the computation of the importance scores for each group out_list = Parallel(n_jobs=self.n_jobs)( delayed(self._joblib_predict_one_group)( - X_, group_id, group_key, random_state=child_state - ) - for group_id, (group_key, child_state) in enumerate( - zip(self.groups.keys(), rng.spawn(self.n_groups)) + X_, group_id, random_state=child_state ) + for group_id, child_state in enumerate(rng.spawn(self.n_groups)) ) return np.stack(out_list, axis=0) @@ -241,7 +239,7 @@ def _check_fit(self, X): f"{number_unique_feature_in_groups}" ) - def _joblib_predict_one_group(self, X, group_id, group_key, random_state=None): + def _joblib_predict_one_group(self, X, group_id, random_state=None): """ Compute the predictions after perturbation of the data for a given group of variables. This function is parallelized. @@ -252,8 +250,6 @@ def _joblib_predict_one_group(self, X, group_id, group_key, random_state=None): The input samples. group_id: int The index of the group of variables. - group_key: str, int - The key of the group of variables. (parameter use for debugging) random_state: The random state to use for sampling. """ diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index a199c7928..94b993132 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -93,7 +93,7 @@ def _joblib_fit_one_group(self, estimator, X, y, key_groups): estimator.fit(X_minus_j, y) return estimator - def _joblib_predict_one_group(self, X, group_id, key_groups, random_state=None): + def _joblib_predict_one_group(self, X, group_id, random_state=None): """Predict the target variable after removing a group of covariates. Used in parallel.""" X_minus_j = np.delete(X, self._groups_ids[group_id], axis=1)