Skip to content

Commit b848a6c

Browse files
committed
add test for lasso
1 parent cd4d367 commit b848a6c

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import numpy as np
2+
import pytest
3+
from sklearn.linear_model import RidgeCV
4+
from sklearn.svm import SVR
5+
6+
from hidimstat.statistical_tools.lasso_test import (
7+
preconfigure_LassoCV,
8+
lasso_statistic_with_sampling,
9+
lasso_statistic,
10+
)
11+
12+
13+
def test_preconfigure_LassoCV():
14+
"""Test type errors"""
15+
with pytest.raises(
16+
TypeError, match="You should not use this function to configure the estimator"
17+
):
18+
preconfigure_LassoCV(
19+
estimator=RidgeCV(),
20+
X=np.random.rand(10, 10),
21+
y=np.random.rand(10),
22+
X_tilde=np.random.rand(10, 10),
23+
)
24+
25+
26+
def test_error_lasso_statistic_with_sampling_with_bad_config():
27+
"""Test error lasso statistic"""
28+
with pytest.raises(
29+
TypeError, match="You should not use this function to configure the estimator"
30+
):
31+
lasso_statistic_with_sampling(
32+
X=np.random.rand(10, 10),
33+
X_tilde=np.random.rand(10, 10),
34+
y=np.random.rand(10),
35+
lasso=SVR(),
36+
)
37+
38+
39+
def test_error_lasso_statistic_with_sampling():
40+
"""Test error lasso statistic"""
41+
with pytest.raises(TypeError, match="estimator should be linear"):
42+
lasso_statistic_with_sampling(
43+
X=np.random.rand(10, 10),
44+
X_tilde=np.random.rand(10, 10),
45+
y=np.random.rand(10),
46+
lasso=SVR(),
47+
preconfigure_lasso=None,
48+
)
49+
50+
51+
def test_error_lasso_statistic():
52+
"""Test error lasso statistic"""
53+
with pytest.raises(TypeError, match="estimator should be linear"):
54+
lasso_statistic(X=np.random.rand(10, 10), y=np.random.rand(10), lasso=SVR())

0 commit comments

Comments
 (0)