Skip to content

Commit 8662e96

Browse files
Update remaining test cases to work without engine
1 parent 0836236 commit 8662e96

File tree

3 files changed

+231
-260
lines changed

3 files changed

+231
-260
lines changed

tests/data_collection_tests/test_observational_data_collector.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,15 @@ def test_not_all_variables_in_data(self):
4444
def test_all_variables_in_data(self):
4545
scenario = Scenario({self.X1, self.X2, self.X3, self.Y1, self.Y2})
4646
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df)
47-
df = observational_data_collector.collect_data(index_col=0)
47+
observational_data_collector.collect_data(index_col=0)
48+
df = observational_data_collector.data
4849
assert df.equals(self.observational_df), f"\n{df}\nwas not equal to\n{self.observational_df}"
4950

5051
def test_data_constraints(self):
5152
scenario = Scenario({self.X1, self.X2, self.X3, self.Y1, self.Y2}, {self.X1.z3 > 2})
5253
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df)
53-
df = observational_data_collector.collect_data(index_col=0)
54+
observational_data_collector.collect_data(index_col=0)
55+
df = observational_data_collector.data
5456
expected = self.observational_df.loc[[2, 3]]
5557
assert df.equals(expected), f"\n{df}\nwas not equal to\n{expected}"
5658

@@ -61,7 +63,8 @@ def populate_m(data):
6163
meta = Meta("M", int, populate_m)
6264
scenario = Scenario({self.X1, meta})
6365
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df)
64-
data = observational_data_collector.collect_data()
66+
observational_data_collector.collect_data()
67+
data = observational_data_collector.data
6568
assert all((m == 2 * x1 for x1, m in zip(data["X1"], data["M"])))
6669

6770
def tearDown(self) -> None:

tests/testing_tests/test_causal_test_case.py

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
import unittest
22
import os
3+
import pandas as pd
4+
import numpy as np
5+
36
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
7+
from causal_testing.specification.causal_specification import CausalSpecification, Scenario
48
from causal_testing.specification.variable import Input, Output
9+
from causal_testing.specification.causal_dag import CausalDAG
10+
from causal_testing.data_collection.data_collector import ObservationalDataCollector
511
from causal_testing.testing.causal_test_case import CausalTestCase
612
from causal_testing.testing.causal_test_outcome import ExactValue
13+
from causal_testing.testing.estimators import CausalForestEstimator, LinearRegressionEstimator
714
from causal_testing.testing.base_test_case import BaseTestCase
815

916

@@ -51,3 +58,220 @@ def test_str(self):
5158

