Skip to content

Commit d127487

Browse files
committed
unit tests
1 parent dae1800 commit d127487

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

tests/testing_tests/test_causal_test_outcome.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from causal_testing.testing.causal_test_outcome import ExactValue, SomeEffect, Positive, Negative
33
from causal_testing.testing.causal_test_result import CausalTestResult, TestValue
44
from causal_testing.testing.estimators import LinearRegressionEstimator
5+
from causal_testing.testing.validation import CausalValidator
56

67

78
class TestCausalTestOutcome(unittest.TestCase):
@@ -176,3 +177,33 @@ def test_someEffect_fail(self):
176177
"ci_high": 0.2,
177178
},
178179
)
180+
181+
def test_positive_risk_ratio_e_value(self):
182+
test_value = TestValue("risk_ratio", 1.5)
183+
ctr = CausalTestResult(
184+
estimator=self.estimator,
185+
test_value=test_value,
186+
confidence_intervals=[1.2, 1.8],
187+
effect_modifier_configuration=None,
188+
)
189+
190+
cv = CausalValidator()
191+
e_value, e_confidence_intervals = cv.estimate_e_value(ctr.test_value.value, ctr.confidence_intervals)
192+
self.assertEqual(round(e_value, 4), 2.366)
193+
self.assertEqual(round(e_confidence_intervals[0], 4), 1.6899)
194+
self.assertEqual(e_confidence_intervals[1], 1)
195+
196+
def test_negative_risk_ratio_e_value(self):
197+
test_value = TestValue("risk_ratio", 0.8)
198+
ctr = CausalTestResult(
199+
estimator=self.estimator,
200+
test_value=test_value,
201+
confidence_intervals=[0.2, 0.9],
202+
effect_modifier_configuration=None,
203+
)
204+
205+
cv = CausalValidator()
206+
e_value, e_confidence_intervals = cv.estimate_e_value(ctr.test_value.value, ctr.confidence_intervals)
207+
self.assertEqual(round(e_value, 4), 1.809)
208+
self.assertEqual(e_confidence_intervals[0], 1)
209+
self.assertEqual(round(e_confidence_intervals[1], 4), 1.4625)

tests/testing_tests/test_estimators.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
InstrumentalVariableEstimator,
1010
)
1111
from causal_testing.specification.variable import Input
12+
from causal_testing.testing.validation import CausalValidator
1213

1314

1415
def plot_results_df(df):
@@ -372,6 +373,15 @@ def test_program_15_no_interaction_ate_calculated(self):
372373
self.assertEqual(round(ate, 1), 3.5)
373374
self.assertEqual([round(ci_low, 1), round(ci_high, 1)], [1.9, 5])
374375

376+
def test_program_11_2_with_robustness_validation(self):
377+
"""Test whether our linear regression estimator, as used in test_program_11_2 can correctly estimate robustness."""
378+
df = self.chapter_11_df.copy()
379+
linear_regression_estimator = LinearRegressionEstimator("treatments", 100, 90, set(), "outcomes", df)
380+
model = linear_regression_estimator._run_linear_regression()
381+
382+
cv = CausalValidator()
383+
self.assertEqual(round(cv.estimate_robustness(model)["treatments"], 4), 0.7353)
384+
375385

376386
class TestCausalForestEstimator(unittest.TestCase):
377387
"""Test the linear regression estimator against the programming exercises in Section 2 of Hernán and Robins [1].

0 commit comments

Comments
 (0)