From df93c7819fa2109ad06f6c556b265f917b54e44e Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Tue, 2 Sep 2025 16:23:48 +0200 Subject: [PATCH 01/80] New API for CFI, PFI, LOCO --- src/hidimstat/base_perturbation.py | 62 +++++++----- .../conditional_feature_importance.py | 99 ++++++++++--------- src/hidimstat/conditional_sampling.py | 8 +- src/hidimstat/leave_one_covariate_out.py | 76 +++++++------- .../permutation_feature_importance.py | 70 ++++++------- 5 files changed, 168 insertions(+), 147 deletions(-) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index ef3c58343..741c527cd 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -14,9 +14,9 @@ class BasePerturbation(BaseVariableImportance): def __init__( self, estimator, + method: str = "predict", loss: callable = root_mean_squared_error, n_permutations: int = 50, - method: str = "predict", n_jobs: int = 1, ): """ @@ -27,6 +27,10 @@ def __init__( ---------- estimator : sklearn compatible estimator, optional The estimator to use for the prediction. + method : str, default="predict" + The method used for making predictions. This determines the predictions + passed to the loss function. Supported methods are "predict", + "predict_proba", "decision_function", "transform". loss : callable, default=root_mean_squared_error The function to compute the loss when comparing the perturbed model to the original model. @@ -35,10 +39,6 @@ def __init__( Specifies the number of times the variable group (residual for CFI) is permuted. For each permutation, the perturbed model's loss is calculated and averaged over all permutations. - method : str, default="predict" - The method used for making predictions. This determines the predictions - passed to the loss function. Supported methods are "predict", - "predict_proba", "decision_function", "transform". n_jobs : int, default=1 The number of parallel jobs to run. Parallelization is done over the variables or groups of variables. @@ -50,9 +50,16 @@ def __init__( self.loss = loss _check_vim_predict_method(method) self.method = method - self.n_jobs = n_jobs self.n_permutations = n_permutations - self.n_groups = None + self.n_jobs = n_jobs + # variable set in fit + self.groups = None + # varaible set in importance + self.loss_reference_ = None + self.loss_ = None + # internal variables + self._n_groups = None + self._groups_ids = None def fit(self, X, y=None, groups=None): """Base fit method for perturbation-based methods. Identifies the groups. @@ -69,11 +76,11 @@ def fit(self, X, y=None, groups=None): identified based on the columns of X. """ if groups is None: - self.n_groups = X.shape[1] - self.groups = {j: [j] for j in range(self.n_groups)} + self._n_groups = X.shape[1] + self.groups = {j: [j] for j in range(self._n_groups)} self._groups_ids = np.array(list(self.groups.values()), dtype=int) elif isinstance(groups, dict): - self.n_groups = len(groups) + self._n_groups = len(groups) self.groups = groups if isinstance(X, pd.DataFrame): self._groups_ids = [] @@ -91,6 +98,7 @@ def fit(self, X, y=None, groups=None): ] else: raise ValueError("groups needs to be a dictionnary") + return self def predict(self, X): """ @@ -139,27 +147,25 @@ def importance(self, X, y): """ self._check_fit(X) - out_dict = dict() - y_pred = getattr(self.estimator, self.method)(X) - loss_reference = self.loss(y, y_pred) - out_dict["loss_reference"] = loss_reference + self.loss_reference_ = self.loss(y, y_pred) y_pred = self.predict(X) - out_dict["loss"] = dict() + self.loss_ = dict() for j, y_pred_j in enumerate(y_pred): list_loss = [] for y_pred_perm in y_pred_j: list_loss.append(self.loss(y, y_pred_perm)) - out_dict["loss"][j] = np.array(list_loss) + self.loss_[j] = np.array(list_loss) - out_dict["importance"] = np.array( + self.importances_ = np.array( [ - np.mean(out_dict["loss"][j]) - loss_reference - for j in range(self.n_groups) + np.mean(self.loss_[j]) - self.loss_reference_ + for j in range(self._n_groups) ] ) - return out_dict + self.pvalues_ = None + return self.importances_ def _check_fit(self, X): """ @@ -183,11 +189,7 @@ def _check_fit(self, X): If the number of features in X does not match the total number of features in the grouped variables. """ - if ( - self.n_groups is None - or not hasattr(self, "groups") - or not hasattr(self, "_groups_ids") - ): + if self._n_groups is None or self.groups is None or self._groups_ids is None: raise ValueError( "The class is not fitted. The fit method must be called" " to set variable groups. If no grouping is needed," @@ -231,6 +233,16 @@ def _check_fit(self, X): f"{number_unique_feature_in_groups}" ) + def _check_importance(self): + """ + Checks if the loss have been computed. + """ + super()._check_importance() + if self.loss_reference_ is None or self.loss_ is None: + raise ValueError( + "The importances need to be called before calling this method" + ) + def _joblib_predict_one_group(self, X, group_id, group_key): """ Compute the predictions after perturbation of the data for a given diff --git a/src/hidimstat/conditional_feature_importance.py b/src/hidimstat/conditional_feature_importance.py index 9b0e7905f..0f433ea4c 100644 --- a/src/hidimstat/conditional_feature_importance.py +++ b/src/hidimstat/conditional_feature_importance.py @@ -9,63 +9,65 @@ class CFI(BasePerturbation): + """ + Conditional Feature Importance (CFI) algorithm. + :footcite:t:`Chamma_NeurIPS2023` and for group-level see + :footcite:t:`Chamma_AAAI2024`. + + Parameters + ---------- + estimator : sklearn compatible estimator, optional + The estimator to use for the prediction. + method : str, default="predict" + The method to use for the prediction. This determines the predictions passed + to the loss function. Supported methods are "predict", "predict_proba", + "decision_function", "transform". + loss : callable, default=root_mean_squared_error + The loss function to use when comparing the perturbed model to the full + model. + n_permutations : int, default=50 + The number of permutations to perform. For each variable/group of variables, + the mean of the losses over the `n_permutations` is computed. + imputation_model_continuous : sklearn compatible estimator, optional + The model used to estimate the conditional distribution of a given + continuous variable/group of variables given the others. + imputation_model_categorical : sklearn compatible estimator, optional + The model used to estimate the conditional distribution of a given + categorical variable/group of variables given the others. Binary is + considered as a special case of categorical. + categorical_max_cardinality : int, default=10 + The maximum cardinality of a variable to be considered as categorical + when the variable type is inferred (set to "auto" or not provided). + random_state : int, default=None + The random state to use for sampling. + n_jobs : int, default=1 + The number of jobs to run in parallel. Parallelization is done over the + variables or groups of variables. + + References + ---------- + .. footbibliography:: + """ + def __init__( self, estimator, - loss: callable = root_mean_squared_error, method: str = "predict", - n_jobs: int = 1, + loss: callable = root_mean_squared_error, n_permutations: int = 50, imputation_model_continuous=None, imputation_model_categorical=None, - random_state: int = None, categorical_max_cardinality: int = 10, + random_state: int = None, + n_jobs: int = 1, ): - """ - Conditional Feature Importance (CFI) algorithm. - :footcite:t:`Chamma_NeurIPS2023` and for group-level see - :footcite:t:`Chamma_AAAI2024`. - Parameters - ---------- - estimator : sklearn compatible estimator, optional - The estimator to use for the prediction. - loss : callable, default=root_mean_squared_error - The loss function to use when comparing the perturbed model to the full - model. - method : str, default="predict" - The method to use for the prediction. This determines the predictions passed - to the loss function. Supported methods are "predict", "predict_proba", - "decision_function", "transform". - n_jobs : int, default=1 - The number of jobs to run in parallel. Parallelization is done over the - variables or groups of variables. - n_permutations : int, default=50 - The number of permutations to perform. For each variable/group of variables, - the mean of the losses over the `n_permutations` is computed. - imputation_model_continuous : sklearn compatible estimator, optional - The model used to estimate the conditional distribution of a given - continuous variable/group of variables given the others. - imputation_model_categorical : sklearn compatible estimator, optional - The model used to estimate the conditional distribution of a given - categorical variable/group of variables given the others. Binary is - considered as a special case of categorical. - random_state : int, default=None - The random state to use for sampling. - categorical_max_cardinality : int, default=10 - The maximum cardinality of a variable to be considered as categorical - when the variable type is inferred (set to "auto" or not provided). - - References - ---------- - .. footbibliography:: - """ super().__init__( estimator=estimator, - loss=loss, method=method, - n_jobs=n_jobs, + loss=loss, n_permutations=n_permutations, + n_jobs=n_jobs, ) # check the validity of the inputs @@ -83,7 +85,8 @@ def __init__( self.random_state = random_state def fit(self, X, y=None, groups=None, var_type="auto"): - """Fit the imputation models. + """ + Fit the imputation models. Parameters ---------- @@ -107,13 +110,13 @@ def fit(self, X, y=None, groups=None, var_type="auto"): self.random_state = check_random_state(self.random_state) super().fit(X, None, groups=groups) if isinstance(var_type, str): - self.var_type = [var_type for _ in range(self.n_groups)] + var_type = [var_type for _ in range(self._n_groups)] else: - self.var_type = var_type + var_type = var_type self._list_imputation_models = [ ConditionalSampler( - data_type=self.var_type[groupd_id], + data_type=var_type[groupd_id], model_regression=( None if self.imputation_model_continuous is None @@ -127,7 +130,7 @@ def fit(self, X, y=None, groups=None, var_type="auto"): random_state=self.random_state, categorical_max_cardinality=self.categorical_max_cardinality, ) - for groupd_id in range(self.n_groups) + for groupd_id in range(self._n_groups) ] # Parallelize the fitting of the covariate estimators diff --git a/src/hidimstat/conditional_sampling.py b/src/hidimstat/conditional_sampling.py index f8920581a..609c7ea5d 100644 --- a/src/hidimstat/conditional_sampling.py +++ b/src/hidimstat/conditional_sampling.py @@ -45,8 +45,8 @@ def __init__( model_regression=None, model_categorical=None, data_type: str = "auto", - random_state=None, categorical_max_cardinality=10, + random_state=None, ): """ Class use to sample from the conditional distribution $p(X^j | X^{-j})$. @@ -62,11 +62,11 @@ def __init__( The variable type. Supported types include "auto", "continuous", and "categorical". If "auto", the type is inferred from the cardinality of the unique values passed to the `fit` method. - random_state : int, optional - The random state to use for sampling. categorical_max_cardinality : int, default=10 The maximum cardinality of a variable to be considered as categorical when `data_type` is "auto". + random_state : int, optional + The random state to use for sampling. """ # check the validity of the inputs @@ -79,8 +79,8 @@ def __init__( self.data_type = data_type self.model_regression = model_regression self.model_categorical = model_categorical - self.rng = check_random_state(random_state) self.categorical_max_cardinality = categorical_max_cardinality + self.rng = check_random_state(random_state) def fit(self, X: np.ndarray, y: np.ndarray): r""" diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index c9c64c464..8e946a55a 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -8,55 +8,59 @@ class LOCO(BasePerturbation): + """ + Leave-One-Covariate-Out (LOCO) as presented in + :footcite:t:`lei2018distribution` and :footcite:t:`verdinelli2024feature`. + The model is re-fitted for each variable/group of variables. The importance is + then computed as the difference between the loss of the full model and the loss + of the model without the variable/group. + + Parameters + ---------- + estimator : sklearn compatible estimator, optional + The estimator to use for the prediction. + method : str, default="predict" + The method to use for the prediction. This determines the predictions passed + to the loss function. Supported methods are "predict", "predict_proba", + "decision_function", "transform". + loss : callable, default=root_mean_squared_error + The loss function to use when comparing the perturbed model to the full + model. + n_jobs : int, default=1 + The number of jobs to run in parallel. Parallelization is done over the + variables or groups of variables. + + Notes + ----- + :footcite:t:`Williamson_General_2023` also presented a LOCO method with an + additional data splitting strategy. + + References + ---------- + .. footbibliography:: + """ + def __init__( self, estimator, - loss: callable = root_mean_squared_error, method: str = "predict", + loss: callable = root_mean_squared_error, n_jobs: int = 1, ): - """ - Leave-One-Covariate-Out (LOCO) as presented in - :footcite:t:`lei2018distribution` and :footcite:t:`verdinelli2024feature`. - The model is re-fitted for each variable/group of variables. The importance is - then computed as the difference between the loss of the full model and the loss - of the model without the variable/group. - Parameters - ---------- - estimator : sklearn compatible estimator, optional - The estimator to use for the prediction. - loss : callable, default=root_mean_squared_error - The loss function to use when comparing the perturbed model to the full - model. - method : str, default="predict" - The method to use for the prediction. This determines the predictions passed - to the loss function. Supported methods are "predict", "predict_proba", - "decision_function", "transform". - n_jobs : int, default=1 - The number of jobs to run in parallel. Parallelization is done over the - variables or groups of variables. - - Notes - ----- - :footcite:t:`Williamson_General_2023` also presented a LOCO method with an - additional data splitting strategy. - - References - ---------- - .. footbibliography:: - """ super().__init__( estimator=estimator, - loss=loss, method=method, - n_jobs=n_jobs, + loss=loss, n_permutations=1, + n_jobs=n_jobs, ) + # internal variable self._list_estimators = [] def fit(self, X, y, groups=None): - """Fit a model after removing each covariate/group of covariates. + """ + Fit a model after removing each covariate/group of covariates. Parameters ---------- @@ -75,7 +79,7 @@ def fit(self, X, y, groups=None): """ super().fit(X, y, groups) # create a list of covariate estimators for each group if not provided - self._list_estimators = [clone(self.estimator) for _ in range(self.n_groups)] + self._list_estimators = [clone(self.estimator) for _ in range(self._n_groups)] # Parallelize the fitting of the covariate estimators self._list_estimators = Parallel(n_jobs=self.n_jobs)( @@ -93,7 +97,7 @@ def _joblib_fit_one_group(self, estimator, X, y, key_groups): estimator.fit(X_minus_j, y) return estimator - def _joblib_predict_one_group(self, X, group_id, key_groups): + def _joblib_predict_one_group(self, X, group_id, group_key): """Predict the target variable after removing a group of covariates. Used in parallel.""" X_minus_j = np.delete(X, self._groups_ids[group_id], axis=1) diff --git a/src/hidimstat/permutation_feature_importance.py b/src/hidimstat/permutation_feature_importance.py index 29d007656..e6fa08a44 100644 --- a/src/hidimstat/permutation_feature_importance.py +++ b/src/hidimstat/permutation_feature_importance.py @@ -6,52 +6,54 @@ class PFI(BasePerturbation): + """ + Permutation Feature Importance algorithm as presented in + :footcite:t:`breimanRandomForests2001`. For each variable/group of variables, + the importance is computed as the difference between the loss of the initial + model and the loss of the model with the variable/group permuted. + The method was also used in :footcite:t:`mi2021permutation` + + Parameters + ---------- + estimator : sklearn compatible estimator, optionals + The estimator to use for the prediction. + method : str, default="predict" + The method to use for the prediction. This determines the predictions passed + to the loss function. Supported methods are "predict", "predict_proba", + "decision_function", "transform". + loss : callable, default=root_mean_squared_error + The loss function to use when comparing the perturbed model to the full + model. + n_permutations : int, default=50 + The number of permutations to perform. For each variable/group of variables, + the mean of the losses over the `n_permutations` is computed. + random_state : int, default=None + The random state to use for sampling. + n_jobs : int, default=1 + The number of jobs to run in parallel. Parallelization is done over the + variables or groups of variables. + + References + ---------- + .. footbibliography:: + """ + def __init__( self, estimator, - loss: callable = root_mean_squared_error, method: str = "predict", - n_jobs: int = 1, + loss: callable = root_mean_squared_error, n_permutations: int = 50, random_state: int = None, + n_jobs: int = 1, ): - """ - Permutation Feature Importance algorithm as presented in - :footcite:t:`breimanRandomForests2001`. For each variable/group of variables, - the importance is computed as the difference between the loss of the initial - model and the loss of the model with the variable/group permuted. - The method was also used in :footcite:t:`mi2021permutation` - - Parameters - ---------- - estimator : sklearn compatible estimator, optionals - The estimator to use for the prediction. - loss : callable, default=root_mean_squared_error - The loss function to use when comparing the perturbed model to the full - model. - method : str, default="predict" - The method to use for the prediction. This determines the predictions passed - to the loss function. Supported methods are "predict", "predict_proba", - "decision_function", "transform". - n_jobs : int, default=1 - The number of jobs to run in parallel. Parallelization is done over the - variables or groups of variables. - n_permutations : int, default=50 - The number of permutations to perform. For each variable/group of variables, - the mean of the losses over the `n_permutations` is computed. - random_state : int, default=None - The random state to use for sampling. - References - ---------- - .. footbibliography:: - """ super().__init__( estimator=estimator, - loss=loss, method=method, - n_jobs=n_jobs, + loss=loss, n_permutations=n_permutations, + n_jobs=n_jobs, ) self.random_state = random_state From ccb60ed4b5229e96aea3e337cea3e3bdd0e9f2a3 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Tue, 2 Sep 2025 16:29:54 +0200 Subject: [PATCH 02/80] fix test for new API --- test/test_conditional_feature_importance.py | 15 ++++++--------- test/test_leave_one_covariate_out.py | 9 +++------ test/test_permutation_feature_importance.py | 9 +++------ 3 files changed, 12 insertions(+), 21 deletions(-) diff --git a/test/test_conditional_feature_importance.py b/test/test_conditional_feature_importance.py index f3bec9735..529cf891d 100644 --- a/test/test_conditional_feature_importance.py +++ b/test/test_conditional_feature_importance.py @@ -66,8 +66,7 @@ def run_cfi(X, y, n_permutation, seed): var_type="auto", ) # calculate feature importance using the test set - vim = cfi.importance(X_test, y_test) - importance = vim["importance"] + importance = cfi.importance(X_test, y_test) return importance @@ -199,9 +198,8 @@ def test_group(data_generator): ) # Warning expected since column names in pandas are not considered with pytest.warns(UserWarning, match="X does not have valid feature names, but"): - vim = cfi.importance(X_test_df, y_test) + importance = cfi.importance(X_test_df, y_test) - importance = vim["importance"] # Check if importance scores are computed for each feature assert importance.shape == (2,) # Verify that important feature group has higher score @@ -248,8 +246,7 @@ def test_classication(data_generator): groups=None, var_type=["continuous"] * X.shape[1], ) - vim = cfi.importance(X_test, y_test_clf) - importance = vim["importance"] + importance = cfi.importance(X_test, y_test_clf) # Check that importance scores are defined for each feature assert importance.shape == (X.shape[1],) # Check that important features have higher mean importance scores @@ -297,13 +294,13 @@ def test_fit(self, data_generator): # Test fit with auto var_type cfi.fit(X) assert len(cfi._list_imputation_models) == X.shape[1] - assert cfi.n_groups == X.shape[1] + assert cfi._n_groups == X.shape[1] # Test fit with specified groups groups = {"g1": [0, 1], "g2": [2, 3, 4]} cfi.fit(X, groups=groups) assert len(cfi._list_imputation_models) == 2 - assert cfi.n_groups == 2 + assert cfi._n_groups == 2 def test_categorical( self, @@ -334,7 +331,7 @@ def test_categorical( var_type = ["continuous", "continuous", "categorical"] cfi.fit(X, y, var_type=var_type) - importances = cfi.importance(X, y)["importance"] + importances = cfi.importance(X, y) assert len(importances) == 3 assert np.all(importances >= 0) diff --git a/test/test_leave_one_covariate_out.py b/test/test_leave_one_covariate_out.py index d8fd2a763..f6f4b2319 100644 --- a/test/test_leave_one_covariate_out.py +++ b/test/test_leave_one_covariate_out.py @@ -38,9 +38,8 @@ def test_loco(): y_train, groups=None, ) - vim = loco.importance(X_test, y_test) + importance = loco.importance(X_test, y_test) - importance = vim["importance"] assert importance.shape == (X.shape[1],) assert ( importance[important_features].mean() @@ -67,9 +66,8 @@ def test_loco(): ) # warnings because we doesn't considere the name of columns of pandas with pytest.warns(UserWarning, match="X does not have valid feature names, but"): - vim = loco.importance(X_test_df, y_test) + importance = loco.importance(X_test_df, y_test) - importance = vim["importance"] assert importance[0].mean() > importance[1].mean() # Classification case @@ -89,9 +87,8 @@ def test_loco(): y_train_clf, groups={"group_0": important_features, "the_group_1": non_important_features}, ) - vim_clf = loco_clf.importance(X_test, y_test_clf) + importance_clf = loco_clf.importance(X_test, y_test_clf) - importance_clf = vim_clf["importance"] assert importance_clf.shape == (2,) assert importance[0].mean() > importance[1].mean() diff --git a/test/test_permutation_feature_importance.py b/test/test_permutation_feature_importance.py index b9639f359..ee0a870c1 100644 --- a/test/test_permutation_feature_importance.py +++ b/test/test_permutation_feature_importance.py @@ -39,9 +39,8 @@ def test_permutation_importance(): y_train, groups=None, ) - vim = pfi.importance(X_test, y_test) + importance = pfi.importance(X_test, y_test) - importance = vim["importance"] assert importance.shape == (X.shape[1],) assert ( importance[important_features].mean() @@ -70,9 +69,8 @@ def test_permutation_importance(): ) # warnings because we doesn't considere the name of columns of pandas with pytest.warns(UserWarning, match="X does not have valid feature names, but"): - vim = pfi.importance(X_test_df, y_test) + importance = pfi.importance(X_test_df, y_test) - importance = vim["importance"] assert importance[0].mean() > importance[1].mean() # Classification case @@ -95,7 +93,6 @@ def test_permutation_importance(): y_train_clf, groups=None, ) - vim_clf = pfi_clf.importance(X_test, y_test_clf) + importance_clf = pfi_clf.importance(X_test, y_test_clf) - importance_clf = vim_clf["importance"] assert importance_clf.shape == (X.shape[1],) From 7c827ad57ee14269692847f1b01eb0d1b1f8615a Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Tue, 2 Sep 2025 16:51:42 +0200 Subject: [PATCH 03/80] fix example --- examples/plot_conditional_vs_marginal_xor_data.py | 2 +- .../plot_diabetes_variable_importance_example.py | 12 ++++++------ examples/plot_importance_classification_iris.py | 2 +- examples/plot_model_agnostic_importance.py | 6 ++---- examples/plot_pitfalls_permutation_importance.py | 4 ++-- src/hidimstat/leave_one_covariate_out.py | 5 +++-- src/hidimstat/permutation_feature_importance.py | 10 ++++++---- 7 files changed, 21 insertions(+), 20 deletions(-) diff --git a/examples/plot_conditional_vs_marginal_xor_data.py b/examples/plot_conditional_vs_marginal_xor_data.py index e9ea09ec9..55a095fc5 100644 --- a/examples/plot_conditional_vs_marginal_xor_data.py +++ b/examples/plot_conditional_vs_marginal_xor_data.py @@ -117,7 +117,7 @@ random_state=0, ) vim.fit(X_train, y_train) - importances.append(vim.importance(X_test, y_test)["importance"]) + importances.append(vim.importance(X_test, y_test)) importances = np.array(importances).T diff --git a/examples/plot_diabetes_variable_importance_example.py b/examples/plot_diabetes_variable_importance_example.py index 0340d9a3d..17e933802 100644 --- a/examples/plot_diabetes_variable_importance_example.py +++ b/examples/plot_diabetes_variable_importance_example.py @@ -184,14 +184,14 @@ def compute_pval(vim): # ------------------- -cfi_vim_arr = np.array([x["importance"] for x in cfi_importance_list]) / 2 +cfi_vim_arr = np.array(cfi_importance_list) / 2 cfi_pval = compute_pval(cfi_vim_arr) vim = [ pd.DataFrame( { "var": np.arange(cfi_vim_arr.shape[1]), - "importance": x["importance"], + "importance": x, "fold": i, "pval": cfi_pval, "method": "CFI", @@ -200,14 +200,14 @@ def compute_pval(vim): for x in cfi_importance_list ] -loco_vim_arr = np.array([x["importance"] for x in loco_importance_list]) +loco_vim_arr = np.array(loco_importance_list) loco_pval = compute_pval(loco_vim_arr) vim += [ pd.DataFrame( { "var": np.arange(loco_vim_arr.shape[1]), - "importance": x["importance"], + "importance": x, "fold": i, "pval": loco_pval, "method": "LOCO", @@ -216,14 +216,14 @@ def compute_pval(vim): for x in loco_importance_list ] -pfi_vim_arr = np.array([x["importance"] for x in pfi_importance_list]) +pfi_vim_arr = np.array(pfi_importance_list) pfi_pval = compute_pval(pfi_vim_arr) vim += [ pd.DataFrame( { "var": np.arange(pfi_vim_arr.shape[1]), - "importance": x["importance"], + "importance": x, "fold": i, "pval": pfi_pval, "method": "PFI", diff --git a/examples/plot_importance_classification_iris.py b/examples/plot_importance_classification_iris.py index eb92d7abf..9a2be2b72 100644 --- a/examples/plot_importance_classification_iris.py +++ b/examples/plot_importance_classification_iris.py @@ -93,7 +93,7 @@ def run_one_fold(X, y, model, train_index, test_index, vim_name="CFI", groups=No ) vim.fit(X[train_index], y[train_index], groups=groups) - importance = vim.importance(X[test_index], y[test_index])["importance"] + importance = vim.importance(X[test_index], y[test_index]) return pd.DataFrame( { diff --git a/examples/plot_model_agnostic_importance.py b/examples/plot_model_agnostic_importance.py index 9ace442d4..03f8ceded 100644 --- a/examples/plot_model_agnostic_importance.py +++ b/examples/plot_model_agnostic_importance.py @@ -108,10 +108,8 @@ vim_linear.fit(X[train], y[train]) vim_non_linear.fit(X[train], y[train]) - importances_linear.append(vim_linear.importance(X[test], y[test])["importance"]) - importances_non_linear.append( - vim_non_linear.importance(X[test], y[test])["importance"] - ) + importances_linear.append(vim_linear.importance(X[test], y[test])) + importances_non_linear.append(vim_non_linear.importance(X[test], y[test])) ################################################################################ diff --git a/examples/plot_pitfalls_permutation_importance.py b/examples/plot_pitfalls_permutation_importance.py index af4deb83e..e7a041d8e 100644 --- a/examples/plot_pitfalls_permutation_importance.py +++ b/examples/plot_pitfalls_permutation_importance.py @@ -132,7 +132,7 @@ ) pfi.fit(X_test, y_test) - permutation_importances.append(pfi.importance(X_test, y_test)["importance"]) + permutation_importances.append(pfi.importance(X_test, y_test)) permutation_importances = np.stack(permutation_importances) pval_pfi = ttest_1samp( permutation_importances, 0.0, axis=0, alternative="greater" @@ -200,7 +200,7 @@ ) cfi.fit(X_test, y_test) - conditional_importances.append(cfi.importance(X_test, y_test)["importance"]) + conditional_importances.append(cfi.importance(X_test, y_test)) cfi_pval = ttest_1samp( diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index 8e946a55a..5a865cb71 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -9,8 +9,9 @@ class LOCO(BasePerturbation): """ - Leave-One-Covariate-Out (LOCO) as presented in - :footcite:t:`lei2018distribution` and :footcite:t:`verdinelli2024feature`. + Leave-One-Covariate-Out (LOCO) algorithm + + This method is presented in :footcite:t:`lei2018distribution` and :footcite:t:`verdinelli2024feature`. The model is re-fitted for each variable/group of variables. The importance is then computed as the difference between the loss of the full model and the loss of the model without the variable/group. diff --git a/src/hidimstat/permutation_feature_importance.py b/src/hidimstat/permutation_feature_importance.py index e6fa08a44..10f18289f 100644 --- a/src/hidimstat/permutation_feature_importance.py +++ b/src/hidimstat/permutation_feature_importance.py @@ -7,10 +7,12 @@ class PFI(BasePerturbation): """ - Permutation Feature Importance algorithm as presented in - :footcite:t:`breimanRandomForests2001`. For each variable/group of variables, - the importance is computed as the difference between the loss of the initial - model and the loss of the model with the variable/group permuted. + Permutation Feature Importance algorithm + + This as presented in :footcite:t:`breimanRandomForests2001`. + For each variable/group of variables, the importance is computed as + the difference between the loss of the initial model and the loss of + the model with the variable/group permuted. The method was also used in :footcite:t:`mi2021permutation` Parameters From 82d61e6120eb55c83a280c3e4944e6ab1ef311f6 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Tue, 2 Sep 2025 17:23:02 +0200 Subject: [PATCH 04/80] add test for new check --- test/test_base_perturbation.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test/test_base_perturbation.py b/test/test_base_perturbation.py index dd3ff6d6c..9198cf72d 100644 --- a/test/test_base_perturbation.py +++ b/test/test_base_perturbation.py @@ -12,3 +12,16 @@ def test_no_implemented_methods(): basic_class = BasePerturbation(estimator=estimator) with pytest.raises(NotImplementedError): basic_class._permutation(X, group_id=None) + + +def test_chek_importance(): + """test that the methods are not implemented in the base class""" + X = np.random.randint(0, 2, size=(100, 2, 1)) + estimator = LinearRegression() + estimator.fit(X[:, 0], X[:, 1]) + basic_class = BasePerturbation(estimator=estimator) + basic_class.importances_ = [] + with pytest.raises( + ValueError, match="The importances need to be called before calling this method" + ): + basic_class.selection() From 28593e49e77b9310861b4dc6fca3e4931d687ea8 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Tue, 2 Sep 2025 18:13:55 +0200 Subject: [PATCH 05/80] add pvalue and fit_importance and function --- src/hidimstat/base_perturbation.py | 52 +++++++++++++- .../conditional_feature_importance.py | 70 +++++++++++++++++++ src/hidimstat/leave_one_covariate_out.py | 63 +++++++++++++++++ .../permutation_feature_importance.py | 62 ++++++++++++++++ 4 files changed, 244 insertions(+), 3 deletions(-) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index 741c527cd..fa6f7fde2 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -1,9 +1,12 @@ +import warnings + import numpy as np import pandas as pd from joblib import Parallel, delayed -from sklearn.base import check_is_fitted +from scipy.stats import ttest_1samp +from sklearn.base import check_is_fitted, clone from sklearn.metrics import root_mean_squared_error -import warnings +from sklearn.model_selection import KFold from hidimstat._utils.utils import _check_vim_predict_method from hidimstat._utils.exception import InternalError @@ -164,9 +167,52 @@ def importance(self, X, y): for j in range(self._n_groups) ] ) - self.pvalues_ = None + self.pvalues_ = ttest_1samp( + self.importances_, 0.0, axis=0, alternative="greater" + ).pvalue return self.importances_ + def fit_importance( + self, X, y, cv=KFold(n_splits=5, shuffle=True, random_state=0), **fit_kwargs + ): + """ + Compute feature importance scores using cross-validation. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Training data. + y : array-like of shape (n_samples,) + Target values. + cv : cross-validation generator or iterable, default=KFold(n_splits=5, shuffle=True, random_state=0) + Determines the cross-validation splitting strategy. + **fit_kwargs : dict + Additional arguments passed to the fit method during variable group identification. + + Returns + ------- + importances : float + Mean feature importance scores across CV folds. + + Notes + ----- + For each CV fold: + 1. Clones and fits the estimator on training fold + 2. Identifies variable groups on training fold + 3. Computes feature importances on test fold + 4. Returns average importance across all folds + + The importances for each fold are stored in self.importances_ + """ + importances = [] + for train, test in cv.split(X): + estimator = clone(self.estimator) + estimator.fit(X[train], y[train]) + self.fit(X[train], y[train], **fit_kwargs) + importances.append(self.importance(X[test], y[test])) + self.importances_ = importances + return np.mean(importances) + def _check_fit(self, X): """ Check if the perturbation method has been properly fitted. diff --git a/src/hidimstat/conditional_feature_importance.py b/src/hidimstat/conditional_feature_importance.py index 0f433ea4c..f45e8d402 100644 --- a/src/hidimstat/conditional_feature_importance.py +++ b/src/hidimstat/conditional_feature_importance.py @@ -2,10 +2,12 @@ from joblib import Parallel, delayed from sklearn.base import check_is_fitted, clone, BaseEstimator from sklearn.metrics import root_mean_squared_error +from sklearn.model_selection import KFold from sklearn.utils.validation import check_random_state from hidimstat.base_perturbation import BasePerturbation from hidimstat.conditional_sampling import ConditionalSampler +from hidimstat._utils.docstring import _aggregate_docstring class CFI(BasePerturbation): @@ -191,3 +193,71 @@ def _permutation(self, X, group_id): return self._list_imputation_models[group_id].sample( X_minus_j, X_j, n_samples=self.n_permutations ) + + +def cfi( + estimator, + X, + y, + cv=KFold(n_splits=5, shuffle=True, random_state=0), + groups: dict = None, + var_type: str = "auto", + method: str = "predict", + loss: callable = root_mean_squared_error, + n_permutations: int = 50, + imputation_model_continuous=None, + imputation_model_categorical=None, + categorical_max_cardinality: int = 10, + k_best=None, + percentile=None, + threshold=None, + threshold_pvalue=None, + random_state: int = None, + n_jobs: int = 1, +): + methods = CFI( + estimator=estimator, + method=method, + loss=loss, + n_permutations=n_permutations, + imputation_model_continuous=imputation_model_continuous, + imputation_model_categorical=imputation_model_categorical, + categorical_max_cardinality=categorical_max_cardinality, + random_state=random_state, + n_jobs=n_jobs, + ) + methods.fit_importance( + X, + y, + cv=cv, + groups=groups, + var_type=var_type, + ) + selection = methods.selection( + k_best=k_best, + percentile=percentile, + threshold=threshold, + threshold_pvalue=threshold_pvalue, + ) + return selection, methods.importances_, methods.pvalues_ + + +# use the docstring of the class for the function +cfi.__doc__ = _aggregate_docstring( + [ + CFI.__doc__, + CFI.__init__.__doc__, + CFI.fit_importance.__doc__, + CFI.selection.__doc__, + ], + """ + Returns + ------- + selection : ndarray of shape (n_features,) + Boolean array indicating selected features (True = selected) + importances : ndarray of shape (n_features,) + Feature importance scores/test statistics. + pvalues : ndarray of shape (n_features,) + + """, +) diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index 5a865cb71..6a9f2dce4 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -2,9 +2,11 @@ import pandas as pd from joblib import Parallel, delayed from sklearn.base import check_is_fitted, clone +from sklearn.model_selection import KFold from sklearn.metrics import root_mean_squared_error from hidimstat.base_perturbation import BasePerturbation +from hidimstat._utils.docstring import _aggregate_docstring class LOCO(BasePerturbation): @@ -89,6 +91,11 @@ def fit(self, X, y, groups=None): ) return self + def importance(self, X, y): + super().importance(X, y) + self.pvalues_ = None + return self.importances_ + def _joblib_fit_one_group(self, estimator, X, y, key_groups): """Fit the estimator after removing a group of covariates. Used in parallel.""" if isinstance(X, pd.DataFrame): @@ -116,3 +123,59 @@ def _check_fit(self, X): raise ValueError("The estimators require to be fit before to use them") for m in self._list_estimators: check_is_fitted(m) + + +def loco( + estimator, + X, + y, + cv=KFold(n_splits=5, shuffle=True, random_state=0), + groups: dict = None, + method: str = "predict", + loss: callable = root_mean_squared_error, + k_best=None, + percentile=None, + threshold=None, + threshold_pvalue=None, + n_jobs: int = 1, +): + methods = LOCO( + estimator=estimator, + method=method, + loss=loss, + n_jobs=n_jobs, + ) + methods.fit_importance( + X, + y, + cv=cv, + groups=groups, + ) + selection = methods.selection( + k_best=k_best, + percentile=percentile, + threshold=threshold, + threshold_pvalue=threshold_pvalue, + ) + return selection, methods.importances_, methods.pvalues_ + + +# use the docstring of the class for the function +loco.__doc__ = _aggregate_docstring( + [ + LOCO.__doc__, + LOCO.__init__.__doc__, + LOCO.fit_importance.__doc__, + LOCO.selection.__doc__, + ], + """ + Returns + ------- + selection : ndarray of shape (n_features,) + Boolean array indicating selected features (True = selected) + importances : ndarray of shape (n_features,) + Feature importance scores/test statistics. + pvalues : ndarray of shape (n_features,) + + """, +) diff --git a/src/hidimstat/permutation_feature_importance.py b/src/hidimstat/permutation_feature_importance.py index 10f18289f..14f02fcc5 100644 --- a/src/hidimstat/permutation_feature_importance.py +++ b/src/hidimstat/permutation_feature_importance.py @@ -1,8 +1,10 @@ import numpy as np from sklearn.metrics import root_mean_squared_error +from sklearn.model_selection import KFold from sklearn.utils import check_random_state from hidimstat.base_perturbation import BasePerturbation +from hidimstat._utils.docstring import _aggregate_docstring class PFI(BasePerturbation): @@ -69,3 +71,63 @@ def _permutation(self, X, group_id): ] ) return X_perm_j + + +def pfi( + estimator, + X, + y, + cv=KFold(n_splits=5, shuffle=True, random_state=0), + groups: dict = None, + method: str = "predict", + loss: callable = root_mean_squared_error, + n_permutations: int = 50, + k_best=None, + percentile=None, + threshold=None, + threshold_pvalue=None, + random_state: int = None, + n_jobs: int = 1, +): + methods = PFI( + estimator=estimator, + method=method, + loss=loss, + n_permutations=n_permutations, + random_state=random_state, + n_jobs=n_jobs, + ) + methods.fit_importance( + X, + y, + cv=cv, + groups=groups, + ) + selection = methods.selection( + k_best=k_best, + percentile=percentile, + threshold=threshold, + threshold_pvalue=threshold_pvalue, + ) + return selection, methods.importances_, methods.pvalues_ + + +# use the docstring of the class for the function +pfi.__doc__ = _aggregate_docstring( + [ + PFI.__doc__, + PFI.__init__.__doc__, + PFI.fit_importance.__doc__, + PFI.selection.__doc__, + ], + """ + Returns + ------- + selection : ndarray of shape (n_features,) + Boolean array indicating selected features (True = selected) + importances : ndarray of shape (n_features,) + Feature importance scores/test statistics. + pvalues : ndarray of shape (n_features,) + + """, +) From cabfb6301056fb8468f6802ca493636e35c66451 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Tue, 2 Sep 2025 18:46:47 +0200 Subject: [PATCH 06/80] Add new function --- src/hidimstat/__init__.py | 9 ++++-- src/hidimstat/base_perturbation.py | 7 ++-- test/test_conditional_feature_importance.py | 21 +++++++++++- test/test_leave_one_covariate_out.py | 34 ++++++++++++++++++- test/test_permutation_feature_importance.py | 36 ++++++++++++++++++++- 5 files changed, 99 insertions(+), 8 deletions(-) diff --git a/src/hidimstat/__init__.py b/src/hidimstat/__init__.py index 81d5a0cce..3037bbcc9 100644 --- a/src/hidimstat/__init__.py +++ b/src/hidimstat/__init__.py @@ -14,16 +14,16 @@ desparsified_group_lasso_pvalue, ) from .distilled_conditional_randomization_test import d0crt, D0CRT -from .conditional_feature_importance import CFI +from .conditional_feature_importance import cfi, CFI from .knockoffs import ( model_x_knockoff, model_x_knockoff_pvalue, model_x_knockoff_bootstrap_quantile, model_x_knockoff_bootstrap_e_value, ) -from .leave_one_covariate_out import LOCO +from .leave_one_covariate_out import loco, LOCO from .noise_std import reid -from .permutation_feature_importance import PFI +from .permutation_feature_importance import pfi, PFI from .statistical_tools.aggregation import quantile_aggregation @@ -49,6 +49,9 @@ "model_x_knockoff_bootstrap_quantile", "model_x_knockoff_bootstrap_e_value", "CFI", + "cfi", "LOCO", + "loco", "PFI", + "pfi", ] diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index fa6f7fde2..c42892fbe 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -60,6 +60,8 @@ def __init__( # varaible set in importance self.loss_reference_ = None self.loss_ = None + # variable set in fit_importance + self.importances_cv_ = None # internal variables self._n_groups = None self._groups_ids = None @@ -210,8 +212,9 @@ def fit_importance( estimator.fit(X[train], y[train]) self.fit(X[train], y[train], **fit_kwargs) importances.append(self.importance(X[test], y[test])) - self.importances_ = importances - return np.mean(importances) + self.importances_cv_ = importances + self.importances_ = np.mean(importances, axis=0) + return self.importances_ def _check_fit(self, X): """ diff --git a/test/test_conditional_feature_importance.py b/test/test_conditional_feature_importance.py index 529cf891d..ef00be5b7 100644 --- a/test/test_conditional_feature_importance.py +++ b/test/test_conditional_feature_importance.py @@ -8,7 +8,7 @@ from sklearn.model_selection import train_test_split from sklearn.metrics import root_mean_squared_error -from hidimstat import CFI, BasePerturbation +from hidimstat import cfi, CFI, BasePerturbation from hidimstat._utils.exception import InternalError @@ -565,3 +565,22 @@ def test_groups_warning(self, data_generator): " number of features for which importance is computed: 4", ): cfi.importance(X, y) + + +@pytest.mark.parametrize( + "n_samples, n_features, support_size, rho, seed, value, signal_noise_ratio, rho_serial", + [(150, 200, 10, 0.2, 42, 1.0, 1.0, 0.0)], + ids=["high level noise"], +) +@pytest.mark.parametrize("n_permutation, cfi_seed", [(20, 0)], ids=["default_cfi"]) +def test_function_cfi(data_generator, n_permutation, cfi_seed): + """Test CFI function""" + X, y, _, _ = data_generator + cfi( + LinearRegression().fit(X, y), + X, + y, + imputation_model_continuous=LinearRegression(), + n_permutations=n_permutation, + random_state=cfi_seed, + ) diff --git a/test/test_leave_one_covariate_out.py b/test/test_leave_one_covariate_out.py index f6f4b2319..a875f98d0 100644 --- a/test/test_leave_one_covariate_out.py +++ b/test/test_leave_one_covariate_out.py @@ -7,7 +7,7 @@ from sklearn.model_selection import train_test_split from hidimstat._utils.scenario import multivariate_simulation -from hidimstat import LOCO, BasePerturbation +from hidimstat import loco, LOCO, BasePerturbation def test_loco(): @@ -135,3 +135,35 @@ def test_raises_value_error(): ) BasePerturbation.fit(loco, X, y) loco.importance(X, y) + + +def test_loco_function(): + """Test the function of LOCO algorithm on a linear scenario.""" + X, y, beta, noise = multivariate_simulation( + n_samples=150, + n_features=200, + support_size=10, + shuffle=False, + seed=42, + ) + important_features = np.where(beta != 0)[0] + non_important_features = np.where(beta == 0)[0] + + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + + regression_model = LinearRegression() + regression_model.fit(X_train, y_train) + + selection, importance, pvalue = loco( + regression_model, + X, + y, + method="predict", + n_jobs=1, + ) + + assert importance.shape == (X.shape[1],) + assert ( + importance[important_features].mean() + > importance[non_important_features].mean() + ) diff --git a/test/test_permutation_feature_importance.py b/test/test_permutation_feature_importance.py index ee0a870c1..d38dfe1a2 100644 --- a/test/test_permutation_feature_importance.py +++ b/test/test_permutation_feature_importance.py @@ -5,7 +5,7 @@ from sklearn.model_selection import train_test_split import pytest -from hidimstat import PFI +from hidimstat import PFI, pfi from hidimstat._utils.scenario import multivariate_simulation @@ -96,3 +96,37 @@ def test_permutation_importance(): importance_clf = pfi_clf.importance(X_test, y_test_clf) assert importance_clf.shape == (X.shape[1],) + + +def test_permutation_importance_function(): + """Test the function of Permutation Importance algorithm on a linear scenario.""" + X, y, beta, noise = multivariate_simulation( + n_samples=150, + n_features=200, + support_size=10, + shuffle=False, + seed=42, + ) + important_features = np.where(beta != 0)[0] + non_important_features = np.where(beta == 0)[0] + + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + + regression_model = LinearRegression() + regression_model.fit(X_train, y_train) + + selection, importance, pvalue = pfi( + regression_model, + X, + y, + n_permutations=20, + method="predict", + random_state=0, + n_jobs=1, + ) + + assert importance.shape == (X.shape[1],) + assert ( + importance[important_features].mean() + > importance[non_important_features].mean() + ) From 7d7fd7d9b0b507b165497eab1963ce6750fca1f8 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Tue, 2 Sep 2025 18:55:41 +0200 Subject: [PATCH 07/80] fix docstring --- src/hidimstat/base_perturbation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index c42892fbe..bbddcf765 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -204,7 +204,7 @@ def fit_importance( 3. Computes feature importances on test fold 4. Returns average importance across all folds - The importances for each fold are stored in self.importances_ + The importances for each fold are stored in self.importances\_ """ importances = [] for train, test in cv.split(X): From b958cc7b013abb849b0e6e9bca6762ac0047f9f1 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Wed, 3 Sep 2025 15:19:11 +0200 Subject: [PATCH 08/80] Improve cross validation --- src/hidimstat/_utils/utils.py | 31 ++++++++++++++++++++++++++++++ src/hidimstat/base_perturbation.py | 24 ++++++++++++++++------- test/_utils/test_utils.py | 16 +++++++++++++++ 3 files changed, 64 insertions(+), 7 deletions(-) create mode 100644 test/_utils/test_utils.py diff --git a/src/hidimstat/_utils/utils.py b/src/hidimstat/_utils/utils.py index 66166f444..57cc38f70 100644 --- a/src/hidimstat/_utils/utils.py +++ b/src/hidimstat/_utils/utils.py @@ -25,3 +25,34 @@ def _check_vim_predict_method(method): "The method {} is not a valid method " "for variable importance measure prediction".format(method) ) + + +def get_generated_attributes(cls): + """ + Get all attributes from a class that end with a single underscore + and doesn't start with one underscore. + + Parameters + ---------- + cls : class + The class to inspect for attributes. + + Returns + ------- + list + A list of attribute names that end with a single underscore but not double underscore. + """ + # Get all attributes and methods of the class + all_attributes = dir(cls) + + # Filter out attributes that start with an underscore + filtered_attributes = [attr for attr in all_attributes if not attr.startswith("_")] + + # Filter out attributes that do not end with a single underscore + result = [ + attr + for attr in filtered_attributes + if attr.endswith("_") and not attr.endswith("__") + ] + + return result diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index bbddcf765..3cd3ae2b5 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -11,6 +11,7 @@ from hidimstat._utils.utils import _check_vim_predict_method from hidimstat._utils.exception import InternalError from hidimstat.base_variable_importance import BaseVariableImportance +from hidimstat._utils.utils import get_generated_attributes class BasePerturbation(BaseVariableImportance): @@ -60,8 +61,6 @@ def __init__( # varaible set in importance self.loss_reference_ = None self.loss_ = None - # variable set in fit_importance - self.importances_cv_ = None # internal variables self._n_groups = None self._groups_ids = None @@ -206,14 +205,23 @@ def fit_importance( The importances for each fold are stored in self.importances\_ """ - importances = [] + name_attribute_save = get_generated_attributes(self) + for name in name_attribute_save: + setattr(self, name + "cv_", []) + self.estimators_cv_ = [] + for train, test in cv.split(X): estimator = clone(self.estimator) estimator.fit(X[train], y[train]) self.fit(X[train], y[train], **fit_kwargs) - importances.append(self.importance(X[test], y[test])) - self.importances_cv_ = importances - self.importances_ = np.mean(importances, axis=0) + self.importance(X[test], y[test]) + # save result of each cv + for name in name_attribute_save: + getattr(self, name + "cv_").append(getattr(self, name)) + setattr(self, name, None) + self.estimators_cv_.append(estimator) + self.importances_ = np.mean(self.importances_cv_, axis=0) + self.pvalues_ = np.mean(self.pvalues_cv_, axis=0) return self.importances_ def _check_fit(self, X): @@ -287,7 +295,9 @@ def _check_importance(self): Checks if the loss have been computed. """ super()._check_importance() - if self.loss_reference_ is None or self.loss_ is None: + if ( + self.loss_reference_ is None and not hasattr(self, "loss_reference_cv_") + ) or (self.loss_ is None and not hasattr(self, "loss_cv_")): raise ValueError( "The importances need to be called before calling this method" ) diff --git a/test/_utils/test_utils.py b/test/_utils/test_utils.py new file mode 100644 index 000000000..192e08149 --- /dev/null +++ b/test/_utils/test_utils.py @@ -0,0 +1,16 @@ +from hidimstat._utils.utils import get_generated_attributes + + +def test_generated_attributes(): + """Test function for getting generated attribute""" + + class MyClass: + def __init__(self): + self.attr1 = 1 + self.attr2_ = 2 + self._attr3 = 3 + self.attr4__ = 4 + self.attr5_ = 5 + + attributes = get_generated_attributes(MyClass()) + assert attributes == ["attr2_", "attr5_"] From 1f97d60c3259c989983e63960d67a182b7ad6680 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Wed, 3 Sep 2025 15:26:24 +0200 Subject: [PATCH 09/80] update docstring --- src/hidimstat/base_perturbation.py | 62 ++++++++++++------- .../conditional_feature_importance.py | 2 +- src/hidimstat/leave_one_covariate_out.py | 2 +- .../permutation_feature_importance.py | 2 +- 4 files changed, 42 insertions(+), 26 deletions(-) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index 3cd3ae2b5..c2095ae9c 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -15,6 +15,45 @@ class BasePerturbation(BaseVariableImportance): + """ + Base class for model-agnostic variable importance measures based on + perturbation. + + Parameters + ---------- + estimator : sklearn compatible estimator + The estimator to use for the prediction. + method : str, default="predict" + The method used for making predictions. This determines the predictions + passed to the loss function. Supported methods are "predict", + "predict_proba", "decision_function", "transform". + loss : callable, default=root_mean_squared_error + Loss function to compute difference between original and perturbed predictions. + n_permutations : int, default=50 + Number of permutations to perform for calculating variable importance. + Higher values give more stable results but increase computation time. + n_jobs : int, default=1 + Number of parallel jobs to run. -1 means using all processors. + + Attributes + ---------- + groups : dict + Mapping of feature groups identified during fit. + importances_ : ndarray + Computed importance scores for each feature group. + loss_reference_ : float + Loss of the original model without perturbation. + loss_ : dict + Loss values for each perturbed feature group. + pvalues_ : ndarray + P-values for importance scores. + + Notes + ----- + This is an abstract base class. Concrete implementations must override + the _permutation method. + """ + def __init__( self, estimator, @@ -23,30 +62,7 @@ def __init__( n_permutations: int = 50, n_jobs: int = 1, ): - """ - Base class for model-agnostic variable importance measures based on - perturbation. - Parameters - ---------- - estimator : sklearn compatible estimator, optional - The estimator to use for the prediction. - method : str, default="predict" - The method used for making predictions. This determines the predictions - passed to the loss function. Supported methods are "predict", - "predict_proba", "decision_function", "transform". - loss : callable, default=root_mean_squared_error - The function to compute the loss when comparing the perturbed model - to the original model. - n_permutations : int, default=50 - This parameter is relevant only for PFI or CFI. - Specifies the number of times the variable group (residual for CFI) is - permuted. For each permutation, the perturbed model's loss is calculated - and averaged over all permutations. - n_jobs : int, default=1 - The number of parallel jobs to run. Parallelization is done over the - variables or groups of variables. - """ super().__init__() check_is_fitted(estimator) assert n_permutations > 0, "n_permutations must be positive" diff --git a/src/hidimstat/conditional_feature_importance.py b/src/hidimstat/conditional_feature_importance.py index f45e8d402..42686bc5d 100644 --- a/src/hidimstat/conditional_feature_importance.py +++ b/src/hidimstat/conditional_feature_importance.py @@ -258,6 +258,6 @@ def cfi( importances : ndarray of shape (n_features,) Feature importance scores/test statistics. pvalues : ndarray of shape (n_features,) - + P-values for importance scores. """, ) diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index 6a9f2dce4..048be4216 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -176,6 +176,6 @@ def loco( importances : ndarray of shape (n_features,) Feature importance scores/test statistics. pvalues : ndarray of shape (n_features,) - + None because there is no p-value for this method """, ) diff --git a/src/hidimstat/permutation_feature_importance.py b/src/hidimstat/permutation_feature_importance.py index 14f02fcc5..6e25bc38a 100644 --- a/src/hidimstat/permutation_feature_importance.py +++ b/src/hidimstat/permutation_feature_importance.py @@ -128,6 +128,6 @@ def pfi( importances : ndarray of shape (n_features,) Feature importance scores/test statistics. pvalues : ndarray of shape (n_features,) - + P-values for importance scores. """, ) From db96bb6e54272d887740cdf03cf9b8e0b688bc2c Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Wed, 3 Sep 2025 15:31:14 +0200 Subject: [PATCH 10/80] update doctring --- src/hidimstat/base_perturbation.py | 66 ++++++++++++++++++------ src/hidimstat/leave_one_covariate_out.py | 18 +++++++ 2 files changed, 67 insertions(+), 17 deletions(-) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index c2095ae9c..cb6a546d2 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -151,19 +151,38 @@ def importance(self, X, y): Parameters ---------- - X: array-like of shape (n_samples, n_features) - The input samples. - y: array-like of shape (n_samples,) - The target values. + X : array-like of shape (n_samples, n_features) + The input samples to compute importance scores for. + y : array-like of shape (n_samples,) + + importances_ : ndarray of shape (n_groups,) + The importance scores for each group of covariates. + A higher score indicates greater importance of that group. Returns ------- - out_dict: dict - A dictionary containing the following keys: - - 'loss_reference': the loss of the model with the original data. - - 'loss': a dictionary containing the loss of the perturbed model - for each group. - - 'importance': the importance scores for each group. + importances_ : ndarray of shape (n_features,) + Importance scores for each feature. + + Attributes + ---------- + loss_reference_ : float + The loss of the model with the original (non-perturbed) data. + loss_ : dict + Dictionary with indices as keys and arrays of perturbed losses as values. + Contains the loss values for each permutation of each group. + importances_ : ndarray of shape (n_groups,) + The calculated importance scores for each group. + pvalues_ : ndarray of shape (n_groups,) + P-values from one-sided t-test testing if importance scores are + significantly greater than 0. + + Notes + ----- + The importance score for each group is calculated as the mean increase in loss + when that group is perturbed, compared to the reference loss. + A higher importance score indicates that perturbing that group leads to + worse model performance, suggesting those features are more important. """ self._check_fit(X) @@ -208,18 +227,31 @@ def fit_importance( Returns ------- - importances : float - Mean feature importance scores across CV folds. + importances_ : ndarray + Average importance scores for each feature group across CV folds. + + Attributes + ---------- + estimators_cv_ : list + List of fitted estimators for each CV fold. + importances_cv_ : list + List of importance scores for each CV fold. + pvalues_cv_ : list + List of p-values for each CV fold. + loss_cv_ : list + List of loss values for each CV fold. + loss_reference_cv_ : list + List of reference loss values for each CV fold. Notes ----- For each CV fold: - 1. Clones and fits the estimator on training fold - 2. Identifies variable groups on training fold - 3. Computes feature importances on test fold - 4. Returns average importance across all folds + 1. Fits a clone of the base estimator on the training fold + 2. Identifies variable groups on the training fold + 3. Computes feature importances using the test fold + 4. Stores results for each fold in respective cv_ attributes - The importances for each fold are stored in self.importances\_ + Final importances_ and pvalues_ are averaged across all CV folds. """ name_attribute_save = get_generated_attributes(self) for name in name_attribute_save: diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index 048be4216..8b67739cc 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -92,6 +92,24 @@ def fit(self, X, y, groups=None): return self def importance(self, X, y): + """ + Compute the importance scores for each group of covariates. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + The input samples to compute importance scores for. + y : array-like of shape (n_samples,) + + importances_ : ndarray of shape (n_groups,) + The importance scores for each group of covariates. + A higher score indicates greater importance of that group. + + Returns + ------- + importances_ : ndarray of shape (n_features,) + Importance scores for each feature. + """ super().importance(X, y) self.pvalues_ = None return self.importances_ From d656f17213246471e040311f93c99c23db5581df Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Wed, 3 Sep 2025 15:48:10 +0200 Subject: [PATCH 11/80] fix error --- src/hidimstat/base_perturbation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index cb6a546d2..7fe4eb681 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -269,7 +269,9 @@ def fit_importance( setattr(self, name, None) self.estimators_cv_.append(estimator) self.importances_ = np.mean(self.importances_cv_, axis=0) - self.pvalues_ = np.mean(self.pvalues_cv_, axis=0) + self.pvalues_ = ( + None if self.pvalues_cv_[0] is None else np.mean(self.pvalues_cv_, axis=0) + ) return self.importances_ def _check_fit(self, X): From 0493b6f05446a7a3deaa997f7a18486c5dc7595f Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Wed, 3 Sep 2025 16:48:04 +0200 Subject: [PATCH 12/80] fix docstring --- src/hidimstat/base_perturbation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index 7fe4eb681..e1d5726e5 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -249,9 +249,9 @@ def fit_importance( 1. Fits a clone of the base estimator on the training fold 2. Identifies variable groups on the training fold 3. Computes feature importances using the test fold - 4. Stores results for each fold in respective cv_ attributes + 4. Stores results for each fold in respective cv\_ attributes - Final importances_ and pvalues_ are averaged across all CV folds. + Final importances\_ and pvalues\_ are averaged across all CV folds. """ name_attribute_save = get_generated_attributes(self) for name in name_attribute_save: From 9c54e1bd26c7cdbbe2f947e58392a3f3a490d037 Mon Sep 17 00:00:00 2001 From: lionel kusch Date: Fri, 5 Sep 2025 11:52:43 +0200 Subject: [PATCH 13/80] Apply suggestions from code review Co-authored-by: Joseph Paillard --- src/hidimstat/conditional_feature_importance.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hidimstat/conditional_feature_importance.py b/src/hidimstat/conditional_feature_importance.py index 42686bc5d..42a09a813 100644 --- a/src/hidimstat/conditional_feature_importance.py +++ b/src/hidimstat/conditional_feature_importance.py @@ -30,10 +30,10 @@ class CFI(BasePerturbation): n_permutations : int, default=50 The number of permutations to perform. For each variable/group of variables, the mean of the losses over the `n_permutations` is computed. - imputation_model_continuous : sklearn compatible estimator, optional + imputation_model_continuous : sklearn compatible estimator, default=RidgeCV() The model used to estimate the conditional distribution of a given continuous variable/group of variables given the others. - imputation_model_categorical : sklearn compatible estimator, optional + imputation_model_categorical : sklearn compatible estimator, default=LogisticRegressionCV() The model used to estimate the conditional distribution of a given categorical variable/group of variables given the others. Binary is considered as a special case of categorical. From 7bf75e4c7600b4ac633282f7d6fa6f3078916e9d Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Fri, 5 Sep 2025 11:55:04 +0200 Subject: [PATCH 14/80] Update default --- src/hidimstat/conditional_feature_importance.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/hidimstat/conditional_feature_importance.py b/src/hidimstat/conditional_feature_importance.py index 42a09a813..4b1e2cdf3 100644 --- a/src/hidimstat/conditional_feature_importance.py +++ b/src/hidimstat/conditional_feature_importance.py @@ -3,6 +3,7 @@ from sklearn.base import check_is_fitted, clone, BaseEstimator from sklearn.metrics import root_mean_squared_error from sklearn.model_selection import KFold +from sklearn.linear_model import RidgeCV, LogisticRegressionCV from sklearn.utils.validation import check_random_state from hidimstat.base_perturbation import BasePerturbation @@ -57,8 +58,8 @@ def __init__( method: str = "predict", loss: callable = root_mean_squared_error, n_permutations: int = 50, - imputation_model_continuous=None, - imputation_model_categorical=None, + imputation_model_continuous=RidgeCV(), + imputation_model_categorical=LogisticRegressionCV(), categorical_max_cardinality: int = 10, random_state: int = None, n_jobs: int = 1, From b3cd78a3ea3378e7830ef49478cdd5049347ba18 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Fri, 5 Sep 2025 12:47:21 +0200 Subject: [PATCH 15/80] fix tests --- test/test_conditional_feature_importance.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_conditional_feature_importance.py b/test/test_conditional_feature_importance.py index ef00be5b7..54013e2f0 100644 --- a/test/test_conditional_feature_importance.py +++ b/test/test_conditional_feature_importance.py @@ -7,6 +7,7 @@ from sklearn.metrics import log_loss from sklearn.model_selection import train_test_split from sklearn.metrics import root_mean_squared_error +from sklearn.linear_model import RidgeCV, LogisticRegressionCV from hidimstat import cfi, CFI, BasePerturbation from hidimstat._utils.exception import InternalError @@ -278,8 +279,8 @@ def test_init(self, data_generator): assert cfi.loss == root_mean_squared_error assert cfi.method == "predict" assert cfi.categorical_max_cardinality == 10 - assert cfi.imputation_model_categorical is None - assert cfi.imputation_model_continuous is None + assert isinstance(cfi.imputation_model_categorical, LogisticRegressionCV) + assert isinstance(cfi.imputation_model_continuous, RidgeCV) def test_fit(self, data_generator): """Test fitting CFI""" From 782549018a88462d4220c2cbbe983028d98f9807 Mon Sep 17 00:00:00 2001 From: lionel kusch Date: Mon, 8 Sep 2025 10:59:18 +0200 Subject: [PATCH 16/80] Apply suggestions from code review Co-authored-by: bthirion --- src/hidimstat/base_perturbation.py | 4 ++-- test/test_permutation_feature_importance.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index e1d5726e5..98f3a30cf 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -342,14 +342,14 @@ def _check_fit(self, X): def _check_importance(self): """ - Checks if the loss have been computed. + Checks if the loss has been computed. """ super()._check_importance() if ( self.loss_reference_ is None and not hasattr(self, "loss_reference_cv_") ) or (self.loss_ is None and not hasattr(self, "loss_cv_")): raise ValueError( - "The importances need to be called before calling this method" + "The importance method has not yet been called." ) def _joblib_predict_one_group(self, X, group_id, group_key): diff --git a/test/test_permutation_feature_importance.py b/test/test_permutation_feature_importance.py index d38dfe1a2..0d6ea7b24 100644 --- a/test/test_permutation_feature_importance.py +++ b/test/test_permutation_feature_importance.py @@ -107,8 +107,7 @@ def test_permutation_importance_function(): shuffle=False, seed=42, ) - important_features = np.where(beta != 0)[0] - non_important_features = np.where(beta == 0)[0] + important_features = beta != 0 X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) @@ -128,5 +127,5 @@ def test_permutation_importance_function(): assert importance.shape == (X.shape[1],) assert ( importance[important_features].mean() - > importance[non_important_features].mean() + > importance[1 - important_features].mean() ) From 084ad245beca4d99d7046da12d1dd771a9a2acb5 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Mon, 8 Sep 2025 10:59:33 +0200 Subject: [PATCH 17/80] chnage group by features_groups --- src/hidimstat/base_perturbation.py | 27 ++++++++++++--------- src/hidimstat/leave_one_covariate_out.py | 15 ++++++------ test/test_conditional_feature_importance.py | 2 +- 3 files changed, 24 insertions(+), 20 deletions(-) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index 98f3a30cf..e88c56c2b 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -73,7 +73,7 @@ def __init__( self.n_permutations = n_permutations self.n_jobs = n_jobs # variable set in fit - self.groups = None + self.features_groups = None # varaible set in importance self.loss_reference_ = None self.loss_ = None @@ -97,24 +97,25 @@ def fit(self, X, y=None, groups=None): """ if groups is None: self._n_groups = X.shape[1] - self.groups = {j: [j] for j in range(self._n_groups)} - self._groups_ids = np.array(list(self.groups.values()), dtype=int) + self.features_groups = {j: [j] for j in range(self._n_groups)} + self._groups_ids = np.array(list(self.features_groups.values()), dtype=int) elif isinstance(groups, dict): self._n_groups = len(groups) - self.groups = groups + self.features_groups = groups if isinstance(X, pd.DataFrame): self._groups_ids = [] - for group_key in self.groups.keys(): + for group_key in self.features_groups.keys(): self._groups_ids.append( [ i for i, col in enumerate(X.columns) - if col in self.groups[group_key] + if col in self.features_groups[group_key] ] ) else: self._groups_ids = [ - np.array(ids, dtype=int) for ids in list(self.groups.values()) + np.array(ids, dtype=int) + for ids in list(self.features_groups.values()) ] else: raise ValueError("groups needs to be a dictionnary") @@ -141,7 +142,7 @@ def predict(self, X): # Parallelize the computation of the importance scores for each group out_list = Parallel(n_jobs=self.n_jobs)( delayed(self._joblib_predict_one_group)(X_, group_id, group_key) - for group_id, group_key in enumerate(self.groups.keys()) + for group_id, group_key in enumerate(self.features_groups.keys()) ) return np.stack(out_list, axis=0) @@ -296,7 +297,11 @@ def _check_fit(self, X): If the number of features in X does not match the total number of features in the grouped variables. """ - if self._n_groups is None or self.groups is None or self._groups_ids is None: + if ( + self._n_groups is None + or self.features_groups is None + or self._groups_ids is None + ): raise ValueError( "The class is not fitted. The fit method must be called" " to set variable groups. If no grouping is needed," @@ -313,7 +318,7 @@ def _check_fit(self, X): else: raise ValueError("X should be a pandas dataframe or a numpy array.") number_columns = X.shape[1] - for index_variables in self.groups.values(): + for index_variables in self.features_groups.values(): if type(index_variables[0]) is int or np.issubdtype( type(index_variables[0]), int ): @@ -331,7 +336,7 @@ def _check_fit(self, X): "A problem with indexing has happened during the fit." ) number_unique_feature_in_groups = np.unique( - np.concatenate([values for values in self.groups.values()]) + np.concatenate([values for values in self.features_groups.values()]) ).shape[0] if X.shape[1] != number_unique_feature_in_groups: warnings.warn( diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index 8b67739cc..3bc2ad876 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -87,7 +87,9 @@ def fit(self, X, y, groups=None): # Parallelize the fitting of the covariate estimators self._list_estimators = Parallel(n_jobs=self.n_jobs)( delayed(self._joblib_fit_one_group)(estimator, X, y, key_groups) - for key_groups, estimator in zip(self.groups.keys(), self._list_estimators) + for key_groups, estimator in zip( + self.features_groups.keys(), self._list_estimators + ) ) return self @@ -101,14 +103,11 @@ def importance(self, X, y): The input samples to compute importance scores for. y : array-like of shape (n_samples,) - importances_ : ndarray of shape (n_groups,) - The importance scores for each group of covariates. - A higher score indicates greater importance of that group. - Returns ------- importances_ : ndarray of shape (n_features,) - Importance scores for each feature. + The importance scores for each group of covariates. + A higher score indicates greater importance of that group. """ super().importance(X, y) self.pvalues_ = None @@ -117,9 +116,9 @@ def importance(self, X, y): def _joblib_fit_one_group(self, estimator, X, y, key_groups): """Fit the estimator after removing a group of covariates. Used in parallel.""" if isinstance(X, pd.DataFrame): - X_minus_j = X.drop(columns=self.groups[key_groups]) + X_minus_j = X.drop(columns=self.features_groups[key_groups]) else: - X_minus_j = np.delete(X, self.groups[key_groups], axis=1) + X_minus_j = np.delete(X, self.features_groups[key_groups], axis=1) estimator.fit(X_minus_j, y) return estimator diff --git a/test/test_conditional_feature_importance.py b/test/test_conditional_feature_importance.py index 54013e2f0..30a0bda32 100644 --- a/test/test_conditional_feature_importance.py +++ b/test/test_conditional_feature_importance.py @@ -498,7 +498,7 @@ def test_internal_error(self, data_generator): ], } cfi.fit(X, groups=subgroups, var_type="auto") - cfi.groups["group1"] = [None for i in range(100)] + cfi.features_groups["group1"] = [None for i in range(100)] X = X.to_records(index=False) X = np.array(X, dtype=X.dtype.descr) From 7379ec1d2d991e2198a99820ac611a0f0488513a Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Mon, 8 Sep 2025 11:01:07 +0200 Subject: [PATCH 18/80] fix format --- src/hidimstat/base_perturbation.py | 4 +--- test/test_permutation_feature_importance.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index e88c56c2b..c19bc607c 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -353,9 +353,7 @@ def _check_importance(self): if ( self.loss_reference_ is None and not hasattr(self, "loss_reference_cv_") ) or (self.loss_ is None and not hasattr(self, "loss_cv_")): - raise ValueError( - "The importance method has not yet been called." - ) + raise ValueError("The importance method has not yet been called.") def _joblib_predict_one_group(self, X, group_id, group_key): """ diff --git a/test/test_permutation_feature_importance.py b/test/test_permutation_feature_importance.py index 0d6ea7b24..1421cab0b 100644 --- a/test/test_permutation_feature_importance.py +++ b/test/test_permutation_feature_importance.py @@ -127,5 +127,5 @@ def test_permutation_importance_function(): assert importance.shape == (X.shape[1],) assert ( importance[important_features].mean() - > importance[1 - important_features].mean() + > importance[1 - important_features].mean() ) From 02ae5ba2ed95f92d10a5b5b10f3ab15a626146f7 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Mon, 8 Sep 2025 11:04:10 +0200 Subject: [PATCH 19/80] improve test --- test/test_permutation_feature_importance.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_permutation_feature_importance.py b/test/test_permutation_feature_importance.py index 1421cab0b..42b1e93fe 100644 --- a/test/test_permutation_feature_importance.py +++ b/test/test_permutation_feature_importance.py @@ -126,6 +126,5 @@ def test_permutation_importance_function(): assert importance.shape == (X.shape[1],) assert ( - importance[important_features].mean() - > importance[1 - important_features].mean() + importance[important_features].mean() > importance[~important_features].mean() ) From 1e91c65cc5e5e0ead7ca68f66a0eef93a9b3c66b Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Mon, 8 Sep 2025 11:05:02 +0200 Subject: [PATCH 20/80] fix docstring --- src/hidimstat/base_perturbation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index c19bc607c..3c5d0bc06 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -37,7 +37,7 @@ class BasePerturbation(BaseVariableImportance): Attributes ---------- - groups : dict + features_groups : dict Mapping of feature groups identified during fit. importances_ : ndarray Computed importance scores for each feature group. From 58a57f8b455806bbab73d3ec074aae761fc24ee4 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Mon, 8 Sep 2025 11:09:25 +0200 Subject: [PATCH 21/80] fix test --- test/test_base_perturbation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_base_perturbation.py b/test/test_base_perturbation.py index 9198cf72d..d21cf1958 100644 --- a/test/test_base_perturbation.py +++ b/test/test_base_perturbation.py @@ -22,6 +22,6 @@ def test_chek_importance(): basic_class = BasePerturbation(estimator=estimator) basic_class.importances_ = [] with pytest.raises( - ValueError, match="The importances need to be called before calling this method" + ValueError, match="The importance method has not yet been called." ): basic_class.selection() From c4ea7318eeb33946139a2d177efae9f87a39cf07 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Thu, 11 Sep 2025 15:10:51 +0200 Subject: [PATCH 22/80] improve loco --- src/hidimstat/leave_one_covariate_out.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index d5625441b..6deba0c9f 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -59,7 +59,7 @@ def __init__( n_jobs=n_jobs, ) # internal variable - self._list_estimators = [] + self._list_estimators = None def fit(self, X, y, groups=None): """ @@ -136,7 +136,7 @@ def _check_fit(self, X): covariates.""" super()._check_fit(X) check_is_fitted(self.estimator) - if len(self._list_estimators) == 0: + if self._list_estimators is None: raise ValueError("The estimators require to be fit before to use them") for m in self._list_estimators: check_is_fitted(m) From 43d3f997a8bf32ddf599e4cf739797ec7887d643 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Thu, 11 Sep 2025 15:45:17 +0200 Subject: [PATCH 23/80] fix computation of pvalues --- src/hidimstat/base_perturbation.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index 3c5d0bc06..0b63ef361 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -198,14 +198,12 @@ def importance(self, X, y): list_loss.append(self.loss(y, y_pred_perm)) self.loss_[j] = np.array(list_loss) - self.importances_ = np.array( - [ - np.mean(self.loss_[j]) - self.loss_reference_ - for j in range(self._n_groups) - ] + test_result = np.array( + [self.loss_[j] - self.loss_reference_ for j in range(self._n_groups)] ) + self.importances_ = np.mean(test_result, axis=1) self.pvalues_ = ttest_1samp( - self.importances_, 0.0, axis=0, alternative="greater" + test_result, 0.0, axis=1, alternative="greater" ).pvalue return self.importances_ From 5fd99b0342442f7fb306ccda6a56f6a32cfff80c Mon Sep 17 00:00:00 2001 From: lionel kusch Date: Fri, 12 Sep 2025 17:46:21 +0200 Subject: [PATCH 24/80] Update src/hidimstat/_utils/utils.py Co-authored-by: Joseph Paillard --- src/hidimstat/_utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hidimstat/_utils/utils.py b/src/hidimstat/_utils/utils.py index bc38b6217..fb2190c86 100644 --- a/src/hidimstat/_utils/utils.py +++ b/src/hidimstat/_utils/utils.py @@ -27,7 +27,7 @@ def _check_vim_predict_method(method): ) -def get_generated_attributes(cls): +def get_fitted_attributes(cls): """ Get all attributes from a class that end with a single underscore and doesn't start with one underscore. From 3c527893ecac90bc01d51ee060e18e20593d8af3 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Fri, 12 Sep 2025 17:48:05 +0200 Subject: [PATCH 25/80] change name --- src/hidimstat/base_perturbation.py | 4 ++-- test/_utils/test_utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index 0b63ef361..62f3cfe66 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -11,7 +11,7 @@ from hidimstat._utils.utils import _check_vim_predict_method from hidimstat._utils.exception import InternalError from hidimstat.base_variable_importance import BaseVariableImportance -from hidimstat._utils.utils import get_generated_attributes +from hidimstat._utils.utils import get_fitted_attributes class BasePerturbation(BaseVariableImportance): @@ -252,7 +252,7 @@ def fit_importance( Final importances\_ and pvalues\_ are averaged across all CV folds. """ - name_attribute_save = get_generated_attributes(self) + name_attribute_save = get_fitted_attributes(self) for name in name_attribute_save: setattr(self, name + "cv_", []) self.estimators_cv_ = [] diff --git a/test/_utils/test_utils.py b/test/_utils/test_utils.py index 192e08149..ab06bbd8e 100644 --- a/test/_utils/test_utils.py +++ b/test/_utils/test_utils.py @@ -1,4 +1,4 @@ -from hidimstat._utils.utils import get_generated_attributes +from hidimstat._utils.utils import get_fitted_attributes def test_generated_attributes(): @@ -12,5 +12,5 @@ def __init__(self): self.attr4__ = 4 self.attr5_ = 5 - attributes = get_generated_attributes(MyClass()) + attributes = get_fitted_attributes(MyClass()) assert attributes == ["attr2_", "attr5_"] From c93f14cd57fab0f113caa67d730164ce1a4c7eed Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Wed, 24 Sep 2025 18:06:47 +0200 Subject: [PATCH 26/80] remove the cross validation in fit_importance --- src/hidimstat/base_perturbation.py | 64 +++++------------------------- 1 file changed, 11 insertions(+), 53 deletions(-) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index 62f3cfe66..5f791f7d8 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -4,14 +4,12 @@ import pandas as pd from joblib import Parallel, delayed from scipy.stats import ttest_1samp -from sklearn.base import check_is_fitted, clone +from sklearn.base import check_is_fitted from sklearn.metrics import root_mean_squared_error -from sklearn.model_selection import KFold from hidimstat._utils.utils import _check_vim_predict_method from hidimstat._utils.exception import InternalError from hidimstat.base_variable_importance import BaseVariableImportance -from hidimstat._utils.utils import get_fitted_attributes class BasePerturbation(BaseVariableImportance): @@ -207,11 +205,10 @@ def importance(self, X, y): ).pvalue return self.importances_ - def fit_importance( - self, X, y, cv=KFold(n_splits=5, shuffle=True, random_state=0), **fit_kwargs - ): + def fit_importance(self, X, y): """ - Compute feature importance scores using cross-validation. + Fits the model to the data and computes feature importance scores. + Convenience method that combines fit() and importance() into a single call. Parameters ---------- @@ -219,59 +216,20 @@ def fit_importance( Training data. y : array-like of shape (n_samples,) Target values. - cv : cross-validation generator or iterable, default=KFold(n_splits=5, shuffle=True, random_state=0) - Determines the cross-validation splitting strategy. - **fit_kwargs : dict - Additional arguments passed to the fit method during variable group identification. Returns ------- - importances_ : ndarray - Average importance scores for each feature group across CV folds. - - Attributes - ---------- - estimators_cv_ : list - List of fitted estimators for each CV fold. - importances_cv_ : list - List of importance scores for each CV fold. - pvalues_cv_ : list - List of p-values for each CV fold. - loss_cv_ : list - List of loss values for each CV fold. - loss_reference_cv_ : list - List of reference loss values for each CV fold. + importances_ : ndarray of shape (n_groups,) + The calculated importance scores for each feature group. + Higher values indicate greater importance. Notes ----- - For each CV fold: - 1. Fits a clone of the base estimator on the training fold - 2. Identifies variable groups on the training fold - 3. Computes feature importances using the test fold - 4. Stores results for each fold in respective cv\_ attributes - - Final importances\_ and pvalues\_ are averaged across all CV folds. + This method first calls fit() to identify feature groups, then calls + importance() to compute the importance scores for each group. """ - name_attribute_save = get_fitted_attributes(self) - for name in name_attribute_save: - setattr(self, name + "cv_", []) - self.estimators_cv_ = [] - - for train, test in cv.split(X): - estimator = clone(self.estimator) - estimator.fit(X[train], y[train]) - self.fit(X[train], y[train], **fit_kwargs) - self.importance(X[test], y[test]) - # save result of each cv - for name in name_attribute_save: - getattr(self, name + "cv_").append(getattr(self, name)) - setattr(self, name, None) - self.estimators_cv_.append(estimator) - self.importances_ = np.mean(self.importances_cv_, axis=0) - self.pvalues_ = ( - None if self.pvalues_cv_[0] is None else np.mean(self.pvalues_cv_, axis=0) - ) - return self.importances_ + self.fit(X, y) + return self.importance(X, y) def _check_fit(self, X): """ From aa583d5c286a7967ab828eeb21b0d3930ebf033a Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Wed, 24 Sep 2025 18:16:58 +0200 Subject: [PATCH 27/80] change fit_importance --- src/hidimstat/base_perturbation.py | 12 +++--- .../conditional_feature_importance.py | 37 +++++++++++++++++-- src/hidimstat/leave_one_covariate_out.py | 2 - .../permutation_feature_importance.py | 3 -- 4 files changed, 41 insertions(+), 13 deletions(-) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index 5f791f7d8..aae76d9ee 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -205,7 +205,7 @@ def importance(self, X, y): ).pvalue return self.importances_ - def fit_importance(self, X, y): + def fit_importance(self, X, y, groups=None): """ Fits the model to the data and computes feature importance scores. Convenience method that combines fit() and importance() into a single call. @@ -216,6 +216,10 @@ def fit_importance(self, X, y): Training data. y : array-like of shape (n_samples,) Target values. + groups: dict, optional + A dictionary where the keys are the group names and the values are the + list of column names corresponding to each group. If None, the groups are + identified based on the columns of X. Returns ------- @@ -228,7 +232,7 @@ def fit_importance(self, X, y): This method first calls fit() to identify feature groups, then calls importance() to compute the importance scores for each group. """ - self.fit(X, y) + self.fit(X, y, groups) return self.importance(X, y) def _check_fit(self, X): @@ -306,9 +310,7 @@ def _check_importance(self): Checks if the loss has been computed. """ super()._check_importance() - if ( - self.loss_reference_ is None and not hasattr(self, "loss_reference_cv_") - ) or (self.loss_ is None and not hasattr(self, "loss_cv_")): + if (self.loss_reference_ is None) or (self.loss_ is None): raise ValueError("The importance method has not yet been called.") def _joblib_predict_one_group(self, X, group_id, group_key): diff --git a/src/hidimstat/conditional_feature_importance.py b/src/hidimstat/conditional_feature_importance.py index c11ab714f..e971fc70f 100644 --- a/src/hidimstat/conditional_feature_importance.py +++ b/src/hidimstat/conditional_feature_importance.py @@ -2,7 +2,6 @@ from joblib import Parallel, delayed from sklearn.base import check_is_fitted, clone, BaseEstimator from sklearn.metrics import root_mean_squared_error -from sklearn.model_selection import KFold from sklearn.linear_model import RidgeCV, LogisticRegressionCV from sklearn.utils.validation import check_random_state @@ -147,6 +146,40 @@ def fit(self, X, y=None, groups=None, var_type="auto"): return self + def fit_importance(self, X, y, groups=None, var_type="auto"): + """ + Fits the model to the data and computes feature importance scores. + Convenience method that combines fit() and importance() into a single call. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Training data. + y : array-like of shape (n_samples,) + Target values. + groups: dict, optional + A dictionary where the keys are the group names and the values are the + list of column names corresponding to each group. If None, the groups are + identified based on the columns of X. + var_type: str or list, default="auto" + The variable type. Supported types include "auto", "continuous", and + "categorical". If "auto", the type is inferred from the cardinality + of the unique values passed to the `fit` method. + + Returns + ------- + importances_ : ndarray of shape (n_groups,) + The calculated importance scores for each feature group. + Higher values indicate greater importance. + + Notes + ----- + This method first calls fit() to identify feature groups, then calls + importance() to compute the importance scores for each group. + """ + self.fit(X, y, groups, var_type) + return self.importance(X, y) + def _joblib_fit_one_group(self, estimator, X, groups_ids): """Fit a single imputation model, for a single group of variables. This method is parallelized.""" @@ -200,7 +233,6 @@ def cfi( estimator, X, y, - cv=KFold(n_splits=5, shuffle=True, random_state=0), groups: dict = None, var_type: str = "auto", method: str = "predict", @@ -230,7 +262,6 @@ def cfi( methods.fit_importance( X, y, - cv=cv, groups=groups, var_type=var_type, ) diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index 6deba0c9f..caf287174 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -146,7 +146,6 @@ def loco( estimator, X, y, - cv=KFold(n_splits=5, shuffle=True, random_state=0), groups: dict = None, method: str = "predict", loss: callable = root_mean_squared_error, @@ -165,7 +164,6 @@ def loco( methods.fit_importance( X, y, - cv=cv, groups=groups, ) selection = methods.selection( diff --git a/src/hidimstat/permutation_feature_importance.py b/src/hidimstat/permutation_feature_importance.py index 6d6af8673..f9d97b698 100644 --- a/src/hidimstat/permutation_feature_importance.py +++ b/src/hidimstat/permutation_feature_importance.py @@ -1,6 +1,5 @@ import numpy as np from sklearn.metrics import root_mean_squared_error -from sklearn.model_selection import KFold from sklearn.utils import check_random_state from hidimstat.base_perturbation import BasePerturbation @@ -77,7 +76,6 @@ def pfi( estimator, X, y, - cv=KFold(n_splits=5, shuffle=True, random_state=0), groups: dict = None, method: str = "predict", loss: callable = root_mean_squared_error, @@ -100,7 +98,6 @@ def pfi( methods.fit_importance( X, y, - cv=cv, groups=groups, ) selection = methods.selection( From 01cbc44adc2d34af216d632fb570d627eba1f163 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Wed, 24 Sep 2025 18:33:38 +0200 Subject: [PATCH 28/80] more flexible for the computation of the statistic --- src/hidimstat/base_perturbation.py | 10 ++++++---- src/hidimstat/conditional_feature_importance.py | 5 +++++ src/hidimstat/permutation_feature_importance.py | 5 +++++ 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index aae76d9ee..069d0e4b3 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -1,9 +1,11 @@ +from functools import partial import warnings + import numpy as np import pandas as pd from joblib import Parallel, delayed -from scipy.stats import ttest_1samp +from scipy.stats import wilcoxon from sklearn.base import check_is_fitted from sklearn.metrics import root_mean_squared_error @@ -58,6 +60,7 @@ def __init__( method: str = "predict", loss: callable = root_mean_squared_error, n_permutations: int = 50, + test_statict=partial(wilcoxon, axis=1), n_jobs: int = 1, ): @@ -69,6 +72,7 @@ def __init__( _check_vim_predict_method(method) self.method = method self.n_permutations = n_permutations + self.test_statistic = test_statict self.n_jobs = n_jobs # variable set in fit self.features_groups = None @@ -200,9 +204,7 @@ def importance(self, X, y): [self.loss_[j] - self.loss_reference_ for j in range(self._n_groups)] ) self.importances_ = np.mean(test_result, axis=1) - self.pvalues_ = ttest_1samp( - test_result, 0.0, axis=1, alternative="greater" - ).pvalue + self.pvalues_ = self.test_statistic(test_result).pvalue return self.importances_ def fit_importance(self, X, y, groups=None): diff --git a/src/hidimstat/conditional_feature_importance.py b/src/hidimstat/conditional_feature_importance.py index e971fc70f..54f972aad 100644 --- a/src/hidimstat/conditional_feature_importance.py +++ b/src/hidimstat/conditional_feature_importance.py @@ -1,5 +1,8 @@ +from functools import partial + import numpy as np from joblib import Parallel, delayed +from scipy.stats import wilcoxon from sklearn.base import check_is_fitted, clone, BaseEstimator from sklearn.metrics import root_mean_squared_error from sklearn.linear_model import RidgeCV, LogisticRegressionCV @@ -60,6 +63,7 @@ def __init__( imputation_model_continuous=RidgeCV(), imputation_model_categorical=LogisticRegressionCV(), categorical_max_cardinality: int = 10, + test_statict=partial(wilcoxon, axis=1), random_state: int = None, n_jobs: int = 1, ): @@ -69,6 +73,7 @@ def __init__( method=method, loss=loss, n_permutations=n_permutations, + test_statict=test_statict, n_jobs=n_jobs, ) diff --git a/src/hidimstat/permutation_feature_importance.py b/src/hidimstat/permutation_feature_importance.py index f9d97b698..0aabe3da2 100644 --- a/src/hidimstat/permutation_feature_importance.py +++ b/src/hidimstat/permutation_feature_importance.py @@ -1,4 +1,7 @@ +from functools import partial + import numpy as np +from scipy.stats import wilcoxon from sklearn.metrics import root_mean_squared_error from sklearn.utils import check_random_state @@ -47,6 +50,7 @@ def __init__( method: str = "predict", loss: callable = root_mean_squared_error, n_permutations: int = 50, + test_statict=partial(wilcoxon, axis=1), random_state: int = None, n_jobs: int = 1, ): @@ -56,6 +60,7 @@ def __init__( method=method, loss=loss, n_permutations=n_permutations, + test_statict=test_statict, n_jobs=n_jobs, ) self.random_state = random_state From b1c5f400d0d27889496517028a8c2292a9ecc644 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Wed, 24 Sep 2025 18:34:05 +0200 Subject: [PATCH 29/80] update the computation of pvalue for loco --- src/hidimstat/leave_one_covariate_out.py | 52 +++++++++++++++++++++--- 1 file changed, 47 insertions(+), 5 deletions(-) diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index caf287174..15ddf4a47 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -1,8 +1,10 @@ +from functools import partial + import numpy as np import pandas as pd from joblib import Parallel, delayed +from scipy.stats import wilcoxon from sklearn.base import check_is_fitted, clone -from sklearn.model_selection import KFold from sklearn.metrics import root_mean_squared_error from hidimstat.base_perturbation import BasePerturbation @@ -48,6 +50,7 @@ def __init__( estimator, method: str = "predict", loss: callable = root_mean_squared_error, + test_statict=partial(wilcoxon, axis=1), n_jobs: int = 1, ): @@ -56,6 +59,7 @@ def __init__( method=method, loss=loss, n_permutations=1, + test_statict=test_statict, n_jobs=n_jobs, ) # internal variable @@ -103,14 +107,52 @@ def importance(self, X, y): The input samples to compute importance scores for. y : array-like of shape (n_samples,) + importances_ : ndarray of shape (n_groups,) + The importance scores for each group of covariates. + A higher score indicates greater importance of that group. + Returns ------- importances_ : ndarray of shape (n_features,) - The importance scores for each group of covariates. - A higher score indicates greater importance of that group. + Importance scores for each feature. + + Attributes + ---------- + loss_reference_ : float + The loss of the model with the original (non-perturbed) data. + loss_ : dict + Dictionary with indices as keys and arrays of perturbed losses as values. + Contains the loss values for each permutation of each group. + importances_ : ndarray of shape (n_groups,) + The calculated importance scores for each group. + pvalues_ : ndarray of shape (n_groups,) + P-values from one-sided t-test testing if importance scores are + significantly greater than 0. + + Notes + ----- + The importance score for each group is calculated as the mean increase in loss + when that group is perturbed, compared to the reference loss. + A higher importance score indicates that perturbing that group leads to + worse model performance, suggesting those features are more important. """ - super().importance(X, y) - self.pvalues_ = None + self._check_fit(X) + + y_pred = getattr(self.estimator, self.method)(X) + self.loss_reference_ = self.loss(y, y_pred) + + y_pred = self.predict(X) + test_result = [] + self.loss_ = dict() + for j, y_pred_j in enumerate(y_pred): + self.loss_[j] = np.array([self.loss(y, y_pred_j[0])]) + test_result.append(y - y_pred_j[0]) + + self.importances_ = np.mean( + [self.loss_[j] - self.loss_reference_ for j in range(self._n_groups)], + axis=1, + ) + self.pvalues_ = self.test_statistic(test_result).pvalue return self.importances_ def _joblib_fit_one_group(self, estimator, X, y, key_groups): From 1ef69a6d0a069336645b84a9c19fa10e4e6e386c Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Mon, 13 Oct 2025 11:12:20 +0200 Subject: [PATCH 30/80] fix merge --- src/hidimstat/base_perturbation.py | 27 ++++++----- src/hidimstat/base_variable_importance.py | 1 - .../conditional_feature_importance.py | 20 ++++++-- src/hidimstat/leave_one_covariate_out.py | 25 ++++++---- .../permutation_feature_importance.py | 14 +++--- test/test_base_perturbation.py | 4 +- test/test_conditional_feature_importance.py | 48 ++++++++----------- test/test_leave_one_covariate_out.py | 6 +-- test/test_permutation_feature_importance.py | 22 ++++----- 9 files changed, 93 insertions(+), 74 deletions(-) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index 0cd6bef20..07020be3c 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -34,8 +34,10 @@ class BasePerturbation(BaseVariableImportance, GroupVariableImportanceMixin): Specifies the number of times the variable group (residual for CFI) is permuted. For each permutation, the perturbed model's loss is calculated and averaged over all permutations. - test_statict : - + test_statistic : callable, default=partial(wilcoxon, axis=1) + Statistical test function used to compute p-values for importance scores. + Must accept an array of values and return an object with a 'pvalue' attribute. + Default is Wilcoxon signed-rank test. features_groups: dict or None, default=None A dictionary where the keys are the group names and the values are the list of column names corresponding to each features group. If None, @@ -43,26 +45,27 @@ class BasePerturbation(BaseVariableImportance, GroupVariableImportanceMixin): n_jobs : int, default=1 The number of parallel jobs to run. Parallelization is done over the variables or groups of variables. - random_state : int, default=None - The random state to use for sampling. + random_state : int or None, default=None + Controls random number generation for permutations. Use an int for + repeatable results. Attributes ---------- features_groups : dict Mapping of feature groups identified during fit. - importances_ : ndarray + importances_ : ndarray (n_groups,) Computed importance scores for each feature group. loss_reference_ : float Loss of the original model without perturbation. loss_ : dict Loss values for each perturbed feature group. - pvalues_ : ndarray - P-values for importance scores. + pvalues_ : ndarray (n_groups,) + P-values for importance scores from the specified test_statistic. Notes ----- This is an abstract base class. Concrete implementations must override - the _permutation method. + the `_permutation` method to define how features are perturbed. """ def __init__( @@ -71,7 +74,7 @@ def __init__( method: str = "predict", loss: callable = root_mean_squared_error, n_permutations: int = 50, - test_statict=partial(wilcoxon, axis=1), + test_statistic=partial(wilcoxon, axis=1), features_groups=None, n_jobs: int = 1, random_state=None, @@ -85,10 +88,9 @@ def __init__( _check_vim_predict_method(method) self.method = method self.n_permutations = n_permutations - self.test_statistic = test_statict + self.test_statistic = test_statistic self.n_jobs = n_jobs - # variable set in fit - self.features_groups = None + # variable set in importance self.loss_reference_ = None self.loss_ = None @@ -199,6 +201,7 @@ def importance(self, X, y): """ GroupVariableImportanceMixin._check_fit(self) GroupVariableImportanceMixin._check_compatibility(self, X) + self._check_fit() y_pred = getattr(self.estimator, self.method)(X) self.loss_reference_ = self.loss(y, y_pred) diff --git a/src/hidimstat/base_variable_importance.py b/src/hidimstat/base_variable_importance.py index e7b9a7e25..7d44ff149 100644 --- a/src/hidimstat/base_variable_importance.py +++ b/src/hidimstat/base_variable_importance.py @@ -409,7 +409,6 @@ class GroupVariableImportanceMixin: """ def __init__(self, features_groups=None): - super().__init__() self.features_groups = features_groups self.n_features_groups_ = None self._features_groups_ids = None diff --git a/src/hidimstat/conditional_feature_importance.py b/src/hidimstat/conditional_feature_importance.py index ff0c421cf..81154e56f 100644 --- a/src/hidimstat/conditional_feature_importance.py +++ b/src/hidimstat/conditional_feature_importance.py @@ -39,9 +39,21 @@ class CFI(BasePerturbation): The model used to estimate the conditional distribution of a given categorical variable/group of variables given the others. Binary is considered as a special case of categorical. + features_groups: dict or None, default=None + A dictionary where the keys are the group names and the values are the + list of column names corresponding to each features group. If None, + the features_groups are identified based on the columns of X. + feature_types: str or list, default="auto" + The feature type. Supported types include "auto", "continuous", and + "categorical". If "auto", the type is inferred from the cardinality + of the unique values passed to the `fit` method. categorical_max_cardinality : int, default=10 The maximum cardinality of a variable to be considered as categorical when the variable type is inferred (set to "auto" or not provided). + test_statistic : callable, default=partial(wilcoxon, axis=1) + Statistical test function used to compute p-values for importance scores. + Must accept an array of values and return an object with a 'pvalue' attribute. + Default is Wilcoxon signed-rank test. random_state : int or None, default=None The random state to use for sampling. n_jobs : int, default=1 @@ -64,7 +76,7 @@ def __init__( features_groups=None, feature_types="auto", categorical_max_cardinality: int = 10, - test_statict=partial(wilcoxon, axis=1), + test_statistic=partial(wilcoxon, axis=1), random_state: int = None, n_jobs: int = 1, ): @@ -73,7 +85,7 @@ def __init__( method=method, loss=loss, n_permutations=n_permutations, - test_statict=test_statict, + test_statistic=test_statistic, n_jobs=n_jobs, features_groups=features_groups, random_state=random_state, @@ -228,7 +240,7 @@ def cfi( features_groups=None, feature_types="auto", categorical_max_cardinality: int = 10, - test_statict=partial(wilcoxon, axis=1), + test_statistic=partial(wilcoxon, axis=1), k_best=None, percentile=None, threshold_max=None, @@ -246,7 +258,7 @@ def cfi( features_groups=features_groups, feature_types=feature_types, categorical_max_cardinality=categorical_max_cardinality, - test_statict=test_statict, + test_statistic=test_statistic, random_state=random_state, n_jobs=n_jobs, ) diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index 1ce217924..2a1e9ceab 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -32,8 +32,10 @@ class LOCO(BasePerturbation): loss : callable, default=root_mean_squared_error The loss function to use when comparing the perturbed model to the full model. - test_statict : - + test_statistic : callable, default=partial(wilcoxon, axis=1) + Statistical test function used to compute p-values for importance scores. + Must accept an array of values and return an object with a 'pvalue' attribute. + Default is Wilcoxon signed-rank test. features_groups: dict or None, default=None A dictionary where the keys are the group names and the values are the list of column names corresponding to each features group. If None, @@ -57,7 +59,7 @@ def __init__( estimator, method: str = "predict", loss: callable = root_mean_squared_error, - test_statict=partial(wilcoxon, axis=1), + test_statistic=partial(wilcoxon, axis=1), features_groups=None, n_jobs: int = 1, ): @@ -66,7 +68,7 @@ def __init__( method=method, loss=loss, n_permutations=1, - test_statict=test_statict, + test_statistic=test_statistic, features_groups=features_groups, n_jobs=n_jobs, ) @@ -147,6 +149,7 @@ def importance(self, X, y): """ GroupVariableImportanceMixin._check_fit(self) GroupVariableImportanceMixin._check_compatibility(self, X) + self._check_fit() y_pred = getattr(self.estimator, self.method)(X) self.loss_reference_ = self.loss(y, y_pred) @@ -156,10 +159,16 @@ def importance(self, X, y): self.loss_ = dict() for j, y_pred_j in enumerate(y_pred): self.loss_[j] = np.array([self.loss(y, y_pred_j[0])]) - test_result.append(y - y_pred_j[0]) + if np.all(np.equal(y.shape, y_pred_j[0].shape)): + test_result.append(y - y_pred_j[0]) + else: + test_result.append(y - np.unique(y)[np.argmax(y_pred_j[0], axis=-1)]) self.importances_ = np.mean( - [self.loss_[j] - self.loss_reference_ for j in range(self._n_groups)], + [ + self.loss_[j] - self.loss_reference_ + for j in range(self.n_features_groups_) + ], axis=1, ) self.pvalues_ = self.test_statistic(test_result).pvalue @@ -205,7 +214,7 @@ def loco( method: str = "predict", loss: callable = root_mean_squared_error, features_groups=None, - test_statict=partial(wilcoxon, axis=1), + test_statistic=partial(wilcoxon, axis=1), k_best=None, percentile=None, threshold_min=None, @@ -216,7 +225,7 @@ def loco( estimator=estimator, method=method, loss=loss, - test_statict=test_statict, + test_statistic=test_statistic, features_groups=features_groups, n_jobs=n_jobs, ) diff --git a/src/hidimstat/permutation_feature_importance.py b/src/hidimstat/permutation_feature_importance.py index ab6515b0a..1c91fd229 100644 --- a/src/hidimstat/permutation_feature_importance.py +++ b/src/hidimstat/permutation_feature_importance.py @@ -33,8 +33,10 @@ class PFI(BasePerturbation): n_permutations : int, default=50 The number of permutations to perform. For each variable/group of variables, the mean of the losses over the `n_permutations` is computed. - test_statict : - + test_statistic : callable, default=partial(wilcoxon, axis=1) + Statistical test function used to compute p-values for importance scores. + Must accept an array of values and return an object with a 'pvalue' attribute. + Default is Wilcoxon signed-rank test. features_groups: dict or None, default=None A dictionary where the keys are the group names and the values are the list of column names corresponding to each features group. If None, @@ -56,7 +58,7 @@ def __init__( method: str = "predict", loss: callable = root_mean_squared_error, n_permutations: int = 50, - test_statict=partial(wilcoxon, axis=1), + test_statistic=partial(wilcoxon, axis=1), features_groups=None, random_state: int = None, n_jobs: int = 1, @@ -66,7 +68,7 @@ def __init__( method=method, loss=loss, n_permutations=n_permutations, - test_statict=test_statict, + test_statistic=test_statistic, features_groups=features_groups, random_state=random_state, n_jobs=n_jobs, @@ -93,7 +95,7 @@ def pfi( method: str = "predict", loss: callable = root_mean_squared_error, n_permutations: int = 50, - test_statict=partial(wilcoxon, axis=1), + test_statistic=partial(wilcoxon, axis=1), features_groups=None, k_best=None, percentile=None, @@ -107,7 +109,7 @@ def pfi( method=method, loss=loss, n_permutations=n_permutations, - test_statict=test_statict, + test_statistic=test_statistic, features_groups=features_groups, random_state=random_state, n_jobs=n_jobs, diff --git a/test/test_base_perturbation.py b/test/test_base_perturbation.py index ca78ca9a0..8ed0b8202 100644 --- a/test/test_base_perturbation.py +++ b/test/test_base_perturbation.py @@ -15,7 +15,7 @@ def test_no_implemented_methods(): basic_class._permutation(X, features_group_id=None) -def test_chek_importance(): +def test_check_importance(): """test that the methods are not implemented in the base class""" X = np.random.randint(0, 2, size=(100, 2, 1)) estimator = LinearRegression() @@ -25,4 +25,4 @@ def test_chek_importance(): with pytest.raises( ValueError, match="The importance method has not yet been called." ): - basic_class.selection() + basic_class.importance_selection() diff --git a/test/test_conditional_feature_importance.py b/test/test_conditional_feature_importance.py index ea0e204e9..3e2c7eaee 100644 --- a/test/test_conditional_feature_importance.py +++ b/test/test_conditional_feature_importance.py @@ -376,18 +376,6 @@ def test_unknown_predict_method(self, data_generator): method="unknown method", ) - def test_unfitted_predict(self, data_generator): - """Test predict method with unfitted model""" - X, y, _, _ = data_generator - fitted_model = LinearRegression().fit(X, y) - cfi = CFI( - estimator=fitted_model, - method="predict", - ) - - with pytest.raises(ValueError, match="The class is not fitted."): - cfi.predict(X) - def test_unfitted_importance(self, data_generator): """Test importance method with unfitted model""" X, y, _, _ = data_generator @@ -640,6 +628,8 @@ def test_cfi_plot(data_generator): random_state=0, ) cfi.fit(X_train, y_train) + cfi.loss_reference_ = [] + cfi.loss_ = [] # Make the plot independent of data / randomness to test only the plotting function cfi.importances_ = np.arange(X.shape[1]) fig, ax = plt.subplots(figsize=(6, 3)) @@ -667,6 +657,8 @@ def test_cfi_plot_2d_imp(data_generator): random_state=0, ) cfi.fit(X_train, y_train) + cfi.loss_reference_ = [] + cfi.loss_ = [] # Make the plot independent of data / randomness to test only the plotting function cfi.importances_ = np.stack( [ @@ -699,6 +691,8 @@ def test_cfi_plot_coverage(data_generator): random_state=0, ) cfi.fit(X_train, y_train) + cfi.loss_reference_ = [] + cfi.loss_ = [] # Make the plot independent of data / randomness to test only the plotting function cfi.importances_ = np.arange(X.shape[1]) _, ax = plt.subplots(figsize=(6, 3)) @@ -750,9 +744,9 @@ def test_cfi_repeatibility(cfi_test_data): X_train, X_test, y_test, cfi_default_parameters = cfi_test_data cfi = CFI(**cfi_default_parameters) cfi.fit(X_train) - vim = cfi.importance(X_test, y_test)["importance"] + vim = cfi.importance(X_test, y_test) # repeat - vim_repeat = cfi.importance(X_test, y_test)["importance"] + vim_repeat = cfi.importance(X_test, y_test) assert not np.array_equal(vim, vim_repeat) @@ -763,20 +757,20 @@ def test_cfi_randomness_with_none(cfi_test_data): X_train, X_test, y_test, cfi_default_parameters = cfi_test_data cfi = CFI(random_state=None, **cfi_default_parameters) cfi.fit(X_train) - vim = cfi.importance(X_test, y_test)["importance"] + vim = cfi.importance(X_test, y_test) # repeat importance - vim_repeat = cfi.importance(X_test, y_test)["importance"] + vim_repeat = cfi.importance(X_test, y_test) assert not np.array_equal(vim, vim_repeat) # refit cfi.fit(X_train) - vim_refit = cfi.importance(X_test, y_test)["importance"] + vim_refit = cfi.importance(X_test, y_test) assert not np.array_equal(vim, vim_refit) # Reproducibility cfi_2 = CFI(random_state=None, **cfi_default_parameters) cfi_2.fit(X_train) - vim_reproducibility = cfi_2.importance(X_test, y_test)["importance"] + vim_reproducibility = cfi_2.importance(X_test, y_test) assert not np.array_equal(vim, vim_reproducibility) @@ -787,20 +781,20 @@ def test_cfi_reproducibility_with_integer(cfi_test_data): X_train, X_test, y_test, cfi_default_parameters = cfi_test_data cfi = CFI(random_state=42, **cfi_default_parameters) cfi.fit(X_train) - vim = cfi.importance(X_test, y_test)["importance"] + vim = cfi.importance(X_test, y_test) # repeat importance - vim_repeat = cfi.importance(X_test, y_test)["importance"] + vim_repeat = cfi.importance(X_test, y_test) assert np.array_equal(vim, vim_repeat) # refit cfi.fit(X_train) - vim_refit = cfi.importance(X_test, y_test)["importance"] + vim_refit = cfi.importance(X_test, y_test) assert np.array_equal(vim, vim_refit) # Reproducibility cfi_2 = CFI(random_state=42, **cfi_default_parameters) cfi_2.fit(X_train) - vim_reproducibility = cfi_2.importance(X_test, y_test)["importance"] + vim_reproducibility = cfi_2.importance(X_test, y_test) assert np.array_equal(vim, vim_reproducibility) @@ -814,25 +808,25 @@ def test_cfi_reproducibility_with_rng(cfi_test_data): rng = np.random.default_rng(0) cfi = CFI(random_state=rng, **cfi_default_parameters) cfi.fit(X_train) - vim = cfi.importance(X_test, y_test)["importance"] + vim = cfi.importance(X_test, y_test) # repeat importance - vim_repeat = cfi.importance(X_test, y_test)["importance"] + vim_repeat = cfi.importance(X_test, y_test) assert not np.array_equal(vim, vim_repeat) # refit cfi.fit(X_train) - vim_refit = cfi.importance(X_test, y_test)["importance"] + vim_refit = cfi.importance(X_test, y_test) assert not np.array_equal(vim, vim_refit) # refit repeatability rng = np.random.default_rng(0) cfi.random_state = rng cfi.fit(X_train) - vim_refit_2 = cfi.importance(X_test, y_test)["importance"] + vim_refit_2 = cfi.importance(X_test, y_test) assert np.array_equal(vim, vim_refit_2) # Reproducibility cfi_2 = CFI(random_state=np.random.default_rng(0), **cfi_default_parameters) cfi_2.fit(X_train) - vim_reproducibility = cfi_2.importance(X_test, y_test)["importance"] + vim_reproducibility = cfi_2.importance(X_test, y_test) assert np.array_equal(vim, vim_reproducibility) diff --git a/test/test_leave_one_covariate_out.py b/test/test_leave_one_covariate_out.py index 3accc4aff..fc70d5d04 100644 --- a/test/test_leave_one_covariate_out.py +++ b/test/test_leave_one_covariate_out.py @@ -94,7 +94,7 @@ def test_loco(): importance_clf = loco_clf.importance(X_test, y_test_clf) assert importance_clf.shape == (2,) - assert importance[0].mean() > importance[1].mean() + assert importance_clf[0].mean() > importance_clf[1].mean() def test_raises_value_error(): @@ -120,7 +120,7 @@ def test_raises_value_error(): estimator=fitted_model, method="predict", ) - loco.predict(X) + loco.importance(X, None) with pytest.raises(ValueError, match="The class is not fitted."): fitted_model = LinearRegression().fit(X, y) loco = LOCO( @@ -145,7 +145,7 @@ def test_loco_function(): """Test the function of LOCO algorithm on a linear scenario.""" X, y, beta, noise = multivariate_simulation( n_samples=150, - n_features=200, + n_features=100, support_size=10, shuffle=False, seed=42, diff --git a/test/test_permutation_feature_importance.py b/test/test_permutation_feature_importance.py index 46ec86136..a007004a4 100644 --- a/test/test_permutation_feature_importance.py +++ b/test/test_permutation_feature_importance.py @@ -166,8 +166,8 @@ def test_pfi_repeatability(pfi_test_data): X_train, X_test, y_train, y_test, pfi_default_parameters = pfi_test_data pfi = PFI(**pfi_default_parameters, random_state=0) pfi.fit(X_train, y_train) - vim = pfi.importance(X_test, y_test)["importance"] - vim_reproducible = pfi.importance(X_test, y_test)["importance"] + vim = pfi.importance(X_test, y_test) + vim_reproducible = pfi.importance(X_test, y_test) assert np.array_equal(vim, vim_reproducible) @@ -179,17 +179,17 @@ def test_pfi_randomness_with_none(pfi_test_data): X_train, X_test, y_train, y_test, pfi_default_parameters = pfi_test_data pfi_fixed = PFI(**pfi_default_parameters, random_state=0) pfi_fixed.fit(X_train, y_train) - vim_fixed = pfi_fixed.importance(X_test, y_test)["importance"] + vim_fixed = pfi_fixed.importance(X_test, y_test) pfi_new_state = PFI(**pfi_default_parameters, random_state=1) pfi_new_state.fit(X_train, y_train) - vim_new_state = pfi_new_state.importance(X_test, y_test)["importance"] + vim_new_state = pfi_new_state.importance(X_test, y_test) assert not np.array_equal(vim_fixed, vim_new_state) pfi_none_state = PFI(**pfi_default_parameters, random_state=None) pfi_none_state.fit(X_train, y_train) - vim_none_state_1 = pfi_none_state.importance(X_test, y_test)["importance"] - vim_none_state_2 = pfi_none_state.importance(X_test, y_test)["importance"] + vim_none_state_1 = pfi_none_state.importance(X_test, y_test) + vim_none_state_2 = pfi_none_state.importance(X_test, y_test) assert not np.array_equal(vim_none_state_1, vim_none_state_2) @@ -201,11 +201,11 @@ def test_pfi_reproducibility_with_integer(pfi_test_data): X_train, X_test, y_train, y_test, pfi_default_parameters = pfi_test_data pfi_1 = PFI(**pfi_default_parameters, random_state=0) pfi_1.fit(X_train, y_train) - vim_1 = pfi_1.importance(X_test, y_test)["importance"] + vim_1 = pfi_1.importance(X_test, y_test) pfi_2 = PFI(**pfi_default_parameters, random_state=0) pfi_2.fit(X_train, y_train) - vim_2 = pfi_2.importance(X_test, y_test)["importance"] + vim_2 = pfi_2.importance(X_test, y_test) assert np.array_equal(vim_1, vim_2) @@ -219,13 +219,13 @@ def test_pfi_reproducibility_with_rng(pfi_test_data): rng = np.random.default_rng(0) pfi = PFI(**pfi_default_parameters, random_state=rng) pfi.fit(X_train, y_train) - vim = pfi.importance(X_test, y_test)["importance"] - vim_repeat = pfi.importance(X_test, y_test)["importance"] + vim = pfi.importance(X_test, y_test) + vim_repeat = pfi.importance(X_test, y_test) assert not np.array_equal(vim, vim_repeat) # Refit with same rng rng = np.random.default_rng(0) pfi_reproducibility = PFI(**pfi_default_parameters, random_state=rng) pfi_reproducibility.fit(X_train, y_train) - vim_reproducibility = pfi_reproducibility.importance(X_test, y_test)["importance"] + vim_reproducibility = pfi_reproducibility.importance(X_test, y_test) assert np.array_equal(vim, vim_reproducibility) From 83ae84915071ad875c8cd845432bc8a1be4437ad Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Mon, 13 Oct 2025 19:18:33 +0200 Subject: [PATCH 31/80] fix example --- examples/plot_loco.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/plot_loco.py b/examples/plot_loco.py index 3ac3c7d8c..dbd9e62a5 100644 --- a/examples/plot_loco.py +++ b/examples/plot_loco.py @@ -82,7 +82,7 @@ # importance. This process is repeated for all features to assess their individual # contributions. loco.fit(X_train, y_train) - importances = loco.importance(X_test, y_test)["importance"] + importances = loco.importance(X_test, y_test) df_list.append( pd.DataFrame( { From a3cd681e95523d756d397225fc017d3dda77683d Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Mon, 13 Oct 2025 19:41:54 +0200 Subject: [PATCH 32/80] fix example --- examples/plot_diabetes_variable_importance_example.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/plot_diabetes_variable_importance_example.py b/examples/plot_diabetes_variable_importance_example.py index bf2011bd1..959c9b43f 100644 --- a/examples/plot_diabetes_variable_importance_example.py +++ b/examples/plot_diabetes_variable_importance_example.py @@ -161,6 +161,9 @@ # %% # Define a function to compute the p-value from importance values # --------------------------------------------------------------- +from scipy.stat import norm + + def compute_pval(vim): mean_vim = np.mean(vim, axis=0) std_vim = np.std(vim, axis=0) From b3f336ab34db255384d617392304cd6d528be8f9 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Mon, 13 Oct 2025 19:48:56 +0200 Subject: [PATCH 33/80] fix import --- examples/plot_diabetes_variable_importance_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/plot_diabetes_variable_importance_example.py b/examples/plot_diabetes_variable_importance_example.py index 959c9b43f..e7cade8b8 100644 --- a/examples/plot_diabetes_variable_importance_example.py +++ b/examples/plot_diabetes_variable_importance_example.py @@ -161,7 +161,7 @@ # %% # Define a function to compute the p-value from importance values # --------------------------------------------------------------- -from scipy.stat import norm +from scipy.stats import norm def compute_pval(vim): From 6dc8d67148048a9142989d7c578ab7a4664225ed Mon Sep 17 00:00:00 2001 From: lionel kusch Date: Tue, 14 Oct 2025 10:01:52 +0200 Subject: [PATCH 34/80] Update src/hidimstat/leave_one_covariate_out.py Co-authored-by: bthirion --- src/hidimstat/leave_one_covariate_out.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index 2a1e9ceab..a1e2849f7 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -221,7 +221,7 @@ def loco( threshold_max=None, n_jobs: int = 1, ): - methods = LOCO( + method = LOCO( estimator=estimator, method=method, loss=loss, From afd03cb383857313fbb406f48cfdf1855536fd2a Mon Sep 17 00:00:00 2001 From: lionel kusch Date: Tue, 14 Oct 2025 10:02:04 +0200 Subject: [PATCH 35/80] Update src/hidimstat/base_perturbation.py Co-authored-by: bthirion --- src/hidimstat/base_perturbation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index 07020be3c..2b4f827b0 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -189,7 +189,7 @@ def importance(self, X, y): importances_ : ndarray of shape (n_groups,) The calculated importance scores for each group. pvalues_ : ndarray of shape (n_groups,) - P-values from one-sided t-test testing if importance scores are + P-values from one-sample t-test testing if importance scores are significantly greater than 0. Notes From 3fa9d01aa6e1ddfa8bf86913f4c649e4fe2c20b6 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Tue, 14 Oct 2025 10:09:30 +0200 Subject: [PATCH 36/80] fix modification --- src/hidimstat/leave_one_covariate_out.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index a1e2849f7..0271762ef 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -229,14 +229,14 @@ def loco( features_groups=features_groups, n_jobs=n_jobs, ) - methods.fit_importance(X, y) - selection = methods.importance_selection( + method.fit_importance(X, y) + selection = method.importance_selection( k_best=k_best, percentile=percentile, threshold_min=threshold_min, threshold_max=threshold_max, ) - return selection, methods.importances_, methods.pvalues_ + return selection, method.importances_, method.pvalues_ # use the docstring of the class for the function From 120242641ec21909e74e53d299a19fa25b4dc484 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Tue, 14 Oct 2025 10:20:40 +0200 Subject: [PATCH 37/80] Remove the wrong merge --- ...ot_diabetes_variable_importance_example.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/examples/plot_diabetes_variable_importance_example.py b/examples/plot_diabetes_variable_importance_example.py index e7cade8b8..01ddd1f59 100644 --- a/examples/plot_diabetes_variable_importance_example.py +++ b/examples/plot_diabetes_variable_importance_example.py @@ -158,19 +158,6 @@ pfi_importance_list.append(importance) -# %% -# Define a function to compute the p-value from importance values -# --------------------------------------------------------------- -from scipy.stats import norm - - -def compute_pval(vim): - mean_vim = np.mean(vim, axis=0) - std_vim = np.std(vim, axis=0) - pval = norm.sf(mean_vim / std_vim) - return np.clip(pval, 1e-10, 1 - 1e-10) - - # %% # Analyze the results # ------------------- @@ -180,7 +167,7 @@ def compute_pval(vim): from scipy.stats import ttest_1samp cfi_vim_arr = np.array(cfi_importance_list) / 2 -cfi_pval = compute_pval(cfi_vim_arr) +cfi_pval = ttest_1samp(cfi_vim_arr, 0, alternative="greater").pvalue vim = [ pd.DataFrame( @@ -196,7 +183,7 @@ def compute_pval(vim): ] loco_vim_arr = np.array(loco_importance_list) -loco_pval = compute_pval(loco_vim_arr) +loco_pval = ttest_1samp(loco_vim_arr, 0, alternative="greater").pvalue vim += [ pd.DataFrame( @@ -212,7 +199,7 @@ def compute_pval(vim): ] pfi_vim_arr = np.array(pfi_importance_list) -pfi_pval = compute_pval(pfi_vim_arr) +pfi_pval = ttest_1samp(pfi_vim_arr, 0, alternative="greater").pvalue vim += [ pd.DataFrame( From 75d65786669e2e01752df4e96c42f33f267c0f4c Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Wed, 15 Oct 2025 18:59:27 +0200 Subject: [PATCH 38/80] Add check_test_statistic --- src/hidimstat/_utils/utils.py | 39 ++++++++++++++++++++++++++++++ src/hidimstat/base_perturbation.py | 13 +++++----- test/_utils/test_utils.py | 27 ++++++++++++++++++++- 3 files changed, 72 insertions(+), 7 deletions(-) diff --git a/src/hidimstat/_utils/utils.py b/src/hidimstat/_utils/utils.py index 0f57b45ca..c4cdf3e52 100644 --- a/src/hidimstat/_utils/utils.py +++ b/src/hidimstat/_utils/utils.py @@ -1,7 +1,9 @@ +from functools import partial import numbers import numpy as np from numpy.random import RandomState +from scipy.stats import ttest_1samp, wilcoxon def _check_vim_predict_method(method): @@ -136,3 +138,40 @@ def seed_estimator(estimator, random_state=None): setattr(value, "random_state", RandomState(rng.bit_generator)) return estimator + + +def check_test_statistic(test): + """ + Validates and returns a test statistic function. + + Parameters + ---------- + test : str or callable + If str, must be either 'ttest' or 'wilcoxon'. + If callable, must be a function that can be used as a test statistic. + + Returns + ------- + callable + A function that can be used as a test statistic. + For string inputs, returns a partial function of either ttest_1samp or wilcoxon. + For callable inputs, returns the input function. + + Raises + ------ + ValueError + If test is a string but not one of the supported test names ('ttest' or 'wilcoxon'). + ValueError + If test is neither a string nor a callable. + """ + if isinstance(test, str): + if test == "ttest": + return partial(ttest_1samp, axis=1) + elif test == "wilcoxon": + return partial(wilcoxon, axis=1) + else: + raise ValueError(f"the test '{test}' is not supported") + elif callable(test): + return test + else: + raise ValueError("The test '{}' is not a valid test".format(test)) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index 2b4f827b0..98a789b97 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -1,12 +1,13 @@ -from functools import partial - import numpy as np from joblib import Parallel, delayed -from scipy.stats import wilcoxon from sklearn.base import check_is_fitted from sklearn.metrics import root_mean_squared_error -from hidimstat._utils.utils import _check_vim_predict_method, check_random_state +from hidimstat._utils.utils import ( + _check_vim_predict_method, + check_random_state, + check_test_statistic, +) from hidimstat.base_variable_importance import ( BaseVariableImportance, GroupVariableImportanceMixin, @@ -74,7 +75,7 @@ def __init__( method: str = "predict", loss: callable = root_mean_squared_error, n_permutations: int = 50, - test_statistic=partial(wilcoxon, axis=1), + test_statistic="wilcoxon", features_groups=None, n_jobs: int = 1, random_state=None, @@ -88,7 +89,7 @@ def __init__( _check_vim_predict_method(method) self.method = method self.n_permutations = n_permutations - self.test_statistic = test_statistic + self.test_statistic = check_test_statistic(test_statistic) self.n_jobs = n_jobs # variable set in importance diff --git a/test/_utils/test_utils.py b/test/_utils/test_utils.py index 6794f425e..e9573bf2c 100644 --- a/test/_utils/test_utils.py +++ b/test/_utils/test_utils.py @@ -1,7 +1,14 @@ +from functools import partial + import numpy as np import pytest +from scipy.stats import ttest_1samp, wilcoxon -from hidimstat._utils.utils import check_random_state, get_fitted_attributes +from hidimstat._utils.utils import ( + check_random_state, + get_fitted_attributes, + check_test_statistic, +) def test_generated_attributes(): @@ -56,3 +63,21 @@ def test_error(): ValueError, match="cannot be used to seed a numpy.random.Generator instance" ): check_random_state(random_state) + + +def test_check_test_statistic(): + "test the function of check" + test_func = check_test_statistic("wilcoxon") + assert test_func.func == wilcoxon + test_func = check_test_statistic("ttest") + assert test_func.func == ttest_1samp + test_func = check_test_statistic(print) + assert test_func == print + + +def test_check_test_statistic_warning(): + "test the exception" + with pytest.raises(ValueError, match="the test 'test' is not supported"): + check_test_statistic("test") + with pytest.raises(ValueError, match="is not a valid test"): + check_test_statistic([]) From ae3dfa9193157994cdfdd9038ec86141c03d8c76 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Wed, 15 Oct 2025 19:23:31 +0200 Subject: [PATCH 39/80] change name --- src/hidimstat/_utils/utils.py | 2 +- src/hidimstat/base_perturbation.py | 16 ++++++++-------- src/hidimstat/conditional_feature_importance.py | 12 +++++------- src/hidimstat/leave_one_covariate_out.py | 14 ++++++-------- src/hidimstat/permutation_feature_importance.py | 12 +++++------- test/_utils/test_utils.py | 12 ++++++------ 6 files changed, 31 insertions(+), 37 deletions(-) diff --git a/src/hidimstat/_utils/utils.py b/src/hidimstat/_utils/utils.py index c4cdf3e52..58b809d6a 100644 --- a/src/hidimstat/_utils/utils.py +++ b/src/hidimstat/_utils/utils.py @@ -140,7 +140,7 @@ def seed_estimator(estimator, random_state=None): return estimator -def check_test_statistic(test): +def check_statistical_test(test): """ Validates and returns a test statistic function. diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index 98a789b97..e66b1beff 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -1,12 +1,12 @@ import numpy as np from joblib import Parallel, delayed from sklearn.base import check_is_fitted -from sklearn.metrics import root_mean_squared_error +from sklearn.metrics import mean_squared_error from hidimstat._utils.utils import ( _check_vim_predict_method, check_random_state, - check_test_statistic, + check_statistical_test, ) from hidimstat.base_variable_importance import ( BaseVariableImportance, @@ -27,7 +27,7 @@ class BasePerturbation(BaseVariableImportance, GroupVariableImportanceMixin): The method used for making predictions. This determines the predictions passed to the loss function. Supported methods are "predict", "predict_proba", "decision_function", "transform". - loss : callable, default=root_mean_squared_error + loss : callable, default=mean_squared_error The function to compute the loss when comparing the perturbed model to the original model. n_permutations : int, default=50 @@ -35,7 +35,7 @@ class BasePerturbation(BaseVariableImportance, GroupVariableImportanceMixin): Specifies the number of times the variable group (residual for CFI) is permuted. For each permutation, the perturbed model's loss is calculated and averaged over all permutations. - test_statistic : callable, default=partial(wilcoxon, axis=1) + statistical_test : callable, default=partial(wilcoxon, axis=1) Statistical test function used to compute p-values for importance scores. Must accept an array of values and return an object with a 'pvalue' attribute. Default is Wilcoxon signed-rank test. @@ -73,9 +73,9 @@ def __init__( self, estimator, method: str = "predict", - loss: callable = root_mean_squared_error, + loss: callable = mean_squared_error, n_permutations: int = 50, - test_statistic="wilcoxon", + statistical_test="wilcoxon", features_groups=None, n_jobs: int = 1, random_state=None, @@ -89,7 +89,7 @@ def __init__( _check_vim_predict_method(method) self.method = method self.n_permutations = n_permutations - self.test_statistic = check_test_statistic(test_statistic) + self.statistical_test = check_statistical_test(statistical_test) self.n_jobs = n_jobs # variable set in importance @@ -222,7 +222,7 @@ def importance(self, X, y): ] ) self.importances_ = np.mean(test_result, axis=1) - self.pvalues_ = self.test_statistic(test_result).pvalue + self.pvalues_ = self.statistical_test(test_result).pvalue return self.importances_ def fit_importance(self, X, y): diff --git a/src/hidimstat/conditional_feature_importance.py b/src/hidimstat/conditional_feature_importance.py index 81154e56f..cefbd2963 100644 --- a/src/hidimstat/conditional_feature_importance.py +++ b/src/hidimstat/conditional_feature_importance.py @@ -50,10 +50,8 @@ class CFI(BasePerturbation): categorical_max_cardinality : int, default=10 The maximum cardinality of a variable to be considered as categorical when the variable type is inferred (set to "auto" or not provided). - test_statistic : callable, default=partial(wilcoxon, axis=1) - Statistical test function used to compute p-values for importance scores. - Must accept an array of values and return an object with a 'pvalue' attribute. - Default is Wilcoxon signed-rank test. + statistical_test : callable or str, default="wilcoxon" + Statistical test function for computing p-values of importance scores. random_state : int or None, default=None The random state to use for sampling. n_jobs : int, default=1 @@ -76,7 +74,7 @@ def __init__( features_groups=None, feature_types="auto", categorical_max_cardinality: int = 10, - test_statistic=partial(wilcoxon, axis=1), + statistical_test=partial(wilcoxon, axis=1), random_state: int = None, n_jobs: int = 1, ): @@ -85,7 +83,7 @@ def __init__( method=method, loss=loss, n_permutations=n_permutations, - test_statistic=test_statistic, + statistical_test=statistical_test, n_jobs=n_jobs, features_groups=features_groups, random_state=random_state, @@ -258,7 +256,7 @@ def cfi( features_groups=features_groups, feature_types=feature_types, categorical_max_cardinality=categorical_max_cardinality, - test_statistic=test_statistic, + statistical_test=test_statistic, random_state=random_state, n_jobs=n_jobs, ) diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index 0271762ef..dd519f085 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -32,10 +32,8 @@ class LOCO(BasePerturbation): loss : callable, default=root_mean_squared_error The loss function to use when comparing the perturbed model to the full model. - test_statistic : callable, default=partial(wilcoxon, axis=1) - Statistical test function used to compute p-values for importance scores. - Must accept an array of values and return an object with a 'pvalue' attribute. - Default is Wilcoxon signed-rank test. + statistical_test : callable or str, default="wilcoxon" + Statistical test function for computing p-values of importance scores. features_groups: dict or None, default=None A dictionary where the keys are the group names and the values are the list of column names corresponding to each features group. If None, @@ -59,7 +57,7 @@ def __init__( estimator, method: str = "predict", loss: callable = root_mean_squared_error, - test_statistic=partial(wilcoxon, axis=1), + statistical_test=partial(wilcoxon, axis=1), features_groups=None, n_jobs: int = 1, ): @@ -68,7 +66,7 @@ def __init__( method=method, loss=loss, n_permutations=1, - test_statistic=test_statistic, + statistical_test=statistical_test, features_groups=features_groups, n_jobs=n_jobs, ) @@ -171,7 +169,7 @@ def importance(self, X, y): ], axis=1, ) - self.pvalues_ = self.test_statistic(test_result).pvalue + self.pvalues_ = self.statistical_test(test_result).pvalue return self.importances_ def _joblib_fit_one_features_group(self, estimator, X, y, key_features_group): @@ -225,7 +223,7 @@ def loco( estimator=estimator, method=method, loss=loss, - test_statistic=test_statistic, + statistical_test=test_statistic, features_groups=features_groups, n_jobs=n_jobs, ) diff --git a/src/hidimstat/permutation_feature_importance.py b/src/hidimstat/permutation_feature_importance.py index 1c91fd229..c7dd9276e 100644 --- a/src/hidimstat/permutation_feature_importance.py +++ b/src/hidimstat/permutation_feature_importance.py @@ -33,10 +33,8 @@ class PFI(BasePerturbation): n_permutations : int, default=50 The number of permutations to perform. For each variable/group of variables, the mean of the losses over the `n_permutations` is computed. - test_statistic : callable, default=partial(wilcoxon, axis=1) - Statistical test function used to compute p-values for importance scores. - Must accept an array of values and return an object with a 'pvalue' attribute. - Default is Wilcoxon signed-rank test. + statistical_test : callable or str, default="wilcoxon" + Statistical test function for computing p-values of importance scores. features_groups: dict or None, default=None A dictionary where the keys are the group names and the values are the list of column names corresponding to each features group. If None, @@ -58,7 +56,7 @@ def __init__( method: str = "predict", loss: callable = root_mean_squared_error, n_permutations: int = 50, - test_statistic=partial(wilcoxon, axis=1), + statistical_test="wilcoxon", features_groups=None, random_state: int = None, n_jobs: int = 1, @@ -68,7 +66,7 @@ def __init__( method=method, loss=loss, n_permutations=n_permutations, - test_statistic=test_statistic, + statistical_test=statistical_test, features_groups=features_groups, random_state=random_state, n_jobs=n_jobs, @@ -109,7 +107,7 @@ def pfi( method=method, loss=loss, n_permutations=n_permutations, - test_statistic=test_statistic, + statistical_test=test_statistic, features_groups=features_groups, random_state=random_state, n_jobs=n_jobs, diff --git a/test/_utils/test_utils.py b/test/_utils/test_utils.py index e9573bf2c..94aed2b3a 100644 --- a/test/_utils/test_utils.py +++ b/test/_utils/test_utils.py @@ -7,7 +7,7 @@ from hidimstat._utils.utils import ( check_random_state, get_fitted_attributes, - check_test_statistic, + check_statistical_test, ) @@ -67,17 +67,17 @@ def test_error(): def test_check_test_statistic(): "test the function of check" - test_func = check_test_statistic("wilcoxon") + test_func = check_statistical_test("wilcoxon") assert test_func.func == wilcoxon - test_func = check_test_statistic("ttest") + test_func = check_statistical_test("ttest") assert test_func.func == ttest_1samp - test_func = check_test_statistic(print) + test_func = check_statistical_test(print) assert test_func == print def test_check_test_statistic_warning(): "test the exception" with pytest.raises(ValueError, match="the test 'test' is not supported"): - check_test_statistic("test") + check_statistical_test("test") with pytest.raises(ValueError, match="is not a valid test"): - check_test_statistic([]) + check_statistical_test([]) From 035f4c829e10089dbe8927222dbe2ed3992f73a3 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Wed, 15 Oct 2025 19:29:37 +0200 Subject: [PATCH 40/80] fix import --- src/hidimstat/_utils/utils.py | 2 +- test/_utils/test_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hidimstat/_utils/utils.py b/src/hidimstat/_utils/utils.py index 58b809d6a..ebeb3b1be 100644 --- a/src/hidimstat/_utils/utils.py +++ b/src/hidimstat/_utils/utils.py @@ -1,5 +1,5 @@ -from functools import partial import numbers +from functools import partial import numpy as np from numpy.random import RandomState diff --git a/test/_utils/test_utils.py b/test/_utils/test_utils.py index 94aed2b3a..509502670 100644 --- a/test/_utils/test_utils.py +++ b/test/_utils/test_utils.py @@ -6,8 +6,8 @@ from hidimstat._utils.utils import ( check_random_state, - get_fitted_attributes, check_statistical_test, + get_fitted_attributes, ) From 251584c8adc7ac9984e6d911d08911fefc12da16 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Wed, 15 Oct 2025 19:33:13 +0200 Subject: [PATCH 41/80] change name --- src/hidimstat/base_perturbation.py | 45 ++++++++++++------------------ 1 file changed, 18 insertions(+), 27 deletions(-) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index e66b1beff..9312fa4ed 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -16,13 +16,13 @@ class BasePerturbation(BaseVariableImportance, GroupVariableImportanceMixin): """ - Base class for model-agnostic variable importance measures based on - perturbation. + Abstract base class for model-agnostic variable importance measures using + perturbation techniques. Parameters ---------- - estimator : sklearn compatible estimator - The estimator to use for the prediction. + estimator : sklearn-compatible estimator + The fitted estimator used for predictions. method : str, default="predict" The method used for making predictions. This determines the predictions passed to the loss function. Supported methods are "predict", @@ -31,42 +31,33 @@ class BasePerturbation(BaseVariableImportance, GroupVariableImportanceMixin): The function to compute the loss when comparing the perturbed model to the original model. n_permutations : int, default=50 - This parameter is relevant only for PFI or CFI. - Specifies the number of times the variable group (residual for CFI) is - permuted. For each permutation, the perturbed model's loss is calculated - and averaged over all permutations. - statistical_test : callable, default=partial(wilcoxon, axis=1) - Statistical test function used to compute p-values for importance scores. - Must accept an array of values and return an object with a 'pvalue' attribute. - Default is Wilcoxon signed-rank test. - features_groups: dict or None, default=None - A dictionary where the keys are the group names and the values are the - list of column names corresponding to each features group. If None, - the features_groups are identified based on the columns of X. + Number of permutations for each feature group. + statistical_test : callable or str, default="wilcoxon" + Statistical test function for computing p-values of importance scores. + features_groups : dict or None, default=None + Mapping of group names to lists of feature indices or names. If None, groups are inferred. n_jobs : int, default=1 - The number of parallel jobs to run. Parallelization is done over the - variables or groups of variables. + Number of parallel jobs for computation. random_state : int or None, default=None - Controls random number generation for permutations. Use an int for - repeatable results. + Seed for reproducible permutations. Attributes ---------- features_groups : dict Mapping of feature groups identified during fit. importances_ : ndarray (n_groups,) - Computed importance scores for each feature group. + Importance scores for each feature group. loss_reference_ : float - Loss of the original model without perturbation. + Loss on original (non-perturbed) data. loss_ : dict - Loss values for each perturbed feature group. - pvalues_ : ndarray (n_groups,) - P-values for importance scores from the specified test_statistic. + Loss values for each permutation of each group. + pvalues_ : ndarray of shape (n_groups,) + P-values for importance scores. Notes ----- - This is an abstract base class. Concrete implementations must override - the `_permutation` method to define how features are perturbed. + This class is abstract. Subclasses must implement the `_permutation` method + to define how feature groups are perturbed. """ def __init__( From 1c15c5bd01cca83a7e57b22b24e2b5f31abe1e77 Mon Sep 17 00:00:00 2001 From: lionel kusch Date: Thu, 16 Oct 2025 11:50:34 +0200 Subject: [PATCH 42/80] Update src/hidimstat/_utils/utils.py Co-authored-by: Joseph Paillard --- src/hidimstat/_utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hidimstat/_utils/utils.py b/src/hidimstat/_utils/utils.py index ebeb3b1be..c5fa5c823 100644 --- a/src/hidimstat/_utils/utils.py +++ b/src/hidimstat/_utils/utils.py @@ -166,7 +166,7 @@ def check_statistical_test(test): """ if isinstance(test, str): if test == "ttest": - return partial(ttest_1samp, axis=1) + return partial(ttest_1samp, popmean=0, alternative='greater', axis=1) elif test == "wilcoxon": return partial(wilcoxon, axis=1) else: From 9414368aeb176a37b40ce7715823a6335d223eb7 Mon Sep 17 00:00:00 2001 From: lionel kusch Date: Thu, 16 Oct 2025 11:51:09 +0200 Subject: [PATCH 43/80] Update src/hidimstat/conditional_feature_importance.py Co-authored-by: Joseph Paillard --- src/hidimstat/conditional_feature_importance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hidimstat/conditional_feature_importance.py b/src/hidimstat/conditional_feature_importance.py index 8e888d475..467f9dc9b 100644 --- a/src/hidimstat/conditional_feature_importance.py +++ b/src/hidimstat/conditional_feature_importance.py @@ -5,7 +5,7 @@ from scipy.stats import wilcoxon from sklearn.base import BaseEstimator, check_is_fitted, clone from sklearn.linear_model import LogisticRegressionCV, RidgeCV -from sklearn.metrics import root_mean_squared_error +from sklearn.metrics import mean_squared_error from hidimstat._utils.docstring import _aggregate_docstring from hidimstat.base_perturbation import BasePerturbation From 8a8587faffaeed973ed7ad414bd38b84144ad0c8 Mon Sep 17 00:00:00 2001 From: lionel kusch Date: Thu, 16 Oct 2025 11:52:06 +0200 Subject: [PATCH 44/80] Update src/hidimstat/conditional_feature_importance.py Co-authored-by: Joseph Paillard --- src/hidimstat/conditional_feature_importance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hidimstat/conditional_feature_importance.py b/src/hidimstat/conditional_feature_importance.py index 467f9dc9b..4aa4e1872 100644 --- a/src/hidimstat/conditional_feature_importance.py +++ b/src/hidimstat/conditional_feature_importance.py @@ -26,7 +26,7 @@ class CFI(BasePerturbation): The method to use for the prediction. This determines the predictions passed to the loss function. Supported methods are "predict", "predict_proba" or "decision_function". - loss : callable, default=root_mean_squared_error + loss : callable, default=mean_squared_error The loss function to use when comparing the perturbed model to the full model. n_permutations : int, default=50 From 631ee83b0c4a7f103a2092158193aafe2026c789 Mon Sep 17 00:00:00 2001 From: lionel kusch Date: Thu, 16 Oct 2025 11:52:18 +0200 Subject: [PATCH 45/80] Update src/hidimstat/permutation_feature_importance.py Co-authored-by: Joseph Paillard --- src/hidimstat/permutation_feature_importance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hidimstat/permutation_feature_importance.py b/src/hidimstat/permutation_feature_importance.py index c7dd9276e..ca29e148f 100644 --- a/src/hidimstat/permutation_feature_importance.py +++ b/src/hidimstat/permutation_feature_importance.py @@ -91,7 +91,7 @@ def pfi( X, y, method: str = "predict", - loss: callable = root_mean_squared_error, + loss: callable = mean_squared_error, n_permutations: int = 50, test_statistic=partial(wilcoxon, axis=1), features_groups=None, From 39ebce1a769dcebd163aadad5c34d597ecf0ae7c Mon Sep 17 00:00:00 2001 From: lionel kusch Date: Thu, 16 Oct 2025 11:52:34 +0200 Subject: [PATCH 46/80] Update src/hidimstat/_utils/utils.py Co-authored-by: Joseph Paillard --- src/hidimstat/_utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hidimstat/_utils/utils.py b/src/hidimstat/_utils/utils.py index c5fa5c823..4d2d22c60 100644 --- a/src/hidimstat/_utils/utils.py +++ b/src/hidimstat/_utils/utils.py @@ -168,7 +168,7 @@ def check_statistical_test(test): if test == "ttest": return partial(ttest_1samp, popmean=0, alternative='greater', axis=1) elif test == "wilcoxon": - return partial(wilcoxon, axis=1) + return partial(wilcoxon, alternative='greater', axis=1) else: raise ValueError(f"the test '{test}' is not supported") elif callable(test): From 971da6143780402be1363b4dbc38663e69296170 Mon Sep 17 00:00:00 2001 From: lionel kusch Date: Thu, 16 Oct 2025 11:52:46 +0200 Subject: [PATCH 47/80] Update test/test_conditional_feature_importance.py Co-authored-by: Joseph Paillard --- test/test_conditional_feature_importance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_conditional_feature_importance.py b/test/test_conditional_feature_importance.py index 3e2c7eaee..767516c87 100644 --- a/test/test_conditional_feature_importance.py +++ b/test/test_conditional_feature_importance.py @@ -277,7 +277,7 @@ def test_init(self, data_generator): ) assert cfi.n_jobs == 1 assert cfi.n_permutations == 50 - assert cfi.loss == root_mean_squared_error + assert cfi.loss == mean_squared_error assert cfi.method == "predict" assert cfi.categorical_max_cardinality == 10 assert isinstance(cfi.imputation_model_categorical, LogisticRegressionCV) From c7adec914c5d08ae0fc397944c2ea19ad3205d5c Mon Sep 17 00:00:00 2001 From: lionel kusch Date: Thu, 16 Oct 2025 11:53:00 +0200 Subject: [PATCH 48/80] Update test/test_conditional_feature_importance.py Co-authored-by: Joseph Paillard --- test/test_conditional_feature_importance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_conditional_feature_importance.py b/test/test_conditional_feature_importance.py index 767516c87..ab867695a 100644 --- a/test/test_conditional_feature_importance.py +++ b/test/test_conditional_feature_importance.py @@ -11,7 +11,7 @@ LogisticRegressionCV, RidgeCV, ) -from sklearn.metrics import log_loss, root_mean_squared_error +from sklearn.metrics import log_loss, mean_squared_error from sklearn.model_selection import train_test_split from hidimstat import CFI, cfi From a3c79063f3810ca6184cd1916ff9311eb8cbdcc7 Mon Sep 17 00:00:00 2001 From: lionel kusch Date: Thu, 16 Oct 2025 11:53:20 +0200 Subject: [PATCH 49/80] Update src/hidimstat/conditional_feature_importance.py Co-authored-by: Joseph Paillard --- src/hidimstat/conditional_feature_importance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hidimstat/conditional_feature_importance.py b/src/hidimstat/conditional_feature_importance.py index 4aa4e1872..414a0090f 100644 --- a/src/hidimstat/conditional_feature_importance.py +++ b/src/hidimstat/conditional_feature_importance.py @@ -67,7 +67,7 @@ def __init__( self, estimator, method: str = "predict", - loss: callable = root_mean_squared_error, + loss: callable = mean_squared_error, n_permutations: int = 50, imputation_model_continuous=RidgeCV(), imputation_model_categorical=LogisticRegressionCV(), From f4c8ce4d891dc8d3eede66dbec46ed570dbc1588 Mon Sep 17 00:00:00 2001 From: lionel kusch Date: Thu, 16 Oct 2025 11:53:38 +0200 Subject: [PATCH 50/80] Update src/hidimstat/leave_one_covariate_out.py Co-authored-by: Joseph Paillard --- src/hidimstat/leave_one_covariate_out.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index dd519f085..8c4b1db15 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -5,7 +5,7 @@ from joblib import Parallel, delayed from scipy.stats import wilcoxon from sklearn.base import check_is_fitted, clone -from sklearn.metrics import root_mean_squared_error +from sklearn.metrics import mean_squared_error from hidimstat._utils.docstring import _aggregate_docstring from hidimstat.base_perturbation import BasePerturbation From c983eab13fb4ca40b43fec19bd2a698e4c78f7db Mon Sep 17 00:00:00 2001 From: lionel kusch Date: Thu, 16 Oct 2025 11:53:59 +0200 Subject: [PATCH 51/80] Update src/hidimstat/leave_one_covariate_out.py Co-authored-by: Joseph Paillard --- src/hidimstat/leave_one_covariate_out.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index 8c4b1db15..23503d60c 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -29,7 +29,7 @@ class LOCO(BasePerturbation): The method to use for the prediction. This determines the predictions passed to the loss function. Supported methods are "predict", "predict_proba" or "decision_function". - loss : callable, default=root_mean_squared_error + loss : callable, default=mean_squared_error The loss function to use when comparing the perturbed model to the full model. statistical_test : callable or str, default="wilcoxon" From f91ac6cf43112987e8df9ef2083dfd2c4efe9706 Mon Sep 17 00:00:00 2001 From: lionel kusch Date: Thu, 16 Oct 2025 11:54:16 +0200 Subject: [PATCH 52/80] Update src/hidimstat/leave_one_covariate_out.py Co-authored-by: Joseph Paillard --- src/hidimstat/leave_one_covariate_out.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index 23503d60c..9f0840a12 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -56,7 +56,7 @@ def __init__( self, estimator, method: str = "predict", - loss: callable = root_mean_squared_error, + loss: callable = mean_squared_error, statistical_test=partial(wilcoxon, axis=1), features_groups=None, n_jobs: int = 1, From 38c6bc7e60dac4972d2e82821408d5f7bf790622 Mon Sep 17 00:00:00 2001 From: lionel kusch Date: Thu, 16 Oct 2025 11:54:35 +0200 Subject: [PATCH 53/80] Update src/hidimstat/permutation_feature_importance.py Co-authored-by: Joseph Paillard --- src/hidimstat/permutation_feature_importance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hidimstat/permutation_feature_importance.py b/src/hidimstat/permutation_feature_importance.py index ca29e148f..7434db953 100644 --- a/src/hidimstat/permutation_feature_importance.py +++ b/src/hidimstat/permutation_feature_importance.py @@ -27,7 +27,7 @@ class PFI(BasePerturbation): The method to use for the prediction. This determines the predictions passed to the loss function. Supported methods are "predict", "predict_proba" or "decision_function". - loss : callable, default=root_mean_squared_error + loss : callable, default=mean_squared_error The loss function to use when comparing the perturbed model to the full model. n_permutations : int, default=50 From 0bb880e2aca887e45faeebf9309fdbb5573041db Mon Sep 17 00:00:00 2001 From: lionel kusch Date: Thu, 16 Oct 2025 11:54:53 +0200 Subject: [PATCH 54/80] Update src/hidimstat/permutation_feature_importance.py Co-authored-by: Joseph Paillard --- src/hidimstat/permutation_feature_importance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hidimstat/permutation_feature_importance.py b/src/hidimstat/permutation_feature_importance.py index 7434db953..6f545f5bc 100644 --- a/src/hidimstat/permutation_feature_importance.py +++ b/src/hidimstat/permutation_feature_importance.py @@ -54,7 +54,7 @@ def __init__( self, estimator, method: str = "predict", - loss: callable = root_mean_squared_error, + loss: callable = mean_squared_error, n_permutations: int = 50, statistical_test="wilcoxon", features_groups=None, From a92ea830b6e095231fd60378440f51d19be57ca6 Mon Sep 17 00:00:00 2001 From: lionel kusch Date: Thu, 16 Oct 2025 11:55:13 +0200 Subject: [PATCH 55/80] Update src/hidimstat/permutation_feature_importance.py Co-authored-by: Joseph Paillard --- src/hidimstat/permutation_feature_importance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hidimstat/permutation_feature_importance.py b/src/hidimstat/permutation_feature_importance.py index 6f545f5bc..e088d74c5 100644 --- a/src/hidimstat/permutation_feature_importance.py +++ b/src/hidimstat/permutation_feature_importance.py @@ -2,7 +2,7 @@ import numpy as np from scipy.stats import wilcoxon -from sklearn.metrics import root_mean_squared_error +from sklearn.metrics import mean_squared_error from hidimstat._utils.docstring import _aggregate_docstring from hidimstat._utils.utils import check_random_state From 2166f2c36155495dd5729ec7ff484e0f278fff02 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Thu, 16 Oct 2025 12:00:40 +0200 Subject: [PATCH 56/80] fix format --- src/hidimstat/_utils/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hidimstat/_utils/utils.py b/src/hidimstat/_utils/utils.py index 4d2d22c60..382396469 100644 --- a/src/hidimstat/_utils/utils.py +++ b/src/hidimstat/_utils/utils.py @@ -166,9 +166,9 @@ def check_statistical_test(test): """ if isinstance(test, str): if test == "ttest": - return partial(ttest_1samp, popmean=0, alternative='greater', axis=1) + return partial(ttest_1samp, popmean=0, alternative="greater", axis=1) elif test == "wilcoxon": - return partial(wilcoxon, alternative='greater', axis=1) + return partial(wilcoxon, alternative="greater", axis=1) else: raise ValueError(f"the test '{test}' is not supported") elif callable(test): From 4cc3598de8a93eb0c1d23914caa2793f9a86e2d2 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Thu, 16 Oct 2025 12:07:07 +0200 Subject: [PATCH 57/80] fix modification --- examples/plot_diabetes_variable_importance_example.py | 4 ++-- src/hidimstat/conditional_feature_importance.py | 2 +- src/hidimstat/leave_one_covariate_out.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/plot_diabetes_variable_importance_example.py b/examples/plot_diabetes_variable_importance_example.py index 01ddd1f59..b14da8ef4 100644 --- a/examples/plot_diabetes_variable_importance_example.py +++ b/examples/plot_diabetes_variable_importance_example.py @@ -63,7 +63,7 @@ import numpy as np from sklearn.base import clone from sklearn.linear_model import LogisticRegressionCV, RidgeCV -from sklearn.metrics import r2_score, root_mean_squared_error +from sklearn.metrics import r2_score, mean_squared_error from sklearn.model_selection import KFold n_folds = 5 @@ -78,7 +78,7 @@ score = r2_score( y_true=y[test_index], y_pred=regressor_list[i].predict(X[test_index]) ) - mse = root_mean_squared_error( + mse = mean_squared_error( y_true=y[test_index], y_pred=regressor_list[i].predict(X[test_index]) ) diff --git a/src/hidimstat/conditional_feature_importance.py b/src/hidimstat/conditional_feature_importance.py index 414a0090f..eadec69e4 100644 --- a/src/hidimstat/conditional_feature_importance.py +++ b/src/hidimstat/conditional_feature_importance.py @@ -231,7 +231,7 @@ def cfi( X, y, method: str = "predict", - loss: callable = root_mean_squared_error, + loss: callable = mean_squared_error, n_permutations: int = 50, imputation_model_continuous=RidgeCV(), imputation_model_categorical=LogisticRegressionCV(), diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index 9f0840a12..546f8a033 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -210,7 +210,7 @@ def loco( X, y, method: str = "predict", - loss: callable = root_mean_squared_error, + loss: callable = mean_squared_error, features_groups=None, test_statistic=partial(wilcoxon, axis=1), k_best=None, From 0e6b9298ecd3ae206db2a4be5beaf86f01602866 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Thu, 16 Oct 2025 13:52:59 +0200 Subject: [PATCH 58/80] fix order import --- examples/plot_diabetes_variable_importance_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/plot_diabetes_variable_importance_example.py b/examples/plot_diabetes_variable_importance_example.py index b14da8ef4..9a03da1a9 100644 --- a/examples/plot_diabetes_variable_importance_example.py +++ b/examples/plot_diabetes_variable_importance_example.py @@ -63,7 +63,7 @@ import numpy as np from sklearn.base import clone from sklearn.linear_model import LogisticRegressionCV, RidgeCV -from sklearn.metrics import r2_score, mean_squared_error +from sklearn.metrics import mean_squared_error, r2_score from sklearn.model_selection import KFold n_folds = 5 From 438979664d78fd7fc42ec6bb856ff18438742476 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Thu, 16 Oct 2025 14:08:17 +0200 Subject: [PATCH 59/80] remove unecessary merge --- src/hidimstat/statistical_tools/conditional_sampling.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/hidimstat/statistical_tools/conditional_sampling.py b/src/hidimstat/statistical_tools/conditional_sampling.py index 6a886a323..153a6d924 100644 --- a/src/hidimstat/statistical_tools/conditional_sampling.py +++ b/src/hidimstat/statistical_tools/conditional_sampling.py @@ -47,7 +47,6 @@ def __init__( model_categorical=None, data_type: str = "auto", categorical_max_cardinality=10, - random_state=None, ): """ Class use to sample from the conditional distribution $p(X^j | X^{-j})$. @@ -66,8 +65,6 @@ def __init__( categorical_max_cardinality : int, default=10 The maximum cardinality of a variable to be considered as categorical when `data_type` is "auto". - random_state : int, optional - The random state to use for sampling. """ # check the validity of the inputs @@ -81,7 +78,6 @@ def __init__( self.model_regression = model_regression self.model_categorical = model_categorical self.categorical_max_cardinality = categorical_max_cardinality - self.rng = check_random_state(random_state) def fit(self, X: np.ndarray, y: np.ndarray): r""" From 9a640b1d7bafea6ea3fe985fd41c32ae4e8fe080 Mon Sep 17 00:00:00 2001 From: lionel kusch Date: Thu, 16 Oct 2025 14:10:13 +0200 Subject: [PATCH 60/80] Update src/hidimstat/_utils/utils.py Co-authored-by: Joseph Paillard --- src/hidimstat/_utils/utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/hidimstat/_utils/utils.py b/src/hidimstat/_utils/utils.py index 382396469..b661a2c30 100644 --- a/src/hidimstat/_utils/utils.py +++ b/src/hidimstat/_utils/utils.py @@ -174,4 +174,10 @@ def check_statistical_test(test): elif callable(test): return test else: - raise ValueError("The test '{}' is not a valid test".format(test)) + raise ValueError( + f"Unsupported value for 'statistical_test': {}.".format(test) + f"The provided argument was '{statistical_test}'. " + f"Please choose from the following valid options: " + f"string values ('ttest', 'wilcoxon', 'NB-ttest') " + f"or a custom callable function with a `scipy.stats` API-compatible signature." + ) From 431d9a63fd1ac6dcacba23bae7ab37ee63aec7b6 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Thu, 16 Oct 2025 14:21:10 +0200 Subject: [PATCH 61/80] update error --- src/hidimstat/_utils/utils.py | 24 ++++++++++++------------ test/_utils/test_utils.py | 2 +- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/hidimstat/_utils/utils.py b/src/hidimstat/_utils/utils.py index b661a2c30..802eb69c9 100644 --- a/src/hidimstat/_utils/utils.py +++ b/src/hidimstat/_utils/utils.py @@ -140,7 +140,7 @@ def seed_estimator(estimator, random_state=None): return estimator -def check_statistical_test(test): +def check_statistical_test(statistical_test): """ Validates and returns a test statistic function. @@ -164,20 +164,20 @@ def check_statistical_test(test): ValueError If test is neither a string nor a callable. """ - if isinstance(test, str): - if test == "ttest": + if isinstance(statistical_test, str): + if statistical_test == "ttest": return partial(ttest_1samp, popmean=0, alternative="greater", axis=1) - elif test == "wilcoxon": + elif statistical_test == "wilcoxon": return partial(wilcoxon, alternative="greater", axis=1) else: - raise ValueError(f"the test '{test}' is not supported") - elif callable(test): - return test + raise ValueError(f"the test '{statistical_test}' is not supported") + elif callable(statistical_test): + return statistical_test else: raise ValueError( - f"Unsupported value for 'statistical_test': {}.".format(test) - f"The provided argument was '{statistical_test}'. " - f"Please choose from the following valid options: " - f"string values ('ttest', 'wilcoxon', 'NB-ttest') " - f"or a custom callable function with a `scipy.stats` API-compatible signature." + f"Unsupported value for 'statistical_test'." + f"The provided argument was '{statistical_test}'. " + f"Please choose from the following valid options: " + f"string values ('ttest', 'wilcoxon', 'NB-ttest') " + f"or a custom callable function with a `scipy.stats` API-compatible signature." ) diff --git a/test/_utils/test_utils.py b/test/_utils/test_utils.py index 509502670..27e6511f6 100644 --- a/test/_utils/test_utils.py +++ b/test/_utils/test_utils.py @@ -79,5 +79,5 @@ def test_check_test_statistic_warning(): "test the exception" with pytest.raises(ValueError, match="the test 'test' is not supported"): check_statistical_test("test") - with pytest.raises(ValueError, match="is not a valid test"): + with pytest.raises(ValueError, match="Unsupported value for 'statistical_test'."): check_statistical_test([]) From e003345d2d902735526fcb81d0c16f2af9fe0c63 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Wed, 22 Oct 2025 11:29:49 +0200 Subject: [PATCH 62/80] add the NB-ttest as default --- src/hidimstat/_utils/utils.py | 10 ++++++++++ src/hidimstat/base_perturbation.py | 2 +- src/hidimstat/conditional_feature_importance.py | 2 +- src/hidimstat/leave_one_covariate_out.py | 2 +- src/hidimstat/permutation_feature_importance.py | 2 +- 5 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/hidimstat/_utils/utils.py b/src/hidimstat/_utils/utils.py index 802eb69c9..abf4f86dc 100644 --- a/src/hidimstat/_utils/utils.py +++ b/src/hidimstat/_utils/utils.py @@ -5,6 +5,8 @@ from numpy.random import RandomState from scipy.stats import ttest_1samp, wilcoxon +from hidimstat.statistical_tools import nadeau_bengio_ttest + def _check_vim_predict_method(method): """ @@ -169,6 +171,14 @@ def check_statistical_test(statistical_test): return partial(ttest_1samp, popmean=0, alternative="greater", axis=1) elif statistical_test == "wilcoxon": return partial(wilcoxon, alternative="greater", axis=1) + elif statistical_test == "NB-ttest": + return partial( + nadeau_bengio_ttest, + popmean=0, + test_frac=0.1 / 0.9, + alternative="greater", + axis=1, + ) else: raise ValueError(f"the test '{statistical_test}' is not supported") elif callable(statistical_test): diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index 9312fa4ed..d5e824442 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -66,7 +66,7 @@ def __init__( method: str = "predict", loss: callable = mean_squared_error, n_permutations: int = 50, - statistical_test="wilcoxon", + statistical_test="NB-ttest", features_groups=None, n_jobs: int = 1, random_state=None, diff --git a/src/hidimstat/conditional_feature_importance.py b/src/hidimstat/conditional_feature_importance.py index eadec69e4..832572d13 100644 --- a/src/hidimstat/conditional_feature_importance.py +++ b/src/hidimstat/conditional_feature_importance.py @@ -74,7 +74,7 @@ def __init__( features_groups=None, feature_types="auto", categorical_max_cardinality: int = 10, - statistical_test=partial(wilcoxon, axis=1), + statistical_test="NB-ttest", random_state: int = None, n_jobs: int = 1, ): diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index 546f8a033..a1179b24a 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -57,7 +57,7 @@ def __init__( estimator, method: str = "predict", loss: callable = mean_squared_error, - statistical_test=partial(wilcoxon, axis=1), + statistical_test="NB-ttest", features_groups=None, n_jobs: int = 1, ): diff --git a/src/hidimstat/permutation_feature_importance.py b/src/hidimstat/permutation_feature_importance.py index e088d74c5..6a1091fb4 100644 --- a/src/hidimstat/permutation_feature_importance.py +++ b/src/hidimstat/permutation_feature_importance.py @@ -56,7 +56,7 @@ def __init__( method: str = "predict", loss: callable = mean_squared_error, n_permutations: int = 50, - statistical_test="wilcoxon", + statistical_test="NB-ttest", features_groups=None, random_state: int = None, n_jobs: int = 1, From 1973bccbfe96003ea2d02712b5b0db1acb94ac97 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Wed, 22 Oct 2025 12:01:52 +0200 Subject: [PATCH 63/80] move sampler in separate module --- docs/src/api.rst | 4 ++-- src/hidimstat/knockoffs.py | 2 +- .../{statistical_tools => samplers}/conditional_sampling.py | 0 .../{statistical_tools => samplers}/gaussian_knockoffs.py | 0 src/hidimstat/statistical_tools/__init__.py | 6 +----- test/statistical_tools/test_conditional_sampling.py | 2 +- test/statistical_tools/test_gaussian_knockoffs.py | 2 +- 7 files changed, 6 insertions(+), 10 deletions(-) rename src/hidimstat/{statistical_tools => samplers}/conditional_sampling.py (100%) rename src/hidimstat/{statistical_tools => samplers}/gaussian_knockoffs.py (100%) diff --git a/docs/src/api.rst b/docs/src/api.rst index be79515c0..d2e49daf3 100644 --- a/docs/src/api.rst +++ b/docs/src/api.rst @@ -52,8 +52,8 @@ Samplers :toctree: ./generated/api/class/ :template: class.rst - ~statistical_tools.ConditionalSampler - ~statistical_tools.GaussianKnockoffs + ~samplers.ConditionalSampler + ~samplers.GaussianKnockoffs Helper Functions ================ diff --git a/src/hidimstat/knockoffs.py b/src/hidimstat/knockoffs.py index fd2d967ff..f4aa95e75 100644 --- a/src/hidimstat/knockoffs.py +++ b/src/hidimstat/knockoffs.py @@ -7,8 +7,8 @@ from sklearn.preprocessing import StandardScaler from sklearn.utils.validation import check_memory +from hidimstat.samplers.gaussian_knockoffs import GaussianKnockoffs from hidimstat.statistical_tools.aggregation import quantile_aggregation -from hidimstat.statistical_tools.gaussian_knockoffs import GaussianKnockoffs from hidimstat.statistical_tools.multiple_testing import fdr_threshold diff --git a/src/hidimstat/statistical_tools/conditional_sampling.py b/src/hidimstat/samplers/conditional_sampling.py similarity index 100% rename from src/hidimstat/statistical_tools/conditional_sampling.py rename to src/hidimstat/samplers/conditional_sampling.py diff --git a/src/hidimstat/statistical_tools/gaussian_knockoffs.py b/src/hidimstat/samplers/gaussian_knockoffs.py similarity index 100% rename from src/hidimstat/statistical_tools/gaussian_knockoffs.py rename to src/hidimstat/samplers/gaussian_knockoffs.py diff --git a/src/hidimstat/statistical_tools/__init__.py b/src/hidimstat/statistical_tools/__init__.py index 4012354d2..8147294cd 100644 --- a/src/hidimstat/statistical_tools/__init__.py +++ b/src/hidimstat/statistical_tools/__init__.py @@ -1,9 +1,5 @@ -from .conditional_sampling import ConditionalSampler -from .gaussian_knockoffs import GaussianKnockoffs -from .nadeau_bengio_ttest import nadeau_bengio_ttest +from .nadeau_bengio_ttest_ import nadeau_bengio_ttest __all__ = [ - "ConditionalSampler", - "GaussianKnockoffs", "nadeau_bengio_ttest", ] diff --git a/test/statistical_tools/test_conditional_sampling.py b/test/statistical_tools/test_conditional_sampling.py index a3de06970..e2509f9df 100644 --- a/test/statistical_tools/test_conditional_sampling.py +++ b/test/statistical_tools/test_conditional_sampling.py @@ -11,7 +11,7 @@ from sklearn.multiclass import OneVsRestClassifier from sklearn.preprocessing import StandardScaler -from hidimstat.statistical_tools.conditional_sampling import ConditionalSampler +from hidimstat.sampler.conditional_sampling import ConditionalSampler def test_continuous_case(): diff --git a/test/statistical_tools/test_gaussian_knockoffs.py b/test/statistical_tools/test_gaussian_knockoffs.py index db07ff4e6..f4378be88 100644 --- a/test/statistical_tools/test_gaussian_knockoffs.py +++ b/test/statistical_tools/test_gaussian_knockoffs.py @@ -3,7 +3,7 @@ from sklearn.covariance import LedoitWolf from hidimstat._utils.scenario import multivariate_simulation -from hidimstat.statistical_tools.gaussian_knockoffs import ( +from hidimstat.samplers.gaussian_knockoffs import ( GaussianKnockoffs, _s_equi, ) From 57ca0678033de93794e63887beecf3e0febb8795 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Wed, 22 Oct 2025 12:04:02 +0200 Subject: [PATCH 64/80] move sampler in a separate folder --- src/hidimstat/_utils/utils.py | 2 +- src/hidimstat/conditional_feature_importance.py | 2 +- src/hidimstat/samplers/__init__.py | 7 +++++++ 3 files changed, 9 insertions(+), 2 deletions(-) create mode 100644 src/hidimstat/samplers/__init__.py diff --git a/src/hidimstat/_utils/utils.py b/src/hidimstat/_utils/utils.py index abf4f86dc..de6fa9643 100644 --- a/src/hidimstat/_utils/utils.py +++ b/src/hidimstat/_utils/utils.py @@ -5,7 +5,7 @@ from numpy.random import RandomState from scipy.stats import ttest_1samp, wilcoxon -from hidimstat.statistical_tools import nadeau_bengio_ttest +from hidimstat.statistical_tools.nadeau_bengio_ttest import nadeau_bengio_ttest def _check_vim_predict_method(method): diff --git a/src/hidimstat/conditional_feature_importance.py b/src/hidimstat/conditional_feature_importance.py index 832572d13..e8a3fe17c 100644 --- a/src/hidimstat/conditional_feature_importance.py +++ b/src/hidimstat/conditional_feature_importance.py @@ -9,7 +9,7 @@ from hidimstat._utils.docstring import _aggregate_docstring from hidimstat.base_perturbation import BasePerturbation -from hidimstat.statistical_tools.conditional_sampling import ConditionalSampler +from hidimstat.samplers.conditional_sampling import ConditionalSampler class CFI(BasePerturbation): diff --git a/src/hidimstat/samplers/__init__.py b/src/hidimstat/samplers/__init__.py new file mode 100644 index 000000000..6cb04420a --- /dev/null +++ b/src/hidimstat/samplers/__init__.py @@ -0,0 +1,7 @@ +from .conditional_sampling import ConditionalSampler +from .gaussian_knockoffs import GaussianKnockoffs + +__all__ = [ + "ConditionalSampler", + "GaussianKnockoffs", +] From e4d686cae8069bf0fbfd833f26c760a087a2dd62 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Wed, 22 Oct 2025 13:57:24 +0200 Subject: [PATCH 65/80] fix import --- src/hidimstat/statistical_tools/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hidimstat/statistical_tools/__init__.py b/src/hidimstat/statistical_tools/__init__.py index 8147294cd..b713a79f6 100644 --- a/src/hidimstat/statistical_tools/__init__.py +++ b/src/hidimstat/statistical_tools/__init__.py @@ -1,4 +1,4 @@ -from .nadeau_bengio_ttest_ import nadeau_bengio_ttest +from .nadeau_bengio_ttest import nadeau_bengio_ttest __all__ = [ "nadeau_bengio_ttest", From a9cd7094b0d83f0f53d9c1ebde23f99214a8e9a6 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Wed, 22 Oct 2025 14:09:57 +0200 Subject: [PATCH 66/80] fix tests --- .../test_conditional_sampling.py | 2 +- test/{statistical_tools => samplers}/test_gaussian_knockoffs.py | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename test/{statistical_tools => samplers}/test_conditional_sampling.py (99%) rename test/{statistical_tools => samplers}/test_gaussian_knockoffs.py (100%) diff --git a/test/statistical_tools/test_conditional_sampling.py b/test/samplers/test_conditional_sampling.py similarity index 99% rename from test/statistical_tools/test_conditional_sampling.py rename to test/samplers/test_conditional_sampling.py index e2509f9df..888493334 100644 --- a/test/statistical_tools/test_conditional_sampling.py +++ b/test/samplers/test_conditional_sampling.py @@ -11,7 +11,7 @@ from sklearn.multiclass import OneVsRestClassifier from sklearn.preprocessing import StandardScaler -from hidimstat.sampler.conditional_sampling import ConditionalSampler +from hidimstat.samplers.conditional_sampling import ConditionalSampler def test_continuous_case(): diff --git a/test/statistical_tools/test_gaussian_knockoffs.py b/test/samplers/test_gaussian_knockoffs.py similarity index 100% rename from test/statistical_tools/test_gaussian_knockoffs.py rename to test/samplers/test_gaussian_knockoffs.py From 9ab658a8ecb7501ef4c2169a6ef0737c3ec9d344 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Wed, 22 Oct 2025 14:22:04 +0200 Subject: [PATCH 67/80] fix tests --- src/hidimstat/leave_one_covariate_out.py | 2 +- test/statistical_tools/test_nadeau_bengio_ttest.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index a1179b24a..ef9d4cafa 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -169,7 +169,7 @@ def importance(self, X, y): ], axis=1, ) - self.pvalues_ = self.statistical_test(test_result).pvalue + self.pvalues_ = self.statistical_test(np.array(test_result)).pvalue return self.importances_ def _joblib_fit_one_features_group(self, estimator, X, y, key_features_group): diff --git a/test/statistical_tools/test_nadeau_bengio_ttest.py b/test/statistical_tools/test_nadeau_bengio_ttest.py index 2771a651c..fafffcf82 100644 --- a/test/statistical_tools/test_nadeau_bengio_ttest.py +++ b/test/statistical_tools/test_nadeau_bengio_ttest.py @@ -48,7 +48,7 @@ def test_ttest_1samp_corrected_NB(data_generator): ) vim.fit(X_train, y_train) importances = vim.importance(X_test, y_test) - importance_list.append(importances["importance"]) + importance_list.append(importances) importance_array = np.array(importance_list) pvalue_corr = nadeau_bengio_ttest(importance_array, 0, test_frac=0.2).pvalue From 329bf43035dd2d3c42758854e25a99d6705b42ae Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Wed, 22 Oct 2025 14:42:58 +0200 Subject: [PATCH 68/80] fix example --- examples/plot_pitfalls_permutation_importance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/plot_pitfalls_permutation_importance.py b/examples/plot_pitfalls_permutation_importance.py index 70660234f..00185f905 100644 --- a/examples/plot_pitfalls_permutation_importance.py +++ b/examples/plot_pitfalls_permutation_importance.py @@ -267,7 +267,7 @@ from matplotlib.lines import Line2D -from hidimstat.statistical_tools.conditional_sampling import ConditionalSampler +from hidimstat.samplers.conditional_sampling import ConditionalSampler X_train, X_test = train_test_split( X, From 9f488d6040cde7211fe66a102720a292c53f0147 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Wed, 22 Oct 2025 15:54:03 +0200 Subject: [PATCH 69/80] change nane of nb-test --- src/hidimstat/_utils/utils.py | 4 ++-- src/hidimstat/base_perturbation.py | 7 +++++-- .../conditional_feature_importance.py | 4 ++-- src/hidimstat/leave_one_covariate_out.py | 4 ++-- .../permutation_feature_importance.py | 4 ++-- test/_utils/test_utils.py | 7 +++++-- test/test_conditional_feature_importance.py | 18 ++++++++++++++++++ 7 files changed, 36 insertions(+), 12 deletions(-) diff --git a/src/hidimstat/_utils/utils.py b/src/hidimstat/_utils/utils.py index de6fa9643..e88d3bc92 100644 --- a/src/hidimstat/_utils/utils.py +++ b/src/hidimstat/_utils/utils.py @@ -171,7 +171,7 @@ def check_statistical_test(statistical_test): return partial(ttest_1samp, popmean=0, alternative="greater", axis=1) elif statistical_test == "wilcoxon": return partial(wilcoxon, alternative="greater", axis=1) - elif statistical_test == "NB-ttest": + elif statistical_test == "nb-ttest": return partial( nadeau_bengio_ttest, popmean=0, @@ -188,6 +188,6 @@ def check_statistical_test(statistical_test): f"Unsupported value for 'statistical_test'." f"The provided argument was '{statistical_test}'. " f"Please choose from the following valid options: " - f"string values ('ttest', 'wilcoxon', 'NB-ttest') " + f"string values ('ttest', 'wilcoxon', 'nb-ttest') " f"or a custom callable function with a `scipy.stats` API-compatible signature." ) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index d5e824442..0c00495a2 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -32,7 +32,7 @@ class BasePerturbation(BaseVariableImportance, GroupVariableImportanceMixin): to the original model. n_permutations : int, default=50 Number of permutations for each feature group. - statistical_test : callable or str, default="wilcoxon" + statistical_test : callable or str, default="nb-ttest" Statistical test function for computing p-values of importance scores. features_groups : dict or None, default=None Mapping of group names to lists of feature indices or names. If None, groups are inferred. @@ -66,7 +66,7 @@ def __init__( method: str = "predict", loss: callable = mean_squared_error, n_permutations: int = 50, - statistical_test="NB-ttest", + statistical_test="nb-ttest", features_groups=None, n_jobs: int = 1, random_state=None, @@ -214,6 +214,9 @@ def importance(self, X, y): ) self.importances_ = np.mean(test_result, axis=1) self.pvalues_ = self.statistical_test(test_result).pvalue + assert ( + self.pvalues_.shape == y_pred.shape + ), "The statistical test doesn't provide the correct dimension." return self.importances_ def fit_importance(self, X, y): diff --git a/src/hidimstat/conditional_feature_importance.py b/src/hidimstat/conditional_feature_importance.py index e8a3fe17c..677308e18 100644 --- a/src/hidimstat/conditional_feature_importance.py +++ b/src/hidimstat/conditional_feature_importance.py @@ -50,7 +50,7 @@ class CFI(BasePerturbation): categorical_max_cardinality : int, default=10 The maximum cardinality of a variable to be considered as categorical when the variable type is inferred (set to "auto" or not provided). - statistical_test : callable or str, default="wilcoxon" + statistical_test : callable or str, default="nb-ttest" Statistical test function for computing p-values of importance scores. random_state : int or None, default=None The random state to use for sampling. @@ -74,7 +74,7 @@ def __init__( features_groups=None, feature_types="auto", categorical_max_cardinality: int = 10, - statistical_test="NB-ttest", + statistical_test="nb-ttest", random_state: int = None, n_jobs: int = 1, ): diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index ef9d4cafa..d54cf9a06 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -32,7 +32,7 @@ class LOCO(BasePerturbation): loss : callable, default=mean_squared_error The loss function to use when comparing the perturbed model to the full model. - statistical_test : callable or str, default="wilcoxon" + statistical_test : callable or str, default="nb-ttest" Statistical test function for computing p-values of importance scores. features_groups: dict or None, default=None A dictionary where the keys are the group names and the values are the @@ -57,7 +57,7 @@ def __init__( estimator, method: str = "predict", loss: callable = mean_squared_error, - statistical_test="NB-ttest", + statistical_test="nb-ttest", features_groups=None, n_jobs: int = 1, ): diff --git a/src/hidimstat/permutation_feature_importance.py b/src/hidimstat/permutation_feature_importance.py index 6a1091fb4..38fc2341a 100644 --- a/src/hidimstat/permutation_feature_importance.py +++ b/src/hidimstat/permutation_feature_importance.py @@ -33,7 +33,7 @@ class PFI(BasePerturbation): n_permutations : int, default=50 The number of permutations to perform. For each variable/group of variables, the mean of the losses over the `n_permutations` is computed. - statistical_test : callable or str, default="wilcoxon" + statistical_test : callable or str, default="nb-ttest" Statistical test function for computing p-values of importance scores. features_groups: dict or None, default=None A dictionary where the keys are the group names and the values are the @@ -56,7 +56,7 @@ def __init__( method: str = "predict", loss: callable = mean_squared_error, n_permutations: int = 50, - statistical_test="NB-ttest", + statistical_test="nb-ttest", features_groups=None, random_state: int = None, n_jobs: int = 1, diff --git a/test/_utils/test_utils.py b/test/_utils/test_utils.py index 27e6511f6..4a4abae5c 100644 --- a/test/_utils/test_utils.py +++ b/test/_utils/test_utils.py @@ -1,5 +1,3 @@ -from functools import partial - import numpy as np import pytest from scipy.stats import ttest_1samp, wilcoxon @@ -9,6 +7,7 @@ check_statistical_test, get_fitted_attributes, ) +from hidimstat.statistical_tools import nadeau_bengio_ttest def test_generated_attributes(): @@ -71,8 +70,12 @@ def test_check_test_statistic(): assert test_func.func == wilcoxon test_func = check_statistical_test("ttest") assert test_func.func == ttest_1samp + test_func = check_statistical_test("nb-ttest") + assert test_func.func == nadeau_bengio_ttest test_func = check_statistical_test(print) assert test_func == print + test_func = check_statistical_test(lambda x: x) + assert test_func.__class__.__name__ == "function" def test_check_test_statistic_warning(): diff --git a/test/test_conditional_feature_importance.py b/test/test_conditional_feature_importance.py index ab867695a..4f599271d 100644 --- a/test/test_conditional_feature_importance.py +++ b/test/test_conditional_feature_importance.py @@ -1,9 +1,11 @@ +from functools import partial from copy import deepcopy import matplotlib.pyplot as plt import numpy as np import pandas as pd import pytest +from scipy.stats import ttest_1samp from sklearn.exceptions import NotFittedError from sklearn.linear_model import ( LinearRegression, @@ -588,6 +590,22 @@ def test_groups_warning(self, data_generator): ): cfi.importance(X, y) + def test_assert_dimension_pvalue(self, data_generator): + """test that assert is raise if function stat is not good""" + X, y, _, _ = data_generator + fitted_model = LinearRegression().fit(X, y) + cfi = CFI( + estimator=fitted_model, + imputation_model_continuous=LinearRegression(), + statistical_test=partial(ttest_1samp, popmean=0, axis=0), + ) + cfi.fit(X, y) + with pytest.raises( + AssertionError, + match="The statistical test doesn't provide the correct dimension.", + ): + cfi.importance(X, y) + @pytest.mark.parametrize( "n_samples, n_features, support_size, rho, seed, value, signal_noise_ratio, rho_serial", From 17f8d6ee5fe7992763570bf626f5ca21ad8bc5b5 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Wed, 22 Oct 2025 16:03:11 +0200 Subject: [PATCH 70/80] fix import order --- src/hidimstat/leave_one_covariate_out.py | 3 +-- test/samplers/test_gaussian_knockoffs.py | 5 +---- test/test_conditional_feature_importance.py | 2 +- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index d54cf9a06..9a898e00d 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -145,9 +145,8 @@ def importance(self, X, y): A higher importance score indicates that perturbing that group leads to worse model performance, suggesting those features are more important. """ - GroupVariableImportanceMixin._check_fit(self) - GroupVariableImportanceMixin._check_compatibility(self, X) self._check_fit() + GroupVariableImportanceMixin._check_compatibility(self, X) y_pred = getattr(self.estimator, self.method)(X) self.loss_reference_ = self.loss(y, y_pred) diff --git a/test/samplers/test_gaussian_knockoffs.py b/test/samplers/test_gaussian_knockoffs.py index f4378be88..0ff587c1e 100644 --- a/test/samplers/test_gaussian_knockoffs.py +++ b/test/samplers/test_gaussian_knockoffs.py @@ -3,10 +3,7 @@ from sklearn.covariance import LedoitWolf from hidimstat._utils.scenario import multivariate_simulation -from hidimstat.samplers.gaussian_knockoffs import ( - GaussianKnockoffs, - _s_equi, -) +from hidimstat.samplers.gaussian_knockoffs import GaussianKnockoffs, _s_equi def test_gaussian_equi(): diff --git a/test/test_conditional_feature_importance.py b/test/test_conditional_feature_importance.py index 4f599271d..714dd8cde 100644 --- a/test/test_conditional_feature_importance.py +++ b/test/test_conditional_feature_importance.py @@ -1,5 +1,5 @@ -from functools import partial from copy import deepcopy +from functools import partial import matplotlib.pyplot as plt import numpy as np From d33205c39bcee57a28ee3ab23623ac2ce01f8565 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Wed, 22 Oct 2025 16:16:41 +0200 Subject: [PATCH 71/80] fix assert and add assert --- src/hidimstat/base_perturbation.py | 2 +- src/hidimstat/leave_one_covariate_out.py | 3 +++ test/test_leave_one_covariate_out.py | 14 ++++++++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index 0c00495a2..4e8645e94 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -215,7 +215,7 @@ def importance(self, X, y): self.importances_ = np.mean(test_result, axis=1) self.pvalues_ = self.statistical_test(test_result).pvalue assert ( - self.pvalues_.shape == y_pred.shape + self.pvalues_.shape[0] == y_pred.shape[0] ), "The statistical test doesn't provide the correct dimension." return self.importances_ diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index 9a898e00d..63eaed2d3 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -169,6 +169,9 @@ def importance(self, X, y): axis=1, ) self.pvalues_ = self.statistical_test(np.array(test_result)).pvalue + assert ( + self.pvalues_.shape[0] == y_pred.shape[0] + ), "The statistical test doesn't provide the correct dimension." return self.importances_ def _joblib_fit_one_features_group(self, estimator, X, y, key_features_group): diff --git a/test/test_leave_one_covariate_out.py b/test/test_leave_one_covariate_out.py index fc70d5d04..dc2ff38d8 100644 --- a/test/test_leave_one_covariate_out.py +++ b/test/test_leave_one_covariate_out.py @@ -1,6 +1,9 @@ +from functools import partial + import numpy as np import pandas as pd import pytest +from scipy.stats import ttest_1samp from sklearn.exceptions import NotFittedError from sklearn.linear_model import LinearRegression, LogisticRegression from sklearn.metrics import log_loss @@ -140,6 +143,17 @@ def test_raises_value_error(): BasePerturbation.fit(loco, X, y) loco.importance(X, y) + with pytest.raises( + AssertionError, + match="The statistical test doesn't provide the correct dimension.", + ): + fitted_model = LinearRegression().fit(X, y) + loco = LOCO( + estimator=fitted_model, + statistical_test=partial(ttest_1samp, popmean=0, axis=0), + ).fit(X, y) + loco.importance(X, y) + def test_loco_function(): """Test the function of LOCO algorithm on a linear scenario.""" From cfc12d087d06686f833a44debb083c2c3bdc0843 Mon Sep 17 00:00:00 2001 From: bthirion Date: Thu, 23 Oct 2025 21:54:03 +0200 Subject: [PATCH 72/80] Update src/hidimstat/base_perturbation.py Co-authored-by: Joseph Paillard --- src/hidimstat/base_perturbation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index 4e8645e94..46255a46a 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -33,7 +33,7 @@ class BasePerturbation(BaseVariableImportance, GroupVariableImportanceMixin): n_permutations : int, default=50 Number of permutations for each feature group. statistical_test : callable or str, default="nb-ttest" - Statistical test function for computing p-values of importance scores. + Statistical test function for computing p-values from importance scores. features_groups : dict or None, default=None Mapping of group names to lists of feature indices or names. If None, groups are inferred. n_jobs : int, default=1 From d14b835c212ae856f946259b9a937baeac67f228 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Fri, 24 Oct 2025 17:24:08 +0200 Subject: [PATCH 73/80] Remove unecessary check --- src/hidimstat/base_perturbation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index 46255a46a..44ef0e147 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -191,9 +191,8 @@ def importance(self, X, y): A higher importance score indicates that perturbing that group leads to worse model performance, suggesting those features are more important. """ - GroupVariableImportanceMixin._check_fit(self) - GroupVariableImportanceMixin._check_compatibility(self, X) self._check_fit() + self._check_compatibility(X) y_pred = getattr(self.estimator, self.method)(X) self.loss_reference_ = self.loss(y, y_pred) From 87e20292e335c2d279ccf8cd1d76d57202216448 Mon Sep 17 00:00:00 2001 From: kusch lionel Date: Fri, 24 Oct 2025 17:25:08 +0200 Subject: [PATCH 74/80] update loco --- src/hidimstat/leave_one_covariate_out.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index 63eaed2d3..8e427e9f1 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -146,7 +146,7 @@ def importance(self, X, y): worse model performance, suggesting those features are more important. """ self._check_fit() - GroupVariableImportanceMixin._check_compatibility(self, X) + self._check_compatibility(X) y_pred = getattr(self.estimator, self.method)(X) self.loss_reference_ = self.loss(y, y_pred) From c61bb4490a41652c80a58428d9cc34c6d6d45cd5 Mon Sep 17 00:00:00 2001 From: jpaillard Date: Sun, 26 Oct 2025 19:57:19 +0100 Subject: [PATCH 75/80] make ttest the default without CV --- src/hidimstat/_utils/utils.py | 4 ++-- src/hidimstat/base_perturbation.py | 7 ++++--- src/hidimstat/conditional_feature_importance.py | 6 +++--- src/hidimstat/leave_one_covariate_out.py | 10 ++++++---- src/hidimstat/permutation_feature_importance.py | 6 +++--- test/statistical_tools/test_nadeau_bengio_ttest.py | 2 ++ 6 files changed, 20 insertions(+), 15 deletions(-) diff --git a/src/hidimstat/_utils/utils.py b/src/hidimstat/_utils/utils.py index e88d3bc92..8bee24709 100644 --- a/src/hidimstat/_utils/utils.py +++ b/src/hidimstat/_utils/utils.py @@ -142,7 +142,7 @@ def seed_estimator(estimator, random_state=None): return estimator -def check_statistical_test(statistical_test): +def check_statistical_test(statistical_test, test_frac=None): """ Validates and returns a test statistic function. @@ -175,7 +175,7 @@ def check_statistical_test(statistical_test): return partial( nadeau_bengio_ttest, popmean=0, - test_frac=0.1 / 0.9, + test_frac=test_frac, alternative="greater", axis=1, ) diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py index 44ef0e147..35b49bb29 100644 --- a/src/hidimstat/base_perturbation.py +++ b/src/hidimstat/base_perturbation.py @@ -66,7 +66,7 @@ def __init__( method: str = "predict", loss: callable = mean_squared_error, n_permutations: int = 50, - statistical_test="nb-ttest", + statistical_test="ttest", features_groups=None, n_jobs: int = 1, random_state=None, @@ -80,7 +80,7 @@ def __init__( _check_vim_predict_method(method) self.method = method self.n_permutations = n_permutations - self.statistical_test = check_statistical_test(statistical_test) + self.statistical_test = statistical_test self.n_jobs = n_jobs # variable set in importance @@ -193,6 +193,7 @@ def importance(self, X, y): """ self._check_fit() self._check_compatibility(X) + statistical_test = check_statistical_test(self.statistical_test) y_pred = getattr(self.estimator, self.method)(X) self.loss_reference_ = self.loss(y, y_pred) @@ -212,7 +213,7 @@ def importance(self, X, y): ] ) self.importances_ = np.mean(test_result, axis=1) - self.pvalues_ = self.statistical_test(test_result).pvalue + self.pvalues_ = statistical_test(test_result).pvalue assert ( self.pvalues_.shape[0] == y_pred.shape[0] ), "The statistical test doesn't provide the correct dimension." diff --git a/src/hidimstat/conditional_feature_importance.py b/src/hidimstat/conditional_feature_importance.py index 677308e18..121a96c68 100644 --- a/src/hidimstat/conditional_feature_importance.py +++ b/src/hidimstat/conditional_feature_importance.py @@ -50,7 +50,7 @@ class CFI(BasePerturbation): categorical_max_cardinality : int, default=10 The maximum cardinality of a variable to be considered as categorical when the variable type is inferred (set to "auto" or not provided). - statistical_test : callable or str, default="nb-ttest" + statistical_test : callable or str, default="ttest" Statistical test function for computing p-values of importance scores. random_state : int or None, default=None The random state to use for sampling. @@ -74,7 +74,7 @@ def __init__( features_groups=None, feature_types="auto", categorical_max_cardinality: int = 10, - statistical_test="nb-ttest", + statistical_test="ttest", random_state: int = None, n_jobs: int = 1, ): @@ -238,7 +238,7 @@ def cfi( features_groups=None, feature_types="auto", categorical_max_cardinality: int = 10, - test_statistic=partial(wilcoxon, axis=1), + test_statistic="ttest", k_best=None, percentile=None, threshold_max=None, diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index 8e427e9f1..86b8db68a 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -8,6 +8,7 @@ from sklearn.metrics import mean_squared_error from hidimstat._utils.docstring import _aggregate_docstring +from hidimstat._utils.utils import check_statistical_test from hidimstat.base_perturbation import BasePerturbation from hidimstat.base_variable_importance import GroupVariableImportanceMixin @@ -32,7 +33,7 @@ class LOCO(BasePerturbation): loss : callable, default=mean_squared_error The loss function to use when comparing the perturbed model to the full model. - statistical_test : callable or str, default="nb-ttest" + statistical_test : callable or str, default="ttest" Statistical test function for computing p-values of importance scores. features_groups: dict or None, default=None A dictionary where the keys are the group names and the values are the @@ -57,7 +58,7 @@ def __init__( estimator, method: str = "predict", loss: callable = mean_squared_error, - statistical_test="nb-ttest", + statistical_test="ttest", features_groups=None, n_jobs: int = 1, ): @@ -147,6 +148,7 @@ def importance(self, X, y): """ self._check_fit() self._check_compatibility(X) + statistical_test = check_statistical_test(self.statistical_test) y_pred = getattr(self.estimator, self.method)(X) self.loss_reference_ = self.loss(y, y_pred) @@ -168,7 +170,7 @@ def importance(self, X, y): ], axis=1, ) - self.pvalues_ = self.statistical_test(np.array(test_result)).pvalue + self.pvalues_ = statistical_test(np.array(test_result)).pvalue assert ( self.pvalues_.shape[0] == y_pred.shape[0] ), "The statistical test doesn't provide the correct dimension." @@ -214,7 +216,7 @@ def loco( method: str = "predict", loss: callable = mean_squared_error, features_groups=None, - test_statistic=partial(wilcoxon, axis=1), + test_statistic="ttest", k_best=None, percentile=None, threshold_min=None, diff --git a/src/hidimstat/permutation_feature_importance.py b/src/hidimstat/permutation_feature_importance.py index 38fc2341a..3f2bd3a48 100644 --- a/src/hidimstat/permutation_feature_importance.py +++ b/src/hidimstat/permutation_feature_importance.py @@ -33,7 +33,7 @@ class PFI(BasePerturbation): n_permutations : int, default=50 The number of permutations to perform. For each variable/group of variables, the mean of the losses over the `n_permutations` is computed. - statistical_test : callable or str, default="nb-ttest" + statistical_test : callable or str, default="ttest" Statistical test function for computing p-values of importance scores. features_groups: dict or None, default=None A dictionary where the keys are the group names and the values are the @@ -56,7 +56,7 @@ def __init__( method: str = "predict", loss: callable = mean_squared_error, n_permutations: int = 50, - statistical_test="nb-ttest", + statistical_test="ttest", features_groups=None, random_state: int = None, n_jobs: int = 1, @@ -93,7 +93,7 @@ def pfi( method: str = "predict", loss: callable = mean_squared_error, n_permutations: int = 50, - test_statistic=partial(wilcoxon, axis=1), + test_statistic="ttest", features_groups=None, k_best=None, percentile=None, diff --git a/test/statistical_tools/test_nadeau_bengio_ttest.py b/test/statistical_tools/test_nadeau_bengio_ttest.py index fafffcf82..1f9223934 100644 --- a/test/statistical_tools/test_nadeau_bengio_ttest.py +++ b/test/statistical_tools/test_nadeau_bengio_ttest.py @@ -1,3 +1,5 @@ +from functools import partial + import numpy as np import numpy.ma.testutils as ma_npt import numpy.testing as npt From 1deae93f424293c3b71598268d047d2b2ee05e6e Mon Sep 17 00:00:00 2001 From: jpaillard Date: Thu, 6 Nov 2025 10:03:34 +0100 Subject: [PATCH 76/80] rename functions --- docs/src/api.rst | 5 ++++- src/hidimstat/__init__.py | 12 ++++++------ src/hidimstat/conditional_feature_importance.py | 4 ++-- src/hidimstat/knockoffs.py | 1 + src/hidimstat/leave_one_covariate_out.py | 4 ++-- src/hidimstat/permutation_feature_importance.py | 4 ++-- test/test_conditional_feature_importance.py | 4 ++-- test/test_knockoff.py | 2 +- test/test_leave_one_covariate_out.py | 4 ++-- test/test_permutation_feature_importance.py | 4 ++-- 10 files changed, 24 insertions(+), 20 deletions(-) diff --git a/docs/src/api.rst b/docs/src/api.rst index 58885f769..1b6bca41b 100644 --- a/docs/src/api.rst +++ b/docs/src/api.rst @@ -36,7 +36,8 @@ Feature Importance functions .. autosummary:: :toctree: ./generated/api/class/ :template: function.rst - + + cfi_analysis clustered_inference clustered_inference_pvalue desparsified_lasso @@ -44,6 +45,8 @@ Feature Importance functions desparsified_group_lasso_pvalue ensemble_clustered_inference ensemble_clustered_inference_pvalue + loco_analysis + pfi_analysis Visualization ============= diff --git a/src/hidimstat/__init__.py b/src/hidimstat/__init__.py index e3c834c10..37f6e762c 100644 --- a/src/hidimstat/__init__.py +++ b/src/hidimstat/__init__.py @@ -1,4 +1,4 @@ -from .conditional_feature_importance import CFI, cfi +from .conditional_feature_importance import CFI, cfi_analysis from .desparsified_lasso import ( desparsified_group_lasso_pvalue, desparsified_lasso, @@ -12,9 +12,9 @@ ensemble_clustered_inference_pvalue, ) from .knockoffs import ModelXKnockoff -from .leave_one_covariate_out import LOCO +from .leave_one_covariate_out import LOCO, loco_analysis from .noise_std import reid -from .permutation_feature_importance import PFI, pfi +from .permutation_feature_importance import PFI, pfi_analysis from .statistical_tools.aggregation import quantile_aggregation try: @@ -36,9 +36,9 @@ "reid", "ModelXKnockoff", "CFI", - "cfi", + "cfi_analysis", "LOCO", - "loco", + "loco_analysis", "PFI", - "pfi", + "pfi_analysis", ] diff --git a/src/hidimstat/conditional_feature_importance.py b/src/hidimstat/conditional_feature_importance.py index 121a96c68..a2fa2072e 100644 --- a/src/hidimstat/conditional_feature_importance.py +++ b/src/hidimstat/conditional_feature_importance.py @@ -226,7 +226,7 @@ def _permutation(self, X, features_group_id, random_state=None): ) -def cfi( +def cfi_analysis( estimator, X, y, @@ -271,7 +271,7 @@ def cfi( # use the docstring of the class for the function -cfi.__doc__ = _aggregate_docstring( +cfi_analysis.__doc__ = _aggregate_docstring( [ CFI.__doc__, CFI.__init__.__doc__, diff --git a/src/hidimstat/knockoffs.py b/src/hidimstat/knockoffs.py index 90413a12f..4a04b8039 100644 --- a/src/hidimstat/knockoffs.py +++ b/src/hidimstat/knockoffs.py @@ -11,6 +11,7 @@ from hidimstat._utils.docstring import _aggregate_docstring from hidimstat._utils.utils import check_random_state, seed_estimator from hidimstat.base_variable_importance import BaseVariableImportance +from hidimstat.samplers import GaussianKnockoffs from hidimstat.statistical_tools.aggregation import quantile_aggregation from hidimstat.statistical_tools.multiple_testing import fdr_threshold diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py index 86b8db68a..dc48f3bfd 100644 --- a/src/hidimstat/leave_one_covariate_out.py +++ b/src/hidimstat/leave_one_covariate_out.py @@ -209,7 +209,7 @@ def _check_fit(self): check_is_fitted(m) -def loco( +def loco_analysis( estimator, X, y, @@ -242,7 +242,7 @@ def loco( # use the docstring of the class for the function -loco.__doc__ = _aggregate_docstring( +loco_analysis.__doc__ = _aggregate_docstring( [ LOCO.__doc__, LOCO.__init__.__doc__, diff --git a/src/hidimstat/permutation_feature_importance.py b/src/hidimstat/permutation_feature_importance.py index 3f2bd3a48..bb88927d4 100644 --- a/src/hidimstat/permutation_feature_importance.py +++ b/src/hidimstat/permutation_feature_importance.py @@ -86,7 +86,7 @@ def _permutation(self, X, features_group_id, random_state=None): return X_perm_j -def pfi( +def pfi_analysis( estimator, X, y, @@ -123,7 +123,7 @@ def pfi( # use the docstring of the class for the function -pfi.__doc__ = _aggregate_docstring( +pfi_analysis.__doc__ = _aggregate_docstring( [ PFI.__doc__, PFI.__init__.__doc__, diff --git a/test/test_conditional_feature_importance.py b/test/test_conditional_feature_importance.py index 714dd8cde..57d89611e 100644 --- a/test/test_conditional_feature_importance.py +++ b/test/test_conditional_feature_importance.py @@ -16,7 +16,7 @@ from sklearn.metrics import log_loss, mean_squared_error from sklearn.model_selection import train_test_split -from hidimstat import CFI, cfi +from hidimstat import CFI, cfi_analysis from hidimstat._utils.exception import InternalError from hidimstat._utils.scenario import multivariate_simulation from hidimstat.base_perturbation import BasePerturbation @@ -616,7 +616,7 @@ def test_assert_dimension_pvalue(self, data_generator): def test_function_cfi(data_generator, n_permutation, cfi_seed): """Test CFI function""" X, y, _, _ = data_generator - cfi( + cfi_analysis( LinearRegression().fit(X, y), X, y, diff --git a/test/test_knockoff.py b/test/test_knockoff.py index a19e2e56b..cf1ccdf14 100644 --- a/test/test_knockoff.py +++ b/test/test_knockoff.py @@ -11,7 +11,7 @@ model_x_knockoff, set_alpha_max_lasso_path, ) -from hidimstat.statistical_tools.gaussian_knockoffs import GaussianKnockoffs +from hidimstat.samplers import GaussianKnockoffs from hidimstat.statistical_tools.multiple_testing import fdp_power diff --git a/test/test_leave_one_covariate_out.py b/test/test_leave_one_covariate_out.py index dc2ff38d8..1c8c3b243 100644 --- a/test/test_leave_one_covariate_out.py +++ b/test/test_leave_one_covariate_out.py @@ -9,7 +9,7 @@ from sklearn.metrics import log_loss from sklearn.model_selection import train_test_split -from hidimstat import LOCO, loco +from hidimstat import LOCO, loco_analysis from hidimstat._utils.scenario import multivariate_simulation from hidimstat.base_perturbation import BasePerturbation @@ -172,7 +172,7 @@ def test_loco_function(): regression_model = LinearRegression() regression_model.fit(X_train, y_train) - selection, importance, pvalue = loco( + selection, importance, pvalue = loco_analysis( regression_model, X, y, diff --git a/test/test_permutation_feature_importance.py b/test/test_permutation_feature_importance.py index a007004a4..f806486ce 100644 --- a/test/test_permutation_feature_importance.py +++ b/test/test_permutation_feature_importance.py @@ -5,7 +5,7 @@ from sklearn.metrics import log_loss from sklearn.model_selection import train_test_split -from hidimstat import PFI, pfi +from hidimstat import PFI, pfi_analysis from hidimstat._utils.scenario import multivariate_simulation @@ -114,7 +114,7 @@ def test_permutation_importance_function(): regression_model = LinearRegression() regression_model.fit(X_train, y_train) - selection, importance, pvalue = pfi( + selection, importance, pvalue = pfi_analysis( regression_model, X, y, From 911f11c10bc6d134a23c0e0b90db740f4a1ffefc Mon Sep 17 00:00:00 2001 From: jpaillard Date: Thu, 6 Nov 2025 12:09:17 +0100 Subject: [PATCH 77/80] fix import --- examples/plot_knockoffs_wisconsin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/plot_knockoffs_wisconsin.py b/examples/plot_knockoffs_wisconsin.py index ad1a85c57..115e98e05 100644 --- a/examples/plot_knockoffs_wisconsin.py +++ b/examples/plot_knockoffs_wisconsin.py @@ -147,7 +147,7 @@ from sklearn.covariance import LedoitWolf from hidimstat import ModelXKnockoff -from hidimstat.statistical_tools.gaussian_knockoffs import GaussianKnockoffs +from hidimstat.samplers import GaussianKnockoffs model_x_knockoff = ModelXKnockoff( ko_generator=GaussianKnockoffs( From bc4ee65832c0fe0ed4b48770d5cb9938d8f126b6 Mon Sep 17 00:00:00 2001 From: Joseph Paillard Date: Fri, 7 Nov 2025 08:38:21 +0100 Subject: [PATCH 78/80] Update src/hidimstat/_utils/utils.py Co-authored-by: bthirion --- src/hidimstat/_utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hidimstat/_utils/utils.py b/src/hidimstat/_utils/utils.py index 6a8c84746..dc2b37816 100644 --- a/src/hidimstat/_utils/utils.py +++ b/src/hidimstat/_utils/utils.py @@ -148,7 +148,7 @@ def check_statistical_test(statistical_test, test_frac=None): Parameters ---------- - test : str or callable + statisticcal_test : str or callable If str, must be either 'ttest' or 'wilcoxon'. If callable, must be a function that can be used as a test statistic. From e93b97f122dc8e8eb28f7c4346876cff4c717479 Mon Sep 17 00:00:00 2001 From: jpaillard Date: Fri, 7 Nov 2025 08:39:50 +0100 Subject: [PATCH 79/80] add test_frac --- src/hidimstat/_utils/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/hidimstat/_utils/utils.py b/src/hidimstat/_utils/utils.py index dc2b37816..459afb0cb 100644 --- a/src/hidimstat/_utils/utils.py +++ b/src/hidimstat/_utils/utils.py @@ -151,6 +151,8 @@ def check_statistical_test(statistical_test, test_frac=None): statisticcal_test : str or callable If str, must be either 'ttest' or 'wilcoxon'. If callable, must be a function that can be used as a test statistic. + test_frac : float, optional + The fraction of data used for testing in the Nadeau-Bengio t-test. Returns ------- From eca0c0c0d35a712859464d92b482dd8bfa5d7bdf Mon Sep 17 00:00:00 2001 From: jpaillard Date: Fri, 7 Nov 2025 10:13:48 +0100 Subject: [PATCH 80/80] init --- src/hidimstat/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/hidimstat/__init__.py b/src/hidimstat/__init__.py index 806b08421..edebc8d18 100644 --- a/src/hidimstat/__init__.py +++ b/src/hidimstat/__init__.py @@ -1,4 +1,4 @@ -from .conditional_feature_importance import CFI +from .conditional_feature_importance import CFI, cfi_analysis from .desparsified_lasso import DesparsifiedLasso, desparsified_lasso, reid from .distilled_conditional_randomization_test import D0CRT, d0crt from .ensemble_clustered_inference import ( @@ -8,8 +8,8 @@ ensemble_clustered_inference_pvalue, ) from .knockoffs import ModelXKnockoff -from .leave_one_covariate_out import LOCO -from .permutation_feature_importance import PFI +from .leave_one_covariate_out import LOCO, loco_analysis +from .permutation_feature_importance import PFI, pfi_analysis from .statistical_tools.aggregation import quantile_aggregation try: