@@ -27,9 +27,7 @@ def test_accuracy(self):
2727 discrete_outcome = True
2828 discrete_treatment = True
2929 true_ate = 0.3
30- W = np .random .uniform (- 1 , 1 , size = (n , 1 ))
31- D = np .random .binomial (1 , .5 + .1 * W [:, 0 ], size = (n ,))
32- Y = np .random .binomial (1 , .5 + true_ate * D + .1 * W [:, 0 ], size = (n ,))
30+ num_iterations = 10
3331
3432 ests = [
3533 LinearDML (discrete_outcome = discrete_outcome , discrete_treatment = discrete_treatment ),
@@ -39,34 +37,42 @@ def test_accuracy(self):
3937
4038 for est in ests :
4139
42- if isinstance (est , CausalForestDML ):
43- est .fit (Y , D , X = W )
44- ate = est .ate (X = W )
45- ate_lb , ate_ub = est .ate_interval (X = W )
40+ count_within_interval = 0
4641
47- else :
48- est .fit (Y , D , W = W )
49- ate = est .ate ()
50- ate_lb , ate_ub = est .ate_interval ()
42+ for _ in range (num_iterations ):
5143
52- if isinstance (est , LinearDRLearner ):
53- est .summary (T = 1 )
54- else :
55- est .summary ()
44+ W = np .random .uniform (- 1 , 1 , size = (n , 1 ))
45+ D = np .random .binomial (1 , .5 + .1 * W [:, 0 ], size = (n ,))
46+ Y = np .random .binomial (1 , .5 + true_ate * D + .1 * W [:, 0 ], size = (n ,))
47+
48+ if isinstance (est , CausalForestDML ):
49+ est .fit (Y , D , X = W )
50+ ate_lb , ate_ub = est .ate_interval (X = W )
51+
52+ else :
53+ est .fit (Y , D , W = W )
54+ ate_lb , ate_ub = est .ate_interval ()
55+
56+ if isinstance (est , LinearDRLearner ):
57+ est .summary (T = 1 )
58+ else :
59+ est .summary ()
60+
61+ if ate_lb <= true_ate <= ate_ub :
62+ count_within_interval += 1
5663
57- proportion_in_interval = ((ate_lb < true_ate ) & (true_ate < ate_ub )).mean ()
58- np .testing .assert_array_less (0.50 , proportion_in_interval )
64+ assert count_within_interval >= 7 , (
65+ f"{ est .__class__ .__name__ } : True ATE falls within the interval bounds "
66+ f"only { count_within_interval } times out of { num_iterations } "
67+ )
5968
6069 # accuracy test, DML
6170 def test_accuracy_iv (self ):
62- n = 10000
71+ n = 1000
6372 discrete_outcome = True
6473 discrete_treatment = True
6574 true_ate = 0.3
66- W = np .random .uniform (- 1 , 1 , size = (n , 1 ))
67- Z = np .random .uniform (- 1 , 1 , size = (n , 1 ))
68- D = np .random .binomial (1 , .5 + .1 * W [:, 0 ] + .1 * Z [:, 0 ], size = (n ,))
69- Y = np .random .binomial (1 , .5 + true_ate * D + .1 * W [:, 0 ], size = (n ,))
75+ num_iterations = 10
7076
7177 ests = [
7278 OrthoIV (discrete_outcome = discrete_outcome , discrete_treatment = discrete_treatment ),
@@ -75,14 +81,26 @@ def test_accuracy_iv(self):
7581
7682 for est in ests :
7783
78- est .fit (Y , D , W = W , Z = Z )
79- ate = est .ate ()
80- ate_lb , ate_ub = est .ate_interval ()
84+ count_within_interval = 0
85+
86+ for _ in range (num_iterations ):
87+
88+ W = np .random .uniform (- 1 , 1 , size = (n , 1 ))
89+ Z = np .random .uniform (- 1 , 1 , size = (n , 1 ))
90+ D = np .random .binomial (1 , .5 + .1 * W [:, 0 ] + .1 * Z [:, 0 ], size = (n ,))
91+ Y = np .random .binomial (1 , .5 + true_ate * D + .1 * W [:, 0 ], size = (n ,))
92+
93+ est .fit (Y , D , W = W , Z = Z )
94+ ate_lb , ate_ub = est .ate_interval ()
95+ est .summary ()
8196
82- est .summary ()
97+ if ate_lb <= true_ate <= ate_ub :
98+ count_within_interval += 1
8399
84- proportion_in_interval = ((ate_lb < true_ate ) & (true_ate < ate_ub )).mean ()
85- np .testing .assert_array_less (0.50 , proportion_in_interval )
100+ assert count_within_interval >= 7 , (
101+ f"{ est .__class__ .__name__ } : True ATE falls within the interval bounds "
102+ f"only { count_within_interval } times out of { num_iterations } "
103+ )
86104
87105 def test_string_outcome (self ):
88106 n = 100
0 commit comments