Skip to content

Commit acb8e20

Browse files
lionelkuschjpaillardbthirion
authored
[API 2]: CFI, PFI, LOCO (#372)
* New API for CFI, PFI, LOCO * fix test for new API * fix example * add test for new check * add pvalue and fit_importance and function * Add new function * fix docstring * Improve cross validation * update docstring * update doctring * fix error * fix docstring * Apply suggestions from code review Co-authored-by: Joseph Paillard <[email protected]> * Update default * fix tests * Apply suggestions from code review Co-authored-by: bthirion <[email protected]> * chnage group by features_groups * fix format * improve test * fix docstring * fix test * improve loco * fix computation of pvalues * Update src/hidimstat/_utils/utils.py Co-authored-by: Joseph Paillard <[email protected]> * change name * remove the cross validation in fit_importance * change fit_importance * more flexible for the computation of the statistic * update the computation of pvalue for loco * fix merge * fix example * fix example * fix import * Update src/hidimstat/leave_one_covariate_out.py Co-authored-by: bthirion <[email protected]> * Update src/hidimstat/base_perturbation.py Co-authored-by: bthirion <[email protected]> * fix modification * Remove the wrong merge * Add check_test_statistic * change name * fix import * change name * Update src/hidimstat/_utils/utils.py Co-authored-by: Joseph Paillard <[email protected]> * Update src/hidimstat/conditional_feature_importance.py Co-authored-by: Joseph Paillard <[email protected]> * Update src/hidimstat/conditional_feature_importance.py Co-authored-by: Joseph Paillard <[email protected]> * Update src/hidimstat/permutation_feature_importance.py Co-authored-by: Joseph Paillard <[email protected]> * Update src/hidimstat/_utils/utils.py Co-authored-by: Joseph Paillard <[email protected]> * Update test/test_conditional_feature_importance.py Co-authored-by: Joseph Paillard <[email protected]> * Update test/test_conditional_feature_importance.py Co-authored-by: Joseph Paillard <[email protected]> * Update src/hidimstat/conditional_feature_importance.py Co-authored-by: Joseph Paillard <[email protected]> * Update src/hidimstat/leave_one_covariate_out.py Co-authored-by: Joseph Paillard <[email protected]> * Update src/hidimstat/leave_one_covariate_out.py Co-authored-by: Joseph Paillard <[email protected]> * Update src/hidimstat/leave_one_covariate_out.py Co-authored-by: Joseph Paillard <[email protected]> * Update src/hidimstat/permutation_feature_importance.py Co-authored-by: Joseph Paillard <[email protected]> * Update src/hidimstat/permutation_feature_importance.py Co-authored-by: Joseph Paillard <[email protected]> * Update src/hidimstat/permutation_feature_importance.py Co-authored-by: Joseph Paillard <[email protected]> * fix format * fix modification * fix order import * remove unecessary merge * Update src/hidimstat/_utils/utils.py Co-authored-by: Joseph Paillard <[email protected]> * update error * add the NB-ttest as default * move sampler in separate module * move sampler in a separate folder * fix import * fix tests * fix tests * fix example * change nane of nb-test * fix import order * fix assert and add assert * Update src/hidimstat/base_perturbation.py Co-authored-by: Joseph Paillard <[email protected]> * Remove unecessary check * update loco * make ttest the default without CV * rename functions * fix import * Update src/hidimstat/_utils/utils.py Co-authored-by: bthirion <[email protected]> * add test_frac * init --------- Co-authored-by: Joseph Paillard <[email protected]> Co-authored-by: bthirion <[email protected]>
1 parent 1f97f5b commit acb8e20

29 files changed

+947
-308
lines changed

docs/src/api.rst

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,14 @@ Feature Importance functions
3737
.. autosummary::
3838
:toctree: ./generated/api/class/
3939
:template: function.rst
40-
40+
41+
cfi_analysis
4142
clustered_inference
4243
clustered_inference_pvalue
4344
ensemble_clustered_inference
4445
ensemble_clustered_inference_pvalue
46+
loco_analysis
47+
pfi_analysis
4548

4649
Visualization
4750
=============
@@ -60,8 +63,8 @@ Samplers
6063
:toctree: ./generated/api/class/
6164
:template: class.rst
6265

63-
~statistical_tools.ConditionalSampler
64-
~statistical_tools.GaussianKnockoffs
66+
~samplers.ConditionalSampler
67+
~samplers.GaussianKnockoffs
6568

6669
Helper Functions
6770
================

examples/plot_conditional_vs_marginal_xor_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@
131131
random_state=0,
132132
)
133133
vim.fit(X_train, y_train)
134-
importances.append(vim.importance(X_test, y_test)["importance"])
134+
importances.append(vim.importance(X_test, y_test))
135135

136136
importances = np.array(importances).T
137137

examples/plot_diabetes_variable_importance_example.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
import numpy as np
6464
from sklearn.base import clone
6565
from sklearn.linear_model import LogisticRegressionCV, RidgeCV
66-
from sklearn.metrics import r2_score, root_mean_squared_error
66+
from sklearn.metrics import mean_squared_error, r2_score
6767
from sklearn.model_selection import KFold
6868

