Skip to content

Commit ace2612

Browse files
Update all unit tests to work with pd.Series refactor
1 parent e43bf38 commit ace2612

File tree

4 files changed

+34
-33
lines changed

4 files changed

+34
-33
lines changed

tests/testing_tests/test_causal_test_case.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def test_execute_test_observational_causal_forest_estimator(self):
118118
self.df,
119119
)
120120
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.data_collector)
121-
self.assertAlmostEqual(causal_test_result.test_value.value, 4, delta=1)
121+
self.assertAlmostEqual(causal_test_result.test_value.value[0], 4, delta=1)
122122

123123
def test_invalid_causal_effect(self):
124124
"""Check that executing the causal test case returns the correct results for dummy data using a linear
@@ -140,7 +140,7 @@ def test_execute_test_observational_linear_regression_estimator(self):
140140
self.df,
141141
)
142142
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.data_collector)
143-
self.assertAlmostEqual(causal_test_result.test_value.value, 4, delta=1e-10)
143+
self.assertAlmostEqual(causal_test_result.test_value.value[0], 4, delta=1e-10)
144144

145145
def test_execute_test_observational_linear_regression_estimator_direct_effect(self):
146146
"""Check that executing the causal test case returns the correct results for dummy data using a linear
@@ -167,7 +167,7 @@ def test_execute_test_observational_linear_regression_estimator_direct_effect(se
167167
self.df,
168168
)
169169
causal_test_result = causal_test_case.execute_test(estimation_model, self.data_collector)
170-
self.assertAlmostEqual(causal_test_result.test_value.value, 4, delta=1e-10)
170+
self.assertAlmostEqual(causal_test_result.test_value.value[0], 4, delta=1e-10)
171171

172172
def test_execute_test_observational_linear_regression_estimator_coefficient(self):
173173
"""Check that executing the causal test case returns the correct results for dummy data using a linear
@@ -227,7 +227,7 @@ def test_execute_test_observational_linear_regression_estimator_squared_term(sel
227227
formula=f"C ~ A + {'+'.join(self.minimal_adjustment_set)} + (D ** 2)",
228228
)
229229
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.data_collector)
230-
self.assertAlmostEqual(round(causal_test_result.test_value.value, 1), 4, delta=1)
230+
self.assertAlmostEqual(round(causal_test_result.test_value.value[0], 1), 4, delta=1)
231231

232232
def test_execute_observational_causal_forest_estimator_cates(self):
233233
"""Check that executing the causal test case returns the correct conditional average treatment effects for

tests/testing_tests/test_causal_test_outcome.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
import pandas as pd
23
from causal_testing.testing.causal_test_outcome import ExactValue, SomeEffect, Positive, Negative, NoEffect
34
from causal_testing.testing.causal_test_result import CausalTestResult, TestValue
45
from causal_testing.testing.estimators import LinearRegressionEstimator
@@ -69,7 +70,7 @@ def test_empty_adjustment_set(self):
6970
)
7071

