Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
df93c78
New API for CFI, PFI, LOCO
lionelkusch Sep 2, 2025
ccb60ed
fix test for new API
lionelkusch Sep 2, 2025
7c827ad
fix example
lionelkusch Sep 2, 2025
82d61e6
add test for new check
lionelkusch Sep 2, 2025
28593e4
add pvalue and fit_importance and function
lionelkusch Sep 2, 2025
cabfb63
Add new function
lionelkusch Sep 2, 2025
7d7fd7d
fix docstring
lionelkusch Sep 2, 2025
b958cc7
Improve cross validation
lionelkusch Sep 3, 2025
1f97d60
update docstring
lionelkusch Sep 3, 2025
db96bb6
update doctring
lionelkusch Sep 3, 2025
d656f17
fix error
lionelkusch Sep 3, 2025
0493b6f
fix docstring
lionelkusch Sep 3, 2025
9c54e1b
Apply suggestions from code review
lionelkusch Sep 5, 2025
7bf75e4
Update default
lionelkusch Sep 5, 2025
b3cd78a
fix tests
lionelkusch Sep 5, 2025
7825490
Apply suggestions from code review
lionelkusch Sep 8, 2025
084ad24
chnage group by features_groups
lionelkusch Sep 8, 2025
7379ec1
fix format
lionelkusch Sep 8, 2025
02ae5ba
improve test
lionelkusch Sep 8, 2025
1e91c65
fix docstring
lionelkusch Sep 8, 2025
46b8fa5
Merge branch 'main' into PR_CFI
lionelkusch Sep 8, 2025
58a57f8
fix test
lionelkusch Sep 8, 2025
03b919a
Merge branch 'main' into PR_CFI
lionelkusch Sep 10, 2025
c4ea731
improve loco
lionelkusch Sep 11, 2025
43d3f99
fix computation of pvalues
lionelkusch Sep 11, 2025
5fd99b0
Update src/hidimstat/_utils/utils.py
lionelkusch Sep 12, 2025
3c52789
change name
lionelkusch Sep 12, 2025
c93f14c
remove the cross validation in fit_importance
lionelkusch Sep 24, 2025
aa583d5
change fit_importance
lionelkusch Sep 24, 2025
01cbc44
more flexible for the computation of the statistic
lionelkusch Sep 24, 2025
b1c5f40
update the computation of pvalue for loco
lionelkusch Sep 24, 2025
1a47330
Merge branch 'main' into PR_CFI
lionelkusch Oct 10, 2025
1ef69a6
fix merge
lionelkusch Oct 13, 2025
d00566b
Merge branch 'main' into PR_CFI
lionelkusch Oct 13, 2025
83ae849
fix example
lionelkusch Oct 13, 2025
a3cd681
fix example
lionelkusch Oct 13, 2025
b3f336a
fix import
lionelkusch Oct 13, 2025
a364a93
Merge branch 'main' into PR_CFI
lionelkusch Oct 14, 2025
6dc8d67
Update src/hidimstat/leave_one_covariate_out.py
lionelkusch Oct 14, 2025
afd03cb
Update src/hidimstat/base_perturbation.py
lionelkusch Oct 14, 2025
3fa9d01
fix modification
lionelkusch Oct 14, 2025
1202426
Remove the wrong merge
lionelkusch Oct 14, 2025
75d6578
Add check_test_statistic
lionelkusch Oct 15, 2025
ae3dfa9
change name
lionelkusch Oct 15, 2025
035f4c8
fix import
lionelkusch Oct 15, 2025
84ff550
Merge branch 'main' into PR_CFI
lionelkusch Oct 15, 2025
251584c
change name
lionelkusch Oct 15, 2025
1c15c5b
Update src/hidimstat/_utils/utils.py
lionelkusch Oct 16, 2025
9414368
Update src/hidimstat/conditional_feature_importance.py
lionelkusch Oct 16, 2025
8a8587f
Update src/hidimstat/conditional_feature_importance.py
lionelkusch Oct 16, 2025
631ee83
Update src/hidimstat/permutation_feature_importance.py
lionelkusch Oct 16, 2025
39ebce1
Update src/hidimstat/_utils/utils.py
lionelkusch Oct 16, 2025
971da61
Update test/test_conditional_feature_importance.py
lionelkusch Oct 16, 2025
c7adec9
Update test/test_conditional_feature_importance.py
lionelkusch Oct 16, 2025
a3c7906
Update src/hidimstat/conditional_feature_importance.py
lionelkusch Oct 16, 2025
f4c8ce4
Update src/hidimstat/leave_one_covariate_out.py
lionelkusch Oct 16, 2025
c983eab
Update src/hidimstat/leave_one_covariate_out.py
lionelkusch Oct 16, 2025
f91ac6c
Update src/hidimstat/leave_one_covariate_out.py
lionelkusch Oct 16, 2025
38c6bc7
Update src/hidimstat/permutation_feature_importance.py
lionelkusch Oct 16, 2025
0bb880e
Update src/hidimstat/permutation_feature_importance.py
lionelkusch Oct 16, 2025
a92ea83
Update src/hidimstat/permutation_feature_importance.py
lionelkusch Oct 16, 2025
2166f2c
fix format
lionelkusch Oct 16, 2025
4cc3598
fix modification
lionelkusch Oct 16, 2025
0e6b929
fix order import
lionelkusch Oct 16, 2025
4389796
remove unecessary merge
lionelkusch Oct 16, 2025
9a640b1
Update src/hidimstat/_utils/utils.py
lionelkusch Oct 16, 2025
431d9a6
update error
lionelkusch Oct 16, 2025
e7bbb30
Merge branch 'main' into PR_CFI
lionelkusch Oct 22, 2025
e003345
add the NB-ttest as default
lionelkusch Oct 22, 2025
1973bcc
move sampler in separate module
lionelkusch Oct 22, 2025
57ca067
move sampler in a separate folder
lionelkusch Oct 22, 2025
e4d686c
fix import
lionelkusch Oct 22, 2025
a9cd709
fix tests
lionelkusch Oct 22, 2025
75e81d7
Merge branch 'main' into PR_CFI
lionelkusch Oct 22, 2025
9ab658a
fix tests
lionelkusch Oct 22, 2025
329bf43
fix example
lionelkusch Oct 22, 2025
9f488d6
change nane of nb-test
lionelkusch Oct 22, 2025
17f8d6e
fix import order
lionelkusch Oct 22, 2025
d33205c
fix assert and add assert
lionelkusch Oct 22, 2025
cfc12d0
Update src/hidimstat/base_perturbation.py
bthirion Oct 23, 2025
7a4f44a
Merge branch 'main' into PR_CFI
lionelkusch Oct 24, 2025
d14b835
Remove unecessary check
lionelkusch Oct 24, 2025
87e2029
update loco
lionelkusch Oct 24, 2025
c61bb44
make ttest the default without CV
jpaillard Oct 26, 2025
b0c4ec0
Merge branch 'main' into PR_CFI
jpaillard Oct 29, 2025
359118a
Merge branch 'main' of github.com:mind-inria/hidimstat into PR_CFI
jpaillard Nov 6, 2025
1deae93
rename functions
jpaillard Nov 6, 2025
911f11c
fix import
jpaillard Nov 6, 2025
bc4ee65
Update src/hidimstat/_utils/utils.py
jpaillard Nov 7, 2025
e93b97f
add test_frac
jpaillard Nov 7, 2025
e03266e
Merge branch 'main' of github.com:mind-inria/hidimstat into PR_CFI
jpaillard Nov 7, 2025
eca0c0c
init
jpaillard Nov 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/plot_conditional_vs_marginal_xor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
random_state=0,
)
vim.fit(X_train, y_train)
importances.append(vim.importance(X_test, y_test)["importance"])
importances.append(vim.importance(X_test, y_test))

