Skip to content

Commit a2c830c

Browse files
committed
add CRT
1 parent 5854f2e commit a2c830c

File tree

4 files changed

+772
-0
lines changed

4 files changed

+772
-0
lines changed
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
from copy import deepcopy
2+
from itertools import product
3+
import warnings
4+
5+
import numpy as np
6+
from joblib import Parallel, delayed
7+
from sklearn.covariance import LedoitWolf
8+
from sklearn.utils.validation import check_memory
9+
from tqdm import tqdm
10+
11+
from hidimstat._utils.docstring import _aggregate_docstring
12+
from hidimstat.statistical_tools.gaussian_distribution import GaussianDistribution
13+
from hidimstat.statistical_tools.lasso_test import lasso_statistic
14+
from hidimstat.base_variable_importance import BaseVariableImportance
15+
16+
17+
class CRT(BaseVariableImportance):
18+
"""
19+
Implements conditional randomization test (CRT).
20+
The Conditional Randomization Test :footcite:t:`candes2018panning` is a method
21+
for statistical variable importance testing (see algorithm 2).
22+
Parameters
23+
----------
24+
generator : object, default=GaussianGenerator(cov_estimator=LedoitWolf(assume_centered=True))
25+
Generator object for simulating null distributions
26+
statistical_test : callable, default=lasso_statistic
27+
Function that computes test statistic
28+
n_permutation : int, default=10
29+
Number of permutations for the test
30+
n_jobs : int, default=1
31+
Number of parallel jobs
32+
memory : str or object, default=None
33+
Used for caching
34+
joblib_verbose : int, default=0
35+
Verbosity level for parallel jobs
36+
Attributes
37+
----------
38+
importances_ : ndarray of shape (n_features,)
39+
Feature importance scores
40+
pvalues_ : ndarray of shape (n_features,)
41+
P-values for each feature
42+
Notes
43+
-----
44+
The CRT tests feature importance by comparing observed test statistics against
45+
a conditional null distribution generated through simulation.
46+
See Also
47+
--------
48+
GaussianGenerator : Generator for Gaussian null distributions
49+
lasso_statistic : Default test statistic using Lasso coefficients
50+
References
51+
----------
52+
.. footbibliography::
53+
"""
54+
55+
def __init__(
56+
self,
57+
generator=GaussianDistribution(cov_estimator=LedoitWolf(assume_centered=True)),
58+
statistical_test=lasso_statistic,
59+
n_repeat=10,
60+
n_jobs=1,
61+
memory=None,
62+
joblib_verbose=0,
63+
):
64+
self.generator = generator
65+
assert n_repeat > 0, "n_samplings must be positive"
66+
self.n_repeat = n_repeat
67+
self.n_jobs = n_jobs
68+
self.memory = check_memory(memory)
69+
self.joblib_verbose = joblib_verbose
70+
self.statistical_test = statistical_test
71+
72+
def fit(self, X, y=None):
73+
"""
74+
Fit the CRT model by training the generator.
75+
Parameters
76+
----------
77+
X : array-like of shape (n_samples, n_features)
78+
Training data matrix where n_samples is the number of samples and
79+
n_features is the number of features.
80+
y : array-like of shape (n_samples,), default=None
81+
Target values. Not used in this method.
82+
Returns
83+
-------
84+
self : object
85+
Returns the instance itself.
86+
Notes
87+
-----
88+
The fit method only trains the generator component. The target values y
89+
are not used in this step.
90+
"""
91+
if y is not None:
92+
warnings.warn("y won't be used")
93+
94+
self.generator.fit(X)
95+
return self
96+
97+
def _check_fit(self):
98+
try:
99+
self.generator._check_fit()
100+
except ValueError as exc:
101+
raise ValueError(
102+
"The CRT requires to be fitted before computing importance"
103+
) from exc
104+
105+
def importance(self, X, y):
106+
"""
107+
Calculate p-values and identify significant features using the CRT test statistics.
108+
This function processes the results from Conditional Randomization Test (CRT) to identify
109+
statistically significant features. It computes p-values by comparing a reference test
110+
statistic to test statistics from permuted data.
111+
X : array-like of shape (n_samples, n_features)
112+
y : array-like of shape (n_samples,)
113+
Array of importance scores (p-values) for each feature. Lower p-values indicate
114+
higher importance. Values range from 0 to 1.
115+
Notes
116+
-----
117+
The p-values are calculated using the formula:
118+
(1 + #(T_perm >= T_obs)) / (n_permutations + 1)
119+
where T_perm are the test statistics from permuted data and T_obs is the
120+
reference test statistic.
121+
See Also
122+
--------
123+
statistical_test : Method that computes the test statistic used in this function.
124+
"""
125+
self._check_fit()
126+
reference_value = self.statistical_test(X, y)
127+
128+
parallel = Parallel(self.n_jobs, verbose=self.joblib_verbose)
129+
X_samples = []
130+
for i in range(self.n_repeat):
131+
X_samples.append(self.generator.sample())
132+
133+
self.test_scores_ = np.array(
134+
parallel(
135+
delayed(joblib_statitistic_test)(
136+
index, X, X_sample, y, self.statistical_test
137+
)
138+
for X_sample, index in tqdm(product(X_samples, range(X.shape[1])))
139+
)
140+
)
141+
self.test_scores_ = reference_value - np.array(self.test_scores_).reshape(
142+
self.n_repeat, -1
143+
)
144+
145+
self.importances_ = np.mean(np.abs(self.test_scores_), axis=0)
146+
# see equation of p-value in algorithm 2
147+
self.pvalues_ = (
148+
1
149+
+ np.sum(
150+
self.test_scores_ >= 0,
151+
axis=0,
152+
)
153+
) / (self.n_repeat + 1)
154+
return self.importances_
155+
156+
def fit_importance(self, X, y, cv=None):
157+
"""
158+
Fits the model to the data and computes feature importance.
159+
Parameters
160+
----------
161+
X : array-like of shape (n_samples, n_features)
162+
The input data matrix where n_samples is the number of samples and
163+
n_features is the number of features.
164+
y : array-like of shape (n_samples,)
165+
The target values.
166+
cv : None or cross-validation generator, default=None
167+
Cross-validation parameter. Not used in this method.
168+
A warning will be issued if provided.
169+
Returns
170+
-------
171+
importances_ : ndarray of shape (n_features,)
172+
Feature importance scores (p-values) for each feature.
173+
Lower values indicate higher importance. Values range from 0 to 1.
174+
Notes
175+
-----
176+
This method combines the fit and importance computation steps.
177+
It first fits the generator to X and then computes importance scores
178+
by comparing observed test statistics against permuted ones.
179+
See Also
180+
--------
181+
fit : Method for fitting the generator only
182+
importance : Method for computing importance scores only
183+
"""
184+
if cv is not None:
185+
warnings.warn("cv won't be used")
186+
187+
self.fit(X)
188+
return self.importance(X, y)
189+
190+
191+
def joblib_statitistic_test(index, X, X_sample, y, statistic_test):
192+
"""Compute test statistic for a single feature with permuted data.
193+
Parameters
194+
----------
195+
index : int
196+
Index of the feature to test
197+
X : array-like of shape (n_samples, n_features)
198+
Original input data matrix
199+
X_sample : array-like of shape (n_samples, n_features)
200+
Permuted data matrix
201+
y : array-like of shape (n_samples,)
202+
Target values
203+
statistic_test : callable
204+
Function that computes the test statistic
205+
Returns
206+
-------
207+
float
208+
Test statistic value for the specified feature
209+
"""
210+
X_tmp = deepcopy(X)
211+
X_tmp[:, index] = deepcopy(X_sample[:, index])
212+
return statistic_test(X_tmp, y)[index]
213+
214+
215+
def crt(
216+
X,
217+
y,
218+
generator=GaussianDistribution(cov_estimator=LedoitWolf(assume_centered=True)),
219+
statistical_test=lasso_statistic,
220+
n_repeat=10,
221+
n_jobs=1,
222+
memory=None,
223+
joblib_verbose=0,
224+
):
225+
crt = CRT(
226+
generator=generator,
227+
statistical_test=statistical_test,
228+
n_repeat=n_repeat,
229+
n_jobs=n_jobs,
230+
memory=memory,
231+
joblib_verbose=joblib_verbose,
232+
)
233+
return crt.fit_importance(X, y)
234+
235+
236+
# use the docstring of the class for the function
237+
crt.__doc__ = _aggregate_docstring(
238+
[
239+
CRT.__doc__,
240+
CRT.__init__.__doc__,
241+
CRT.fit_importance.__doc__,
242+
CRT.selection.__doc__,
243+
],
244+
"""
245+
Returns
246+
-------
247+
selection: binary array-like of shape (n_features)
248+
Binary array of the seleted features
249+
importance : array-like of shape (n_features)
250+
The computed feature importance scores.
251+
pvalues : array-like of shape (n_features)
252+
The computed significant of feature for the prediction.
253+
""",
254+
)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import numpy as np
2+
from sklearn.linear_model import LassoCV
3+
from sklearn.model_selection import KFold
4+
5+
6+
def lasso_statistic(
7+
X,
8+
y,
9+
lasso=LassoCV(
10+
n_jobs=1,
11+
verbose=0,
12+
max_iter=200000,
13+
cv=KFold(n_splits=5, shuffle=True, random_state=0),
14+
tol=1e-6,
15+
),
16+
n_alphas=0,
17+
):
18+
"""
19+
Compute Lasso statistic using feature coefficients.
20+
Parameters
21+
----------
22+
X : array-like of shape (n_samples, n_features)
23+
The input data matrix.
24+
y : array-like of shape (n_samples,)
25+
The target values.
26+
lasso : estimator, default=LassoCV(n_jobs=None, verbose=0, max_iter=200000, cv=KFold(n_splits=5, shuffle=True, random_state=0), tol=1e-6)
27+
The Lasso estimator to use for computing the test statistic.
28+
n_alphas : int, default=0
29+
Number of alpha values to test for Lasso regularization path.
30+
If 0, uses the default alpha sequence from the estimator.
31+
Returns
32+
-------
33+
coef : ndarray
34+
Lasso coefficients for each feature.
35+
Raises
36+
------
37+
TypeError
38+
If the provided estimator does not have coef_ attribute or is not linear.
39+
"""
40+
if n_alphas != 0:
41+
alpha_max = np.max(np.dot(X.T, y)) / (X.shape[1])
42+
alphas = np.linspace(alpha_max * np.exp(-n_alphas), alpha_max, n_alphas)
43+
lasso.alphas = alphas
44+
lasso.fit(X, y)
45+
if hasattr(lasso, "coef_"):
46+
coef = np.ravel(lasso.coef_)
47+
elif hasattr(lasso, "best_estimator_") and hasattr(lasso.best_estimator_, "coef_"):
48+
coef = np.ravel(lasso.best_estimator_.coef_) # for CV object
49+
else:
50+
raise TypeError("estimator should be linear")
51+
return coef
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import numpy as np
2+
import pytest
3+
from sklearn.svm import SVR
4+
5+
from hidimstat.statistical_tools.lasso_test import (
6+
lasso_statistic,
7+
)
8+
9+
10+
def test_error_lasso_statistic():
11+
"""Test error lasso statistic"""
12+
with pytest.raises(TypeError, match="estimator should be linear"):
13+
lasso_statistic(X=np.random.rand(10, 10), y=np.random.rand(10), lasso=SVR())

0 commit comments

Comments
 (0)