Skip to content

Commit 52efc8e

Browse files
authored
Fix the test_accuracy function by modifying the assertion logic (#867)
Fix the test_accuracy function by modifying the assertion logic Signed-off-by: kgao <[email protected]>
1 parent 6219695 commit 52efc8e

File tree

1 file changed

+46
-28
lines changed

1 file changed

+46
-28
lines changed

econml/tests/test_discrete_outcome.py

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@ def test_accuracy(self):
2727
discrete_outcome = True
2828
discrete_treatment = True
2929
true_ate = 0.3
30-
W = np.random.uniform(-1, 1, size=(n, 1))
31-
D = np.random.binomial(1, .5 + .1 * W[:, 0], size=(n,))
32-
Y = np.random.binomial(1, .5 + true_ate * D + .1 * W[:, 0], size=(n,))
30+
num_iterations = 10
3331

3432
ests = [
3533
LinearDML(discrete_outcome=discrete_outcome, discrete_treatment=discrete_treatment),
@@ -39,34 +37,42 @@ def test_accuracy(self):
3937

4038
for est in ests:
4139

42-
if isinstance(est, CausalForestDML):
43-
est.fit(Y, D, X=W)
44-
ate = est.ate(X=W)
45-
ate_lb, ate_ub = est.ate_interval(X=W)
40+
count_within_interval = 0
4641

47-
else:
48-
est.fit(Y, D, W=W)
49-
ate = est.ate()
50-
ate_lb, ate_ub = est.ate_interval()
42+
for _ in range(num_iterations):
5143

52-
if isinstance(est, LinearDRLearner):
53-
est.summary(T=1)
54-
else:
55-
est.summary()
44+
W = np.random.uniform(-1, 1, size=(n, 1))
45+
D = np.random.binomial(1, .5 + .1 * W[:, 0], size=(n,))
46+
Y = np.random.binomial(1, .5 + true_ate * D + .1 * W[:, 0], size=(n,))
47+
48+
if isinstance(est, CausalForestDML):
49+
est.fit(Y, D, X=W)
50+
ate_lb, ate_ub = est.ate_interval(X=W)
51+
52+
else:
53+
est.fit(Y, D, W=W)
54+
ate_lb, ate_ub = est.ate_interval()
55+
56+
if isinstance(est, LinearDRLearner):
57+
est.summary(T=1)
58+
else:
59+
est.summary()
60+
61+
if ate_lb <= true_ate <= ate_ub:
62+
count_within_interval += 1
5663

57-
proportion_in_interval = ((ate_lb < true_ate) & (true_ate < ate_ub)).mean()
58-
np.testing.assert_array_less(0.50, proportion_in_interval)
64+
assert count_within_interval >= 7, (
65+
f"{est.__class__.__name__}: True ATE falls within the interval bounds "
66+
f"only {count_within_interval} times out of {num_iterations}"
67+
)
5968

6069
# accuracy test, DML
6170
def test_accuracy_iv(self):
62-
n = 10000
71+
n = 1000
6372
discrete_outcome = True
6473
discrete_treatment = True
6574
true_ate = 0.3
66-
W = np.random.uniform(-1, 1, size=(n, 1))
67-
Z = np.random.uniform(-1, 1, size=(n, 1))
68-
D = np.random.binomial(1, .5 + .1 * W[:, 0] + .1 * Z[:, 0], size=(n,))
69-
Y = np.random.binomial(1, .5 + true_ate * D + .1 * W[:, 0], size=(n,))
75+
num_iterations = 10
7076

7177
ests = [
7278
OrthoIV(discrete_outcome=discrete_outcome, discrete_treatment=discrete_treatment),
@@ -75,14 +81,26 @@ def test_accuracy_iv(self):
7581

7682
for est in ests:
7783

78-
est.fit(Y, D, W=W, Z=Z)
79-
ate = est.ate()
80-
ate_lb, ate_ub = est.ate_interval()
84+
count_within_interval = 0
85+
86+
for _ in range(num_iterations):
87+
88+
W = np.random.uniform(-1, 1, size=(n, 1))
89+
Z = np.random.uniform(-1, 1, size=(n, 1))
90+
D = np.random.binomial(1, .5 + .1 * W[:, 0] + .1 * Z[:, 0], size=(n,))
91+
Y = np.random.binomial(1, .5 + true_ate * D + .1 * W[:, 0], size=(n,))
92+
93+
est.fit(Y, D, W=W, Z=Z)
94+
ate_lb, ate_ub = est.ate_interval()
95+
est.summary()
8196

82-
est.summary()
97+
if ate_lb <= true_ate <= ate_ub:
98+
count_within_interval += 1
8399

84-
proportion_in_interval = ((ate_lb < true_ate) & (true_ate < ate_ub)).mean()
85-
np.testing.assert_array_less(0.50, proportion_in_interval)
100+
assert count_within_interval >= 7, (
101+
f"{est.__class__.__name__}: True ATE falls within the interval bounds "
102+
f"only {count_within_interval} times out of {num_iterations}"
103+
)
86104

87105
def test_string_outcome(self):
88106
n = 100

0 commit comments

Comments
 (0)