importances = np.array(importances).T

Expand Down
12 changes: 6 additions & 6 deletions examples/plot_diabetes_variable_importance_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,14 @@ def compute_pval(vim):
# -------------------


cfi_vim_arr = np.array([x["importance"] for x in cfi_importance_list]) / 2
cfi_vim_arr = np.array(cfi_importance_list) / 2
cfi_pval = compute_pval(cfi_vim_arr)

vim = [
pd.DataFrame(
{
"var": np.arange(cfi_vim_arr.shape[1]),
"importance": x["importance"],
"importance": x,
"fold": i,
"pval": cfi_pval,
"method": "CFI",
Expand All @@ -200,14 +200,14 @@ def compute_pval(vim):
for x in cfi_importance_list
]

loco_vim_arr = np.array([x["importance"] for x in loco_importance_list])
loco_vim_arr = np.array(loco_importance_list)
loco_pval = compute_pval(loco_vim_arr)

vim += [
pd.DataFrame(
{
"var": np.arange(loco_vim_arr.shape[1]),
"importance": x["importance"],
"importance": x,
"fold": i,
"pval": loco_pval,
"method": "LOCO",
Expand All @@ -216,14 +216,14 @@ def compute_pval(vim):
for x in loco_importance_list
]

pfi_vim_arr = np.array([x["importance"] for x in pfi_importance_list])
pfi_vim_arr = np.array(pfi_importance_list)
pfi_pval = compute_pval(pfi_vim_arr)

vim += [
pd.DataFrame(
{
"var": np.arange(pfi_vim_arr.shape[1]),
"importance": x["importance"],
"importance": x,
"fold": i,
"pval": pfi_pval,
"method": "PFI",
Expand Down
2 changes: 1 addition & 1 deletion examples/plot_importance_classification_iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def run_one_fold(X, y, model, train_index, test_index, vim_name="CFI", groups=No
)

vim.fit(X[train_index], y[train_index], groups=groups)
importance = vim.importance(X[test_index], y[test_index])["importance"]
importance = vim.importance(X[test_index], y[test_index])

return pd.DataFrame(
{
Expand Down
6 changes: 2 additions & 4 deletions examples/plot_model_agnostic_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,8 @@
vim_linear.fit(X[train], y[train])
vim_non_linear.fit(X[train], y[train])

importances_linear.append(vim_linear.importance(X[test], y[test])["importance"])
importances_non_linear.append(
vim_non_linear.importance(X[test], y[test])["importance"]
)
importances_linear.append(vim_linear.importance(X[test], y[test]))
importances_non_linear.append(vim_non_linear.importance(X[test], y[test]))


################################################################################
Expand Down
4 changes: 2 additions & 2 deletions examples/plot_pitfalls_permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@
)
pfi.fit(X_test, y_test)

permutation_importances.append(pfi.importance(X_test, y_test)["importance"])
permutation_importances.append(pfi.importance(X_test, y_test))
permutation_importances = np.stack(permutation_importances)
pval_pfi = ttest_1samp(
permutation_importances, 0.0, axis=0, alternative="greater"
Expand Down Expand Up @@ -200,7 +200,7 @@
)
cfi.fit(X_test, y_test)

conditional_importances.append(cfi.importance(X_test, y_test)["importance"])
conditional_importances.append(cfi.importance(X_test, y_test))


cfi_pval = ttest_1samp(
Expand Down
9 changes: 6 additions & 3 deletions src/hidimstat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@
desparsified_group_lasso_pvalue,
)
from .distilled_conditional_randomization_test import d0crt, D0CRT
from .conditional_feature_importance import CFI
from .conditional_feature_importance import cfi, CFI
from .knockoffs import (
model_x_knockoff,
model_x_knockoff_pvalue,
model_x_knockoff_bootstrap_quantile,
model_x_knockoff_bootstrap_e_value,
)
from .leave_one_covariate_out import LOCO
from .leave_one_covariate_out import loco, LOCO
from .noise_std import reid
from .permutation_feature_importance import PFI
from .permutation_feature_importance import pfi, PFI

from .statistical_tools.aggregation import quantile_aggregation

Expand All @@ -49,6 +49,9 @@
"model_x_knockoff_bootstrap_quantile",
"model_x_knockoff_bootstrap_e_value",
"CFI",
"cfi",
"LOCO",
"loco",
"PFI",
"pfi",
]
115 changes: 88 additions & 27 deletions src/hidimstat/base_perturbation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import warnings

import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from sklearn.base import check_is_fitted
from scipy.stats import ttest_1samp
from sklearn.base import check_is_fitted, clone
from sklearn.metrics import root_mean_squared_error
import warnings
from sklearn.model_selection import KFold

from hidimstat._utils.utils import _check_vim_predict_method
from hidimstat._utils.exception import InternalError
Expand All @@ -14,9 +17,9 @@ class BasePerturbation(BaseVariableImportance):
def __init__(
self,
estimator,
method: str = "predict",
loss: callable = root_mean_squared_error,
n_permutations: int = 50,
method: str = "predict",
n_jobs: int = 1,
):
"""
Expand All @@ -27,6 +30,10 @@ def __init__(
----------
estimator : sklearn compatible estimator, optional
The estimator to use for the prediction.
method : str, default="predict"
The method used for making predictions. This determines the predictions
passed to the loss function. Supported methods are "predict",
"predict_proba", "decision_function", "transform".
loss : callable, default=root_mean_squared_error
The function to compute the loss when comparing the perturbed model
to the original model.
Expand All @@ -35,10 +42,6 @@ def __init__(
Specifies the number of times the variable group (residual for CFI) is
permuted. For each permutation, the perturbed model's loss is calculated
and averaged over all permutations.
method : str, default="predict"
The method used for making predictions. This determines the predictions
passed to the loss function. Supported methods are "predict",
"predict_proba", "decision_function", "transform".
n_jobs : int, default=1
The number of parallel jobs to run. Parallelization is done over the
variables or groups of variables.
Expand All @@ -50,9 +53,18 @@ def __init__(
self.loss = loss
_check_vim_predict_method(method)
self.method = method
self.n_jobs = n_jobs
self.n_permutations = n_permutations
self.n_groups = None
self.n_jobs = n_jobs
# variable set in fit
self.groups = None
# varaible set in importance
self.loss_reference_ = None
self.loss_ = None
# variable set in fit_importance
self.importances_cv_ = None
# internal variables
self._n_groups = None
self._groups_ids = None

def fit(self, X, y=None, groups=None):
"""Base fit method for perturbation-based methods. Identifies the groups.
Expand All @@ -69,11 +81,11 @@ def fit(self, X, y=None, groups=None):
identified based on the columns of X.
"""
if groups is None:
self.n_groups = X.shape[1]
self.groups = {j: [j] for j in range(self.n_groups)}
self._n_groups = X.shape[1]
self.groups = {j: [j] for j in range(self._n_groups)}
self._groups_ids = np.array(list(self.groups.values()), dtype=int)
elif isinstance(groups, dict):
self.n_groups = len(groups)
self._n_groups = len(groups)
self.groups = groups
if isinstance(X, pd.DataFrame):
self._groups_ids = []
Expand All @@ -91,6 +103,7 @@ def fit(self, X, y=None, groups=None):
]
else:
raise ValueError("groups needs to be a dictionnary")
return self

def predict(self, X):
"""
Expand Down Expand Up @@ -139,27 +152,69 @@ def importance(self, X, y):
"""
self._check_fit(X)

out_dict = dict()

y_pred = getattr(self.estimator, self.method)(X)
loss_reference = self.loss(y, y_pred)
out_dict["loss_reference"] = loss_reference
self.loss_reference_ = self.loss(y, y_pred)

y_pred = self.predict(X)
out_dict["loss"] = dict()
self.loss_ = dict()
for j, y_pred_j in enumerate(y_pred):
list_loss = []
for y_pred_perm in y_pred_j:
list_loss.append(self.loss(y, y_pred_perm))
out_dict["loss"][j] = np.array(list_loss)
self.loss_[j] = np.array(list_loss)

out_dict["importance"] = np.array(
self.importances_ = np.array(
[
np.mean(out_dict["loss"][j]) - loss_reference
for j in range(self.n_groups)
np.mean(self.loss_[j]) - self.loss_reference_
for j in range(self._n_groups)
]
)
return out_dict
self.pvalues_ = ttest_1samp(
self.importances_, 0.0, axis=0, alternative="greater"
).pvalue
return self.importances_

def fit_importance(
self, X, y, cv=KFold(n_splits=5, shuffle=True, random_state=0), **fit_kwargs
):
"""
Compute feature importance scores using cross-validation.

Parameters
----------
X : array-like of shape (n_samples, n_features)
Training data.
y : array-like of shape (n_samples,)
Target values.
cv : cross-validation generator or iterable, default=KFold(n_splits=5, shuffle=True, random_state=0)
Determines the cross-validation splitting strategy.
**fit_kwargs : dict
Additional arguments passed to the fit method during variable group identification.

Returns
-------
importances : float
Mean feature importance scores across CV folds.

Notes
-----
For each CV fold:
1. Clones and fits the estimator on training fold
2. Identifies variable groups on training fold
3. Computes feature importances on test fold
4. Returns average importance across all folds

The importances for each fold are stored in self.importances\_
"""
importances = []
for train, test in cv.split(X):
estimator = clone(self.estimator)
estimator.fit(X[train], y[train])
self.fit(X[train], y[train], **fit_kwargs)
importances.append(self.importance(X[test], y[test]))
self.importances_cv_ = importances
self.importances_ = np.mean(importances, axis=0)
return self.importances_

def _check_fit(self, X):
"""
Expand All @@ -183,11 +238,7 @@ def _check_fit(self, X):
If the number of features in X does not match the total number
of features in the grouped variables.
"""
if (
self.n_groups is None
or not hasattr(self, "groups")
or not hasattr(self, "_groups_ids")
):
if self._n_groups is None or self.groups is None or self._groups_ids is None:
raise ValueError(
"The class is not fitted. The fit method must be called"
" to set variable groups. If no grouping is needed,"
Expand Down Expand Up @@ -231,6 +282,16 @@ def _check_fit(self, X):
f"{number_unique_feature_in_groups}"
)

def _check_importance(self):
"""
Checks if the loss have been computed.
"""
super()._check_importance()
if self.loss_reference_ is None or self.loss_ is None:
raise ValueError(
"The importances need to be called before calling this method"
)

def _joblib_predict_one_group(self, X, group_id, group_key):
"""
Compute the predictions after perturbation of the data for a given
Expand Down
Loading
Loading