7172
def test_Positive_ate_pass(self):
72-
test_value = TestValue(type="ate", value=5.05)
73+
test_value = TestValue(type="ate", value=pd.Series(5.05))
7374
ctr = CausalTestResult(
7475
estimator=self.estimator,
7576
test_value=test_value,
@@ -80,7 +81,7 @@ def test_Positive_ate_pass(self):
8081
self.assertTrue(ev.apply(ctr))
8182

8283
def test_Positive_risk_ratio_pass(self):
83-
test_value = TestValue(type="risk_ratio", value=2)
84+
test_value = TestValue(type="risk_ratio", value=pd.Series(2))
8485
ctr = CausalTestResult(
8586
estimator=self.estimator,
8687
test_value=test_value,
@@ -91,7 +92,7 @@ def test_Positive_risk_ratio_pass(self):
9192
self.assertTrue(ev.apply(ctr))
9293

9394
def test_Positive_fail(self):
94-
test_value = TestValue(type="ate", value=0)
95+
test_value = TestValue(type="ate", value=pd.Series(0))
9596
ctr = CausalTestResult(
9697
estimator=self.estimator,
9798
test_value=test_value,
@@ -102,7 +103,7 @@ def test_Positive_fail(self):
102103
self.assertFalse(ev.apply(ctr))
103104

104105
def test_Positive_fail_ci(self):
105-
test_value = TestValue(type="ate", value=0)
106+
test_value = TestValue(type="ate", value=pd.Series(0))
106107
ctr = CausalTestResult(
107108
estimator=self.estimator,
108109
test_value=test_value,
@@ -113,7 +114,7 @@ def test_Positive_fail_ci(self):
113114
self.assertFalse(ev.apply(ctr))
114115

115116
def test_Negative_ate_pass(self):
116-
test_value = TestValue(type="ate", value=-5.05)
117+
test_value = TestValue(type="ate", value=pd.Series(-5.05))
117118
ctr = CausalTestResult(
118119
estimator=self.estimator,
119120
test_value=test_value,
@@ -124,7 +125,7 @@ def test_Negative_ate_pass(self):
124125
self.assertTrue(ev.apply(ctr))
125126

126127
def test_Negative_risk_ratio_pass(self):
127-
test_value = TestValue(type="risk_ratio", value=0.2)
128+
test_value = TestValue(type="risk_ratio", value=pd.Series(0.2))
128129
ctr = CausalTestResult(
129130
estimator=self.estimator,
130131
test_value=test_value,
@@ -135,7 +136,7 @@ def test_Negative_risk_ratio_pass(self):
135136
self.assertTrue(ev.apply(ctr))
136137

137138
def test_Negative_fail(self):
138-
test_value = TestValue(type="ate", value=0)
139+
test_value = TestValue(type="ate", value=pd.Series(0))
139140
ctr = CausalTestResult(
140141
estimator=self.estimator,
141142
test_value=test_value,
@@ -146,7 +147,7 @@ def test_Negative_fail(self):
146147
self.assertFalse(ev.apply(ctr))
147148

148149
def test_Negative_fail_ci(self):
149-
test_value = TestValue(type="ate", value=0)
150+
test_value = TestValue(type="ate", value=pd.Series(0))
150151
ctr = CausalTestResult(
151152
estimator=self.estimator,
152153
test_value=test_value,
@@ -157,7 +158,7 @@ def test_Negative_fail_ci(self):
157158
self.assertFalse(ev.apply(ctr))
158159

159160
def test_exactValue_pass(self):
160-
test_value = TestValue(type="ate", value=5.05)
161+
test_value = TestValue(type="ate", value=pd.Series(5.05))
161162
ctr = CausalTestResult(
162163
estimator=self.estimator,
163164
test_value=test_value,
@@ -168,7 +169,7 @@ def test_exactValue_pass(self):
168169
self.assertTrue(ev.apply(ctr))
169170

170171
def test_exactValue_pass_ci(self):
171-
test_value = TestValue(type="ate", value=5.05)
172+
test_value = TestValue(type="ate", value=pd.Series(5.05))
172173
ctr = CausalTestResult(
173174
estimator=self.estimator,
174175
test_value=test_value,
@@ -179,7 +180,7 @@ def test_exactValue_pass_ci(self):
179180
self.assertTrue(ev.apply(ctr))
180181

181182
def test_exactValue_fail(self):
182-
test_value = TestValue(type="ate", value=0)
183+
test_value = TestValue(type="ate", value=pd.Series(0))
183184
ctr = CausalTestResult(
184185
estimator=self.estimator,
185186
test_value=test_value,
@@ -190,7 +191,7 @@ def test_exactValue_fail(self):
190191
self.assertFalse(ev.apply(ctr))
191192

192193
def test_invalid(self):
193-
test_value = TestValue(type="invalid", value=5.05)
194+
test_value = TestValue(type="invalid", value=pd.Series(5.05))
194195
ctr = CausalTestResult(
195196
estimator=self.estimator,
196197
test_value=test_value,
@@ -207,7 +208,7 @@ def test_invalid(self):
207208
Negative().apply(ctr)
208209

209210
def test_someEffect_pass_coefficient(self):
210-
test_value = TestValue(type="coefficient", value=5.05)
211+
test_value = TestValue(type="coefficient", value=pd.Series(5.05))
211212
ctr = CausalTestResult(
212213
estimator=self.estimator,
213214
test_value=test_value,
@@ -218,7 +219,7 @@ def test_someEffect_pass_coefficient(self):
218219
self.assertFalse(NoEffect().apply(ctr))
219220

220221
def test_someEffect_pass_ate(self):
221-
test_value = TestValue(type="ate", value=5.05)
222+
test_value = TestValue(type="ate", value=pd.Series(5.05))
222223
ctr = CausalTestResult(
223224
estimator=self.estimator,
224225
test_value=test_value,
@@ -229,7 +230,7 @@ def test_someEffect_pass_ate(self):
229230
self.assertFalse(NoEffect().apply(ctr))
230231

231232
def test_someEffect_pass_rr(self):
232-
test_value = TestValue(type="risk_ratio", value=5.05)
233+
test_value = TestValue(type="risk_ratio", value=pd.Series(5.05))
233234
ctr = CausalTestResult(
234235
estimator=self.estimator,
235236
test_value=test_value,
@@ -240,7 +241,7 @@ def test_someEffect_pass_rr(self):
240241
self.assertFalse(NoEffect().apply(ctr))
241242

242243
def test_someEffect_fail(self):
243-
test_value = TestValue(type="ate", value=0)
244+
test_value = TestValue(type="ate", value=pd.Series(0))
244245
ctr = CausalTestResult(
245246
estimator=self.estimator,
246247
test_value=test_value,

tests/testing_tests/test_causal_test_suite.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_execute_test_suite_single_base_test_case(self):
9898

9999
causal_test_results = self.test_suite.execute_test_suite(self.data_collector, self.causal_specification)
100100
causal_test_case_result = causal_test_results[self.base_test_case]
101-
self.assertAlmostEqual(causal_test_case_result["LinearRegressionEstimator"][0].test_value.value, 4, delta=1e-10)
101+
self.assertAlmostEqual(causal_test_case_result["LinearRegressionEstimator"][0].test_value.value[0], 4, delta=1e-10)
102102

103103
def test_execute_test_suite_multiple_estimators(self):
104104
"""Check that executing a test suite with multiple estimators returns correct results for the dummy data
@@ -114,5 +114,5 @@ def test_execute_test_suite_multiple_estimators(self):
114114
causal_test_case_result = causal_test_results[self.base_test_case]
115115
linear_regression_result = causal_test_case_result["LinearRegressionEstimator"][0]
116116
causal_forrest_result = causal_test_case_result["CausalForestEstimator"][0]
117-
self.assertAlmostEqual(linear_regression_result.test_value.value, 4, delta=1e-1)
118-
self.assertAlmostEqual(causal_forrest_result.test_value.value, 4, delta=1e-1)
117+
self.assertAlmostEqual(linear_regression_result.test_value.value[0], 4, delta=1e-1)
118+
self.assertAlmostEqual(causal_forrest_result.test_value.value[0], 4, delta=1e-1)

tests/testing_tests/test_estimators.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def test_estimate_coefficient(self):
199199
instrument="Z",
200200
)
201201
coefficient, [low, high] = iv_estimator.estimate_coefficient()
202-
self.assertEqual(coefficient, 2)
202+
self.assertEqual(coefficient[0], 2)
203203

204204

205205
class TestLinearRegressionEstimator(unittest.TestCase):
@@ -364,8 +364,8 @@ def test_program_15_no_interaction_ate(self):
364364
# terms_to_square = ["age", "wt71", "smokeintensity", "smokeyrs"]
365365
# for term_to_square in terms_to_square:
366366
ate, [ci_low, ci_high] = linear_regression_estimator.estimate_ate()
367-
self.assertEqual(round(ate, 1), 3.5)
368-
self.assertEqual([round(ci_low, 1), round(ci_high, 1)], [2.6, 4.3])
367+
self.assertEqual(round(ate[0], 1), 3.5)
368+
self.assertEqual([round(ci_low[0], 1), round(ci_high[0], 1)], [2.6, 4.3])
369369

370370
def test_program_15_no_interaction_ate_calculated(self):
371371
"""Test whether our linear regression implementation produces the same results as program 15.1 (p. 163, 184)
@@ -402,8 +402,8 @@ def test_program_15_no_interaction_ate_calculated(self):
402402
ate, [ci_low, ci_high] = linear_regression_estimator.estimate_ate_calculated(
403403
adjustment_config={k: self.nhefs_df.mean()[k] for k in covariates}
404404
)
405-
self.assertEqual(round(ate, 1), 3.5)
406-
self.assertEqual([round(ci_low, 1), round(ci_high, 1)], [1.9, 5])
405+
self.assertEqual(round(ate[0], 1), 3.5)
406+
self.assertEqual([round(ci_low[0], 1), round(ci_high[0], 1)], [1.9, 5])
407407

408408
def test_program_11_2_with_robustness_validation(self):
409409
"""Test whether our linear regression estimator, as used in test_program_11_2 can correctly estimate robustness."""
@@ -449,8 +449,8 @@ def test_program_11_3_cublic_spline(self):
449449
ate_2 = cublic_spline_estimator.estimate_ate_calculated()
450450

451451
# Doubling the treatemebnt value should roughly but not exactly double the ATE
452-
self.assertNotEqual(ate_1 * 2, ate_2)
453-
self.assertAlmostEqual(ate_1 * 2, ate_2)
452+
self.assertNotEqual(ate_1[0] * 2, ate_2[0])
453+
self.assertAlmostEqual(ate_1[0] * 2, ate_2[0])
454454

455455

456456

@@ -488,8 +488,8 @@ def test_program_15_ate(self):
488488
}
489489
causal_forest = CausalForestEstimator("qsmk", 1, 0, covariates, "wt82_71", df, {"smokeintensity": 40})
490490
ate, _ = causal_forest.estimate_ate()
491-
self.assertGreater(round(ate, 1), 2.5)
492-
self.assertLess(round(ate, 1), 4.5)
491+
self.assertGreater(round(ate[0], 1), 2.5)
492+
self.assertLess(round(ate[0], 1), 4.5)
493493

494494
def test_program_15_cate(self):
495495
"""Test whether our causal forest implementation produces the similar CATE to program 15.1 (p. 163, 184)."""
@@ -535,7 +535,7 @@ def test_X1_effect(self):
535535
"X1", 1, 0, {"X2"}, "Y", effect_modifiers={x2.name: 0}, formula="Y ~ X1 + X2 + (X1 * X2)", df=self.df
536536
)
537537
test_results = lr_model.estimate_ate()
538-
ate = test_results[0]
538+
ate = test_results[0][0]
539539
self.assertAlmostEqual(ate, 2.0)
540540

541541

0 commit comments

Comments
 (0)