Skip to content

Commit cd4d367

Browse files
committed
improve test
1 parent 4f6047b commit cd4d367

File tree

3 files changed

+72
-69
lines changed

3 files changed

+72
-69
lines changed

src/hidimstat/statistical_tools/lasso_test.py

Lines changed: 51 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -55,57 +55,6 @@ def preconfigure_LassoCV(estimator, X, X_tilde, y, n_alphas=20):
5555
return estimator
5656

5757

58-
def lasso_statistic(
59-
X,
60-
y,
61-
lasso=LassoCV(
62-
n_jobs=1,
63-
verbose=0,
64-
max_iter=200000,
65-
cv=KFold(n_splits=5, shuffle=True, random_state=0),
66-
tol=1e-6,
67-
),
68-
n_alphas=0,
69-
):
70-
"""
71-
Compute Lasso statistic using feature coefficients.
72-
73-
Parameters
74-
----------
75-
X : array-like of shape (n_samples, n_features)
76-
The input data matrix.
77-
y : array-like of shape (n_samples,)
78-
The target values.
79-
lasso : estimator, default=LassoCV(n_jobs=None, verbose=0, max_iter=200000, cv=KFold(n_splits=5, shuffle=True, random_state=0), tol=1e-6)
80-
The Lasso estimator to use for computing the test statistic.
81-
n_alphas : int, default=0
82-
Number of alpha values to test for Lasso regularization path.
83-
If 0, uses the default alpha sequence from the estimator.
84-
85-
Returns
86-
-------
87-
coef : ndarray
88-
Lasso coefficients for each feature.
89-
90-
Raises
91-
------
92-
TypeError
93-
If the provided estimator does not have coef_ attribute or is not linear.
94-
"""
95-
if n_alphas != 0:
96-
alpha_max = np.max(np.dot(X.T, y)) / (X.shape[1])
97-
alphas = np.linspace(alpha_max * np.exp(-n_alphas), alpha_max, n_alphas)
98-
lasso.alphas = alphas
99-
lasso.fit(X, y)
100-
if hasattr(lasso, "coef_"):
101-
coef = np.ravel(lasso.coef_)
102-
elif hasattr(lasso, "best_estimator_") and hasattr(lasso.best_estimator_, "coef_"):
103-
coef = np.ravel(lasso.best_estimator_.coef_) # for CV object
104-
else:
105-
raise TypeError("estimator should be linear")
106-
return coef
107-
108-
10958
def lasso_statistic_with_sampling(
11059
X,
11160
X_tilde,
@@ -180,3 +129,54 @@ def lasso_statistic_with_sampling(
180129
test_score = np.abs(coef[:n_features]) - np.abs(coef[n_features:])
181130

182131
return test_score
132+
133+
134+
def lasso_statistic(
135+
X,
136+
y,
137+
lasso=LassoCV(
138+
n_jobs=1,
139+
verbose=0,
140+
max_iter=200000,
141+
cv=KFold(n_splits=5, shuffle=True, random_state=0),
142+
tol=1e-6,
143+
),
144+
n_alphas=0,
145+
):
146+
"""
147+
Compute Lasso statistic using feature coefficients.
148+
149+
Parameters
150+
----------
151+
X : array-like of shape (n_samples, n_features)
152+
The input data matrix.
153+
y : array-like of shape (n_samples,)
154+
The target values.
155+
lasso : estimator, default=LassoCV(n_jobs=None, verbose=0, max_iter=200000, cv=KFold(n_splits=5, shuffle=True, random_state=0), tol=1e-6)
156+
The Lasso estimator to use for computing the test statistic.
157+
n_alphas : int, default=0
158+
Number of alpha values to test for Lasso regularization path.
159+
If 0, uses the default alpha sequence from the estimator.
160+
161+
Returns
162+
-------
163+
coef : ndarray
164+
Lasso coefficients for each feature.
165+
166+
Raises
167+
------
168+
TypeError
169+
If the provided estimator does not have coef_ attribute or is not linear.
170+
"""
171+
if n_alphas != 0:
172+
alpha_max = np.max(np.dot(X.T, y)) / (X.shape[1])
173+
alphas = np.linspace(alpha_max * np.exp(-n_alphas), alpha_max, n_alphas)
174+
lasso.alphas = alphas
175+
lasso.fit(X, y)
176+
if hasattr(lasso, "coef_"):
177+
coef = np.ravel(lasso.coef_)
178+
elif hasattr(lasso, "best_estimator_") and hasattr(lasso.best_estimator_, "coef_"):
179+
coef = np.ravel(lasso.best_estimator_.coef_) # for CV object
180+
else:
181+
raise TypeError("estimator should be linear")
182+
return coef

test/test_conditional_randomization_test.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from hidimstat import CRT
1010
from hidimstat.conditional_randomization_test import crt
11-
from hidimstat.statistical_tools.gaussian_knockoff import GaussianGenerator
11+
from hidimstat.statistical_tools.gaussian_distribution import GaussianDistribution
1212
from hidimstat.statistical_tools.multiple_testing import fdp_power
1313
from hidimstat.statistical_tools.lasso_test import lasso_statistic
1414

@@ -45,7 +45,7 @@ def configure_linear_categorial_crt(X, y, n_repeat, seed, fdr):
4545
# instantiate CRT model with linear regression imputer
4646
crt = CRT(
4747
n_repeat=n_repeat,
48-
generator=GaussianGenerator(
48+
generator=GaussianDistribution(
4949
cov_estimator=LedoitWolf(assume_centered=True), random_state=seed
5050
),
5151
)
@@ -267,11 +267,12 @@ def lasso_statistic_gen(X, y):
267267
lasso=GridSearchCV(
268268
Lasso(), param_grid={"alpha": np.linspace(0.2, 0.3, 5)}
269269
),
270+
n_alphas=5,
270271
)
271272

272273
crt = CRT(
273274
n_repeat=1,
274-
generator=GaussianGenerator(
275+
generator=GaussianDistribution(
275276
cov_estimator=LedoitWolf(assume_centered=True), random_state=seed + 2
276277
),
277278
statistical_test=lasso_statistic_gen,
@@ -290,7 +291,7 @@ def test_estimate_distribution(self, data_generator, seed):
290291
X, y, important_features, _ = data_generator
291292
crt = CRT(
292293
n_repeat=1,
293-
generator=GaussianGenerator(
294+
generator=GaussianDistribution(
294295
cov_estimator=LedoitWolf(assume_centered=True), random_state=seed + 1
295296
),
296297
)
@@ -301,7 +302,7 @@ def test_estimate_distribution(self, data_generator, seed):
301302

302303
crt = CRT(
303304
n_repeat=1,
304-
generator=GaussianGenerator(
305+
generator=GaussianDistribution(
305306
cov_estimator=GraphicalLassoCV(
306307
alphas=[1e-3, 1e-2, 1e-1, 1],
307308
cv=KFold(n_splits=5, shuffle=True, random_state=0),
@@ -366,7 +367,7 @@ def test_crt_invariant_with_bootstrap(self, data_generator):
366367
# Single AKO (or vanilla KO) (verbose vs no verbose)
367368
crt_repeat = CRT(
368369
n_repeat=5,
369-
generator=GaussianGenerator(
370+
generator=GaussianDistribution(
370371
cov_estimator=LedoitWolf(assume_centered=True), random_state=5
371372
),
372373
)
@@ -375,7 +376,7 @@ def test_crt_invariant_with_bootstrap(self, data_generator):
375376

376377
crt_no_repeat = CRT(
377378
n_repeat=1,
378-
generator=GaussianGenerator(
379+
generator=GaussianDistribution(
379380
cov_estimator=LedoitWolf(assume_centered=True), random_state=5
380381
),
381382
)

test/test_knockoff.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88

99
from hidimstat import ModelXKnockoff
1010
from hidimstat.knockoffs import model_x_knockoff
11-
from hidimstat.statistical_tools.gaussian_knockoff import GaussianGenerator
11+
from hidimstat.statistical_tools.gaussian_distribution import GaussianDistribution
1212
from hidimstat.statistical_tools.multiple_testing import fdp_power
13-
from hidimstat.statistical_tools.lasso_test import lasso_statistic
13+
from hidimstat.statistical_tools.lasso_test import lasso_statistic_with_sampling
1414

1515

1616
def configure_linear_categorial_model_x_knockoff(X, y, n_repeat, seed, fdr):
@@ -45,7 +45,7 @@ def configure_linear_categorial_model_x_knockoff(X, y, n_repeat, seed, fdr):
4545
# instantiate ModelXKnockoff model with linear regression imputer
4646
model_x_knockoff = ModelXKnockoff(
4747
n_repeat=n_repeat,
48-
generator=GaussianGenerator(
48+
generator=GaussianDistribution(
4949
cov_estimator=LedoitWolf(assume_centered=True), random_state=seed
5050
),
5151
)
@@ -316,18 +316,20 @@ def test_model_x_knockoff_CV_estimator(self, data_generator, seed):
316316
fdr = 0.7
317317
X, y, important_features, _ = data_generator
318318

319-
def lasso_statistic_gen(X, y):
320-
return lasso_statistic(
319+
def lasso_statistic_gen(X, X_tilde, y):
320+
return lasso_statistic_with_sampling(
321321
X,
322+
X_tilde,
322323
y,
323324
lasso=GridSearchCV(
324325
Lasso(), param_grid={"alpha": np.linspace(0.2, 0.3, 5)}
325326
),
327+
preconfigure_lasso=None,
326328
)
327329

328330
model_x_knockoff = ModelXKnockoff(
329331
n_repeat=1,
330-
generator=GaussianGenerator(
332+
generator=GaussianDistribution(
331333
cov_estimator=LedoitWolf(assume_centered=True), random_state=seed + 2
332334
),
333335
statistical_test=lasso_statistic_gen,
@@ -346,7 +348,7 @@ def test_estimate_distribution(self, data_generator, seed):
346348
X, y, important_features, _ = data_generator
347349
model_x_knockoff = ModelXKnockoff(
348350
n_repeat=1,
349-
generator=GaussianGenerator(
351+
generator=GaussianDistribution(
350352
cov_estimator=LedoitWolf(assume_centered=True), random_state=seed + 1
351353
),
352354
)
@@ -357,7 +359,7 @@ def test_estimate_distribution(self, data_generator, seed):
357359

358360
model_x_knockoff = ModelXKnockoff(
359361
n_repeat=1,
360-
generator=GaussianGenerator(
362+
generator=GaussianDistribution(
361363
cov_estimator=GraphicalLassoCV(
362364
alphas=[1e-3, 1e-2, 1e-1, 1],
363365
cv=KFold(n_splits=5, shuffle=True, random_state=seed + 2),
@@ -408,7 +410,7 @@ def test_model_x_knockoff_repeat_e_values(self, data_generator, n_features):
408410
X, y, important_features, _ = data_generator
409411
model_x_knockoff = ModelXKnockoff(
410412
n_repeat=n_repeat,
411-
generator=GaussianGenerator(
413+
generator=GaussianDistribution(
412414
cov_estimator=LedoitWolf(assume_centered=True), random_state=0
413415
),
414416
)
@@ -431,7 +433,7 @@ def test_model_x_knockoff_invariant_with_bootstrap(self, data_generator):
431433
# Single AKO (or vanilla KO) (verbose vs no verbose)
432434
model_x_knockoff_repeat = ModelXKnockoff(
433435
n_repeat=10,
434-
generator=GaussianGenerator(
436+
generator=GaussianDistribution(
435437
cov_estimator=LedoitWolf(assume_centered=True), random_state=0
436438
),
437439
)
@@ -440,7 +442,7 @@ def test_model_x_knockoff_invariant_with_bootstrap(self, data_generator):
440442

441443
model_x_knockoff_no_repeat = ModelXKnockoff(
442444
n_repeat=1,
443-
generator=GaussianGenerator(
445+
generator=GaussianDistribution(
444446
cov_estimator=LedoitWolf(assume_centered=True), random_state=0
445447
),
446448
)

0 commit comments

Comments
 (0)