11import numpy as np
22import pytest
3+ import scipy .stats
34from sklearn .ensemble import RandomForestRegressor
45
56from ivmodels .tests import residual_prediction_test
67
78
9+ @pytest .mark .parametrize ("robust" , [False , True ])
810@pytest .mark .parametrize (
911 "n, k, mx, mc, fit_intercept" ,
10- [(200 , 3 , 3 , 1 , True ), (200 , 3 , 1 , 1 , False ), (200 , 15 , 5 , 5 , False )],
12+ [(500 , 3 , 3 , 1 , True ), (500 , 3 , 1 , 1 , False ), (500 , 15 , 5 , 5 , True )],
1113)
12- def test_residual_prediction_test (n , k , mx , mc , fit_intercept ):
14+ def test_residual_prediction_test (n , k , mx , mc , fit_intercept , robust ):
1315 rng = np .random .default_rng (0 )
1416
1517 Pi = rng .normal (size = (k , mx ))
@@ -19,7 +21,7 @@ def test_residual_prediction_test(n, k, mx, mc, fit_intercept):
1921 Pi_CX = rng .normal (size = (mc , mx ))
2022 Pi_Cy = rng .normal (size = (mc , 1 ))
2123
22- n_seeds = 50
24+ n_seeds = 20
2325 statistics = np .zeros (n_seeds )
2426 p_values = np .zeros (n_seeds )
2527
@@ -29,20 +31,27 @@ def test_residual_prediction_test(n, k, mx, mc, fit_intercept):
2931 U = rng .normal (size = (n , mx + 1 ))
3032 X = Z @ Pi + U @ gamma + C @ Pi_CX + rng .normal (size = (n , mx ))
3133 X [:, 0 ] += Z [:, 0 ] ** 2 # allow for nonlinearity Z -> X
32- y = X @ beta + U [:, 0 :1 ] + U [:, 0 :1 ] ** 3 + C @ Pi_Cy + rng .normal (size = (n , 1 ))
34+ noise = rng .normal (size = (n , 1 ))
35+ if robust :
36+ noise *= Z [:, 0 :1 ] ** 2
37+ y = X @ beta + U [:, 0 :1 ] + np .sin (U [:, 0 :1 ]) + C @ Pi_Cy + noise
3338
3439 statistics [idx ], p_values [idx ] = residual_prediction_test (
3540 Z = Z ,
3641 X = X ,
3742 y = y ,
3843 C = C ,
44+ robust = robust ,
3945 nonlinear_model = RandomForestRegressor (n_estimators = 20 , random_state = 0 ),
4046 fit_intercept = fit_intercept ,
41- train_fraction = 0.6 ,
47+ train_fraction = 0.4 ,
4248 seed = 0 ,
4349 )
4450
45- assert np .mean (p_values < 0.1 ) < 0.05
51+ assert (
52+ scipy .stats .kstest (p_values , scipy .stats .uniform (loc = 0.0 , scale = 1.0 ).cdf ).pvalue
53+ > 0.05
54+ )
4655
4756
4857def test_residual_prediction_test_rejects ():
0 commit comments