6969
n_folds = 5
@@ -78,7 +78,7 @@
7878
score = r2_score(
7979
y_true=y[test_index], y_pred=regressor_list[i].predict(X[test_index])
8080
)
81-
mse = root_mean_squared_error(
81+
mse = mean_squared_error(
8282
y_true=y[test_index], y_pred=regressor_list[i].predict(X[test_index])
8383
)
8484

@@ -166,14 +166,14 @@
166166
import pandas as pd
167167
from scipy.stats import ttest_1samp
168168

169-
cfi_vim_arr = np.array([x["importance"] for x in cfi_importance_list]) / 2
169+
cfi_vim_arr = np.array(cfi_importance_list) / 2
170170
cfi_pval = ttest_1samp(cfi_vim_arr, 0, alternative="greater").pvalue
171171

172172
vim = [
173173
pd.DataFrame(
174174
{
175175
"var": np.arange(cfi_vim_arr.shape[1]),
176-
"importance": x["importance"],
176+
"importance": x,
177177
"fold": i,
178178
"pval": cfi_pval,
179179
"method": "CFI",
@@ -182,14 +182,14 @@
182182
for x in cfi_importance_list
183183
]
184184

185-
loco_vim_arr = np.array([x["importance"] for x in loco_importance_list])
185+
loco_vim_arr = np.array(loco_importance_list)
186186
loco_pval = ttest_1samp(loco_vim_arr, 0, alternative="greater").pvalue
187187

188188
vim += [
189189
pd.DataFrame(
190190
{
191191
"var": np.arange(loco_vim_arr.shape[1]),
192-
"importance": x["importance"],
192+
"importance": x,
193193
"fold": i,
194194
"pval": loco_pval,
195195
"method": "LOCO",
@@ -198,14 +198,14 @@
198198
for x in loco_importance_list
199199
]
200200

201-
pfi_vim_arr = np.array([x["importance"] for x in pfi_importance_list])
201+
pfi_vim_arr = np.array(pfi_importance_list)
202202
pfi_pval = ttest_1samp(pfi_vim_arr, 0, alternative="greater").pvalue
203203

204204
vim += [
205205
pd.DataFrame(
206206
{
207207
"var": np.arange(pfi_vim_arr.shape[1]),
208-
"importance": x["importance"],
208+
"importance": x,
209209
"fold": i,
210210
"pval": pfi_pval,
211211
"method": "PFI",

examples/plot_importance_classification_iris.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def run_one_fold(
107107
)
108108

109109
vim.fit(X[train_index], y[train_index])
110-
importance = vim.importance(X[test_index], y[test_index])["importance"]
110+
importance = vim.importance(X[test_index], y[test_index])
111111

112112
return pd.DataFrame(
113113
{

examples/plot_knockoffs_wisconsin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@
147147
from sklearn.covariance import LedoitWolf
148148

149149
from hidimstat import ModelXKnockoff
150-
from hidimstat.statistical_tools.gaussian_knockoffs import GaussianKnockoffs
150+
from hidimstat.samplers import GaussianKnockoffs
151151

152152
model_x_knockoff = ModelXKnockoff(
153153
ko_generator=GaussianKnockoffs(

examples/plot_loco.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
# importance. This process is repeated for all features to assess their individual
8383
# contributions.
8484
loco.fit(X_train, y_train)
85-
importances = loco.importance(X_test, y_test)["importance"]
85+
importances = loco.importance(X_test, y_test)
8686
df_list.append(
8787
pd.DataFrame(
8888
{

examples/plot_model_agnostic_importance.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,8 @@
134134
vim_linear.fit(X[train], y[train])
135135
vim_non_linear.fit(X[train], y[train])
136136

137-
importances_linear.append(
138-
vim_linear.importance(X[test], y[test])["importance"],
139-
)
140-
importances_non_linear.append(
141-
vim_non_linear.importance(X[test], y[test])["importance"]
142-
)
137+
importances_linear.append(vim_linear.importance(X[test], y[test]))
138+
importances_non_linear.append(vim_non_linear.importance(X[test], y[test]))
143139

144140

145141
# %%

examples/plot_pitfalls_permutation_importance.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@
144144
)
145145
pfi.fit(X_test, y_test)
146146

147-
permutation_importances.append(pfi.importance(X_test, y_test)["importance"])
147+
permutation_importances.append(pfi.importance(X_test, y_test))
148148
permutation_importances = np.stack(permutation_importances)
149149
pval_pfi = ttest_1samp(
150150
permutation_importances, 0.0, axis=0, alternative="greater"
@@ -216,7 +216,7 @@
216216
)
217217
cfi.fit(X_test, y_test)
218218

219-
conditional_importances.append(cfi.importance(X_test, y_test)["importance"])
219+
conditional_importances.append(cfi.importance(X_test, y_test))
220220

221221

222222
cfi_pval = ttest_1samp(
@@ -267,7 +267,7 @@
267267

268268
from matplotlib.lines import Line2D
269269

270-
from hidimstat.statistical_tools.conditional_sampling import ConditionalSampler
270+
from hidimstat.samplers.conditional_sampling import ConditionalSampler
271271

272272
X_train, X_test = train_test_split(
273273
X,

src/hidimstat/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .conditional_feature_importance import CFI
1+
from .conditional_feature_importance import CFI, cfi_analysis
22
from .desparsified_lasso import DesparsifiedLasso, desparsified_lasso, reid
33
from .distilled_conditional_randomization_test import D0CRT, d0crt
44
from .ensemble_clustered_inference import (
@@ -8,8 +8,8 @@
88
ensemble_clustered_inference_pvalue,
99
)
1010
from .knockoffs import ModelXKnockoff
11-
from .leave_one_covariate_out import LOCO
12-
from .permutation_feature_importance import PFI
11+
from .leave_one_covariate_out import LOCO, loco_analysis
12+
from .permutation_feature_importance import PFI, pfi_analysis
1313
from .statistical_tools.aggregation import quantile_aggregation
1414

1515
try:
@@ -30,6 +30,9 @@
3030
"reid",
3131
"ModelXKnockoff",
3232
"CFI",
33+
"cfi_analysis",
3334
"LOCO",
35+
"loco_analysis",
3436
"PFI",
37+
"pfi_analysis",
3538
]

src/hidimstat/_utils/utils.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import numbers
2+
from functools import partial
23

34
import numpy as np
45
from numpy.random import RandomState
6+
from scipy.stats import ttest_1samp, wilcoxon
7+
8+
from hidimstat.statistical_tools.nadeau_bengio_ttest import nadeau_bengio_ttest
59

610

711
def _check_vim_predict_method(method):
@@ -33,6 +37,37 @@ def _check_vim_predict_method(method):
3337
)
3438

3539

40+
def get_fitted_attributes(cls):
41+
"""
42+
Get all attributes from a class that end with a single underscore
43+
and doesn't start with one underscore.
44+
45+
Parameters
46+
----------
47+
cls : class
48+
The class to inspect for attributes.
49+
50+
Returns
51+
-------
52+
list
53+
A list of attribute names that end with a single underscore but not double underscore.
54+
"""
55+
# Get all attributes and methods of the class
56+
all_attributes = dir(cls)
57+
58+
# Filter out attributes that start with an underscore
59+
filtered_attributes = [attr for attr in all_attributes if not attr.startswith("_")]
60+
61+
# Filter out attributes that do not end with a single underscore
62+
result = [
63+
attr
64+
for attr in filtered_attributes
65+
if attr.endswith("_") and not attr.endswith("__")
66+
]
67+
68+
return result
69+
70+
3671
def check_random_state(seed):
3772
"""
3873
Modified version of sklearn's check_random_state using np.random.Generator.
@@ -105,3 +140,56 @@ def seed_estimator(estimator, random_state=None):
105140
setattr(value, "random_state", RandomState(rng.bit_generator))
106141

107142
return estimator
143+
144+
145+
def check_statistical_test(statistical_test, test_frac=None):
146+
"""
147+
Validates and returns a test statistic function.
148+
149+
Parameters
150+
----------
151+
statisticcal_test : str or callable
152+
If str, must be either 'ttest' or 'wilcoxon'.
153+
If callable, must be a function that can be used as a test statistic.
154+
test_frac : float, optional
155+
The fraction of data used for testing in the Nadeau-Bengio t-test.
156+
157+
Returns
158+
-------
159+
callable
160+
A function that can be used as a test statistic.
161+
For string inputs, returns a partial function of either ttest_1samp or wilcoxon.
162+
For callable inputs, returns the input function.
163+
164+
Raises
165+
------
166+
ValueError
167+
If test is a string but not one of the supported test names ('ttest' or 'wilcoxon').
168+
ValueError
169+
If test is neither a string nor a callable.
170+
"""
171+
if isinstance(statistical_test, str):
172+
if statistical_test == "ttest":
173+
return partial(ttest_1samp, popmean=0, alternative="greater", axis=1)
174+
elif statistical_test == "wilcoxon":
175+
return partial(wilcoxon, alternative="greater", axis=1)
176+
elif statistical_test == "nb-ttest":
177+
return partial(
178+
nadeau_bengio_ttest,
179+
popmean=0,
180+
test_frac=test_frac,
181+
alternative="greater",
182+
axis=1,
183+
)
184+
else:
185+
raise ValueError(f"the test '{statistical_test}' is not supported")
186+
elif callable(statistical_test):
187+
return statistical_test
188+
else:
189+
raise ValueError(
190+
f"Unsupported value for 'statistical_test'."
191+
f"The provided argument was '{statistical_test}'. "
192+
f"Please choose from the following valid options: "
193+
f"string values ('ttest', 'wilcoxon', 'nb-ttest') "
194+
f"or a custom callable function with a `scipy.stats` API-compatible signature."
195+
)

0 commit comments

Comments
 (0)