5259
def tearDown(self) -> None:
5360
remove_temp_dir_if_existent()
61+
62+
63+
64+
class TestCausalTestExecution(unittest.TestCase):
65+
"""Test the CausalTestEngine workflow using observational data.
66+
67+
The causal test engine (CTE) is the main workflow for the causal testing framework. The CTE takes a causal test case
68+
and a causal specification and computes the causal effect of the intervention on the outcome of interest.
69+
"""
70+
71+
def setUp(self) -> None:
72+
# 1. Create Causal DAG
73+
temp_dir_path = create_temp_dir_if_non_existent()
74+
dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
75+
dag_dot = """digraph G { A -> C; D -> A; D -> C}"""
76+
with open(dag_dot_path, "w") as file:
77+
file.write(dag_dot)
78+
self.causal_dag = CausalDAG(dag_dot_path)
79+
80+
# 2. Create Scenario and Causal Specification
81+
A = Input("A", float)
82+
self.A = A
83+
C = Output("C", float)
84+
self.C = C
85+
D = Output("D", float)
86+
self.scenario = Scenario({A, C, D})
87+
self.causal_specification = CausalSpecification(scenario=self.scenario, causal_dag=self.causal_dag)
88+
89+
# 3. Create a causal test case
90+
self.expected_causal_effect = ExactValue(4)
91+
self.base_test_case = BaseTestCase(A, C)
92+
self.causal_test_case = CausalTestCase(
93+
base_test_case=self.base_test_case,
94+
expected_causal_effect=self.expected_causal_effect,
95+
control_value=0,
96+
treatment_value=1,
97+
)
98+
99+
# 4. Create dummy test data and write to csv
100+
np.random.seed(1)
101+
df = pd.DataFrame({"D": list(np.random.normal(60, 10, 1000))}) # D = exogenous
102+
df["A"] = [1 if d > 50 else 0 for d in df["D"]]
103+
df["C"] = df["D"] + (4 * (df["A"] + 2)) # C = (4*(A+2)) + D
104+
self.observational_data_csv_path = os.path.join(temp_dir_path, "observational_data.csv")
105+
df.to_csv(self.observational_data_csv_path, index=False)
106+
107+
# 5. Create observational data collector
108+
# Obsolete?
109+
self.data_collector = ObservationalDataCollector(self.scenario, df)
110+
self.data_collector.collect_data()
111+
self.df = self.data_collector.data
112+
self.minimal_adjustment_set = self.causal_dag.identification(self.base_test_case)
113+
# 6. Easier to access treatment and outcome values
114+
self.treatment_value = 1
115+
self.control_value = 0
116+
117+
def test_check_minimum_adjustment_set(self):
118+
"""Check that the minimum adjustment set is correctly made"""
119+
minimal_adjustment_set = self.causal_dag.identification(self.base_test_case)
120+
self.assertEqual(minimal_adjustment_set, {"D"})
121+
122+
def test_execute_test_observational_causal_forest_estimator(self):
123+
"""Check that executing the causal test case returns the correct results for the dummy data using a causal
124+
forest estimator."""
125+
estimation_model = CausalForestEstimator(
126+
"A",
127+
self.treatment_value,
128+
self.control_value,
129+
self.minimal_adjustment_set,
130+
"C",
131+
self.df,
132+
)
133+
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.data_collector, self.causal_specification)
134+
self.assertAlmostEqual(causal_test_result.test_value.value, 4, delta=1)
135+
136+
def test_invalid_causal_effect(self):
137+
"""Check that executing the causal test case returns the correct results for dummy data using a linear
138+
regression estimator."""
139+
base_test_case = BaseTestCase(treatment_variable=self.A, outcome_variable=self.C, effect="error")
140+
141+
with self.assertRaises(Exception):
142+
self.causal_dag.identification(base_test_case)
143+
144+
def test_execute_test_observational_linear_regression_estimator(self):
145+
"""Check that executing the causal test case returns the correct results for dummy data using a linear
146+
regression estimator."""
147+
estimation_model = LinearRegressionEstimator(
148+
"A",
149+
self.treatment_value,
150+
self.control_value,
151+
self.minimal_adjustment_set,
152+
"C",
153+
self.df,
154+
)
155+
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.data_collector, self.causal_specification)
156+
self.assertAlmostEqual(causal_test_result.test_value.value, 4, delta=1e-10)
157+
158+
def test_execute_test_observational_linear_regression_estimator_direct_effect(self):
159+
"""Check that executing the causal test case returns the correct results for dummy data using a linear
160+
regression estimator."""
161+
base_test_case = BaseTestCase(treatment_variable=self.A, outcome_variable=self.C, effect="direct")
162+
163+
causal_test_case = CausalTestCase(
164+
base_test_case=base_test_case,
165+
expected_causal_effect=self.expected_causal_effect,
166+
control_value=0,
167+
treatment_value=1,
168+
)
169+
170+
minimal_adjustment_set = self.causal_dag.identification(base_test_case)
171+
# 6. Easier to access treatment and outcome values
172+
self.treatment_value = 1
173+
self.control_value = 0
174+
estimation_model = LinearRegressionEstimator(
175+
"A",
176+
self.treatment_value,
177+
self.control_value,
178+
minimal_adjustment_set,
179+
"C",
180+
self.df,
181+
)
182+
causal_test_result = causal_test_case.execute_test(estimation_model, self.data_collector, self.causal_specification)
183+
self.assertAlmostEqual(causal_test_result.test_value.value, 4, delta=1e-10)
184+
185+
def test_execute_test_observational_linear_regression_estimator_coefficient(self):
186+
"""Check that executing the causal test case returns the correct results for dummy data using a linear
187+
regression estimator."""
188+
estimation_model = LinearRegressionEstimator(
189+
"D",
190+
self.treatment_value,
191+
self.control_value,
192+
self.minimal_adjustment_set,
193+
"A",
194+
self.df,
195+
)
196+
self.causal_test_case.estimate_type = "coefficient"
197+
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.data_collector, self.causal_specification)
198+
self.assertEqual(int(causal_test_result.test_value.value), 0)
199+
200+
def test_execute_test_observational_linear_regression_estimator_risk_ratio(self):
201+
"""Check that executing the causal test case returns the correct results for dummy data using a linear
202+
regression estimator."""
203+
estimation_model = LinearRegressionEstimator(
204+
"D",
205+
self.treatment_value,
206+
self.control_value,
207+
self.minimal_adjustment_set,
208+
"A",
209+
self.df,
210+
)
211+
self.causal_test_case.estimate_type = "risk_ratio"
212+
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.data_collector, self.causal_specification)
213+
self.assertEqual(int(causal_test_result.test_value.value), 0)
214+
215+
def test_invalid_estimate_type(self):
216+
"""Check that executing the causal test case returns the correct results for dummy data using a linear
217+
regression estimator."""
218+
estimation_model = LinearRegressionEstimator(
219+
"D",
220+
self.treatment_value,
221+
self.control_value,
222+
self.minimal_adjustment_set,
223+
"A",
224+
self.df,
225+
)
226+
self.causal_test_case.estimate_type = "invalid"
227+
with self.assertRaises(AttributeError):
228+
self.causal_test_case.execute_test(estimation_model, self.data_collector, self.causal_specification)
229+
230+
def test_execute_test_observational_linear_regression_estimator_squared_term(self):
231+
"""Check that executing the causal test case returns the correct results for dummy data with a squared term
232+
using a linear regression estimator. C ~ 4*(A+2) + D + D^2"""
233+
estimation_model = LinearRegressionEstimator(
234+
"A",
235+
self.treatment_value,
236+
self.control_value,
237+
self.minimal_adjustment_set,
238+
"C",
239+
self.df,
240+
formula=f"C ~ A + {'+'.join(self.minimal_adjustment_set)} + (D ** 2)",
241+
)
242+
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.data_collector, self.causal_specification)
243+
self.assertAlmostEqual(round(causal_test_result.test_value.value, 1), 4, delta=1)
244+
245+
def test_execute_observational_causal_forest_estimator_cates(self):
246+
"""Check that executing the causal test case returns the correct conditional average treatment effects for
247+
dummy data with effect multiplicative effect modification. C ~ (4*(A+2) + D)*M"""
248+
# Add some effect modifier M that has a multiplicative effect on C
249+
self.df["M"] = np.random.randint(
250+
1, 5, len(self.df)
251+
)
252+
self.df["C"] *= self.df[
253+
"M"
254+
]
255+
estimation_model = CausalForestEstimator(
256+
"A",
257+
self.treatment_value,
258+
self.control_value,
259+
self.minimal_adjustment_set,
260+
"C",
261+
self.df,
262+
effect_modifiers={"M": None},
263+
)
264+
self.causal_test_case.estimate_type = "cates"
265+
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.data_collector, self.causal_specification)
266+
causal_test_result = causal_test_result.test_value.value
267+
# Check that each effect modifier's strata has a greater ATE than the last (ascending order)
268+
causal_test_result_m1 = causal_test_result.loc[causal_test_result["M"] == 1]
269+
causal_test_result_m2 = causal_test_result.loc[causal_test_result["M"] == 2]
270+
causal_test_result_m3 = causal_test_result.loc[causal_test_result["M"] == 3]
271+
causal_test_result_m4 = causal_test_result.loc[causal_test_result["M"] == 4]
272+
self.assertLess(causal_test_result_m1["cate"].mean(), causal_test_result_m2["cate"].mean())
273+
self.assertLess(causal_test_result_m2["cate"].mean(), causal_test_result_m3["cate"].mean())
274+
self.assertLess(causal_test_result_m3["cate"].mean(), causal_test_result_m4["cate"].mean())
275+
276+
def tearDown(self) -> None:
277+
remove_temp_dir_if_existent()

0 commit comments

Comments
 (0)