Skip to content

Commit b38f421

Browse files
Update json tests to use CausalTestCase to execute tests
1 parent cab1086 commit b38f421

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

tests/testing_tests/test_causal_test_engine.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from causal_testing.testing.causal_test_outcome import ExactValue
1313
from causal_testing.testing.estimators import CausalForestEstimator, LinearRegressionEstimator
1414
from causal_testing.testing.base_test_case import BaseTestCase
15-
15+
from causal_testing.testing.causal_test_result import CausalTestResult, TestValue
1616

1717
class TestCausalTestEngineObservational(unittest.TestCase):
1818
"""Test the CausalTestEngine workflow using observational data.
@@ -60,27 +60,27 @@ def setUp(self) -> None:
6060
# 5. Create observational data collector
6161
# Obsolete?
6262
self.data_collector = ObservationalDataCollector(self.scenario, df)
63-
63+
self.df = self.data_collector.collect_data()
6464
# 5. Create causal test engine
65-
self.causal_test_engine = CausalTestEngine(self.causal_specification, self.data_collector)
65+
# self.causal_test_engine = CausalTestEngine(self.causal_specification, self.data_collector)
6666
self.minimal_adjustment_set = self.causal_dag.identification(self.base_test_case)
6767
# 6. Easier to access treatment and outcome values
6868
self.treatment_value = 1
6969
self.control_value = 0
7070

7171
def test_positivity_violation_throws_exception(self):
72-
causal_test_engine = CausalTestEngine(self.causal_specification, self.data_collector)
73-
causal_test_engine.scenario_execution_data_df.drop("A", axis=1, inplace=True)
72+
data_collector = self.data_collector
73+
data_collector.data.drop("A", axis=1, inplace=True)
7474
estimation_model = LinearRegressionEstimator(
7575
"A",
7676
self.treatment_value,
7777
self.control_value,
7878
self.minimal_adjustment_set,
7979
"C",
80-
self.causal_test_engine.scenario_execution_data_df,
80+
self.df
8181
)
82-
with self.assertRaises(Exception):
83-
causal_test_engine.execute_test(estimation_model)
82+
with self.assertRaises(ValueError):
83+
self.causal_test_case.execute_test(estimation_model, self.data_collector, self.causal_specification)
8484

8585
def test_check_no_positivity_violation(self):
8686
"""Check that no positivity violation is identified when there is no positivity violation."""
@@ -124,7 +124,7 @@ def test_execute_test_observational_causal_forest_estimator(self):
124124
"C",
125125
self.causal_test_engine.scenario_execution_data_df,
126126
)
127-
causal_test_result = self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)
127+
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.causal_test_case)
128128
self.assertAlmostEqual(causal_test_result.test_value.value, 4, delta=1)
129129

130130
def test_invalid_causal_effect(self):
@@ -146,7 +146,7 @@ def test_execute_test_observational_linear_regression_estimator(self):
146146
"C",
147147
self.causal_test_engine.scenario_execution_data_df,
148148
)
149-
causal_test_result = self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)
149+
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.causal_test_case)
150150
self.assertAlmostEqual(causal_test_result.test_value.value, 4, delta=1e-10)
151151

152152
def test_execute_test_observational_linear_regression_estimator_direct_effect(self):
@@ -175,7 +175,7 @@ def test_execute_test_observational_linear_regression_estimator_direct_effect(se
175175
"C",
176176
causal_test_engine.scenario_execution_data_df,
177177
)
178-
causal_test_result = causal_test_engine.execute_test(estimation_model, causal_test_case)
178+
causal_test_result = causal_test_case.execute_test(estimation_model, causal_test_case)
179179
self.assertAlmostEqual(causal_test_result.test_value.value, 4, delta=1e-10)
180180

181181
def test_execute_test_observational_linear_regression_estimator_coefficient(self):
@@ -190,7 +190,7 @@ def test_execute_test_observational_linear_regression_estimator_coefficient(self
190190
self.causal_test_engine.scenario_execution_data_df,
191191
)
192192
self.causal_test_case.estimate_type = "coefficient"
193-
causal_test_result = self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)
193+
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.causal_test_case)
194194
self.assertEqual(int(causal_test_result.test_value.value), 0)
195195

196196
def test_execute_test_observational_linear_regression_estimator_risk_ratio(self):
@@ -205,7 +205,7 @@ def test_execute_test_observational_linear_regression_estimator_risk_ratio(self)
205205
self.causal_test_engine.scenario_execution_data_df,
206206
)
207207
self.causal_test_case.estimate_type = "risk_ratio"
208-
causal_test_result = self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)
208+
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.causal_test_case)
209209
self.assertEqual(int(causal_test_result.test_value.value), 0)
210210

211211
def test_invalid_estimate_type(self):
@@ -221,7 +221,7 @@ def test_invalid_estimate_type(self):
221221
)
222222
self.causal_test_case.estimate_type = "invalid"
223223
with self.assertRaises(AttributeError):
224-
self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)
224+
self.causal_test_case.execute_test(estimation_model, self.causal_test_case)
225225

226226
def test_execute_test_observational_linear_regression_estimator_squared_term(self):
227227
"""Check that executing the causal test case returns the correct results for dummy data with a squared term
@@ -235,7 +235,7 @@ def test_execute_test_observational_linear_regression_estimator_squared_term(sel
235235
self.causal_test_engine.scenario_execution_data_df,
236236
formula=f"C ~ A + {'+'.join(self.minimal_adjustment_set)} + (D ** 2)",
237237
)
238-
causal_test_result = self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)
238+
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.causal_test_case)
239239
self.assertAlmostEqual(round(causal_test_result.test_value.value, 1), 4, delta=1)
240240

241241
def test_execute_observational_causal_forest_estimator_cates(self):
@@ -258,7 +258,7 @@ def test_execute_observational_causal_forest_estimator_cates(self):
258258
effect_modifiers={"M": None},
259259
)
260260
self.causal_test_case.estimate_type = "cates"
261-
causal_test_result = self.causal_test_engine.execute_test(estimation_model, self.causal_test_case)
261+
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.causal_test_case)
262262
causal_test_result = causal_test_result.test_value.value
263263
# Check that each effect modifier's strata has a greater ATE than the last (ascending order)
264264
causal_test_result_m1 = causal_test_result.loc[causal_test_result["M"] == 1]

0 commit comments

Comments
 (0)