From d71422bb56c82c16a6072163ba407d21ffbf7bec Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Fri, 3 Oct 2025 11:20:51 +0200 Subject: [PATCH 1/2] remove the parameter group_key of _joblib_predict_one_group --- src/hidimstat/base_perturbation.py | 4 +--- src/hidimstat/leave_one_covariate_out.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index b025659e2..c3e3cc8b6 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -241,7 +241,7 @@ def _check_fit(self, X): f"{number_unique_feature_in_groups}" ) - def _joblib_predict_one_group(self, X, group_id, group_key, random_state=None): + def _joblib_predict_one_group(self, X, group_id, random_state=None): """ Compute the predictions after perturbation of the data for a given group of variables. This function is parallelized. @@ -252,8 +252,6 @@ def _joblib_predict_one_group(self, X, group_id, group_key, random_state=None): The input samples. group_id: int The index of the group of variables. - group_key: str, int - The key of the group of variables. (parameter use for debugging) random_state: The random state to use for sampling. """ diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index a199c7928..94b993132 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -93,7 +93,7 @@ def _joblib_fit_one_group(self, estimator, X, y, key_groups): estimator.fit(X_minus_j, y) return estimator - def _joblib_predict_one_group(self, X, group_id, key_groups, random_state=None): + def _joblib_predict_one_group(self, X, group_id, random_state=None): """Predict the target variable after removing a group of covariates. Used in parallel.""" X_minus_j = np.delete(X, self._groups_ids[group_id], axis=1) From 7f9c65141ad98e3474cfb8f2dde49457729b8f55 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Fri, 3 Oct 2025 11:34:36 +0200 Subject: [PATCH 2/2] fix error --- src/hidimstat/base_perturbation.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index c3e3cc8b6..c8f755280 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -120,11 +120,9 @@ def predict(self, X): # Parallelize the computation of the importance scores for each group out_list = Parallel(n_jobs=self.n_jobs)( delayed(self._joblib_predict_one_group)( - X_, group_id, group_key, random_state=child_state - ) - for group_id, (group_key, child_state) in enumerate( - zip(self.groups.keys(), rng.spawn(self.n_groups)) + X_, group_id, random_state=child_state ) + for group_id, child_state in enumerate(rng.spawn(self.n_groups)) ) return np.stack(out_list, axis=0)