1
1
import unittest
2
2
import os
3
+ import pandas as pd
4
+ import numpy as np
5
+
3
6
from tests .test_helpers import create_temp_dir_if_non_existent , remove_temp_dir_if_existent
7
+ from causal_testing .specification .causal_specification import CausalSpecification , Scenario
4
8
from causal_testing .specification .variable import Input , Output
9
+ from causal_testing .specification .causal_dag import CausalDAG
10
+ from causal_testing .data_collection .data_collector import ObservationalDataCollector
5
11
from causal_testing .testing .causal_test_case import CausalTestCase
6
12
from causal_testing .testing .causal_test_outcome import ExactValue
13
+ from causal_testing .testing .estimators import CausalForestEstimator , LinearRegressionEstimator
7
14
from causal_testing .testing .base_test_case import BaseTestCase
8
15
9
16
@@ -51,3 +58,220 @@ def test_str(self):
51
58
52
59
def tearDown (self ) -> None :
53
60
remove_temp_dir_if_existent ()
61
+
62
+
63
+
64
+ class TestCausalTestExecution (unittest .TestCase ):
65
+ """Test the CausalTestEngine workflow using observational data.
66
+
67
+ The causal test engine (CTE) is the main workflow for the causal testing framework. The CTE takes a causal test case
68
+ and a causal specification and computes the causal effect of the intervention on the outcome of interest.
69
+ """
70
+
71
+ def setUp (self ) -> None :
72
+ # 1. Create Causal DAG
73
+ temp_dir_path = create_temp_dir_if_non_existent ()
74
+ dag_dot_path = os .path .join (temp_dir_path , "dag.dot" )
75
+ dag_dot = """digraph G { A -> C; D -> A; D -> C}"""
76
+ with open (dag_dot_path , "w" ) as file :
77
+ file .write (dag_dot )
78
+ self .causal_dag = CausalDAG (dag_dot_path )
79
+
80
+ # 2. Create Scenario and Causal Specification
81
+ A = Input ("A" , float )
82
+ self .A = A
83
+ C = Output ("C" , float )
84
+ self .C = C
85
+ D = Output ("D" , float )
86
+ self .scenario = Scenario ({A , C , D })
87
+ self .causal_specification = CausalSpecification (scenario = self .scenario , causal_dag = self .causal_dag )
88
+
89
+ # 3. Create a causal test case
90
+ self .expected_causal_effect = ExactValue (4 )
91
+ self .base_test_case = BaseTestCase (A , C )
92
+ self .causal_test_case = CausalTestCase (
93
+ base_test_case = self .base_test_case ,
94
+ expected_causal_effect = self .expected_causal_effect ,
95
+ control_value = 0 ,
96
+ treatment_value = 1 ,
97
+ )
98
+
99
+ # 4. Create dummy test data and write to csv
100
+ np .random .seed (1 )
101
+ df = pd .DataFrame ({"D" : list (np .random .normal (60 , 10 , 1000 ))}) # D = exogenous
102
+ df ["A" ] = [1 if d > 50 else 0 for d in df ["D" ]]
103
+ df ["C" ] = df ["D" ] + (4 * (df ["A" ] + 2 )) # C = (4*(A+2)) + D
104
+ self .observational_data_csv_path = os .path .join (temp_dir_path , "observational_data.csv" )
105
+ df .to_csv (self .observational_data_csv_path , index = False )
106
+
107
+ # 5. Create observational data collector
108
+ # Obsolete?
109
+ self .data_collector = ObservationalDataCollector (self .scenario , df )
110
+ self .data_collector .collect_data ()
111
+ self .df = self .data_collector .data
112
+ self .minimal_adjustment_set = self .causal_dag .identification (self .base_test_case )
113
+ # 6. Easier to access treatment and outcome values
114
+ self .treatment_value = 1
115
+ self .control_value = 0
116
+
117
+ def test_check_minimum_adjustment_set (self ):
118
+ """Check that the minimum adjustment set is correctly made"""
119
+ minimal_adjustment_set = self .causal_dag .identification (self .base_test_case )
120
+ self .assertEqual (minimal_adjustment_set , {"D" })
121
+
122
+ def test_execute_test_observational_causal_forest_estimator (self ):
123
+ """Check that executing the causal test case returns the correct results for the dummy data using a causal
124
+ forest estimator."""
125
+ estimation_model = CausalForestEstimator (
126
+ "A" ,
127
+ self .treatment_value ,
128
+ self .control_value ,
129
+ self .minimal_adjustment_set ,
130
+ "C" ,
131
+ self .df ,
132
+ )
133
+ causal_test_result = self .causal_test_case .execute_test (estimation_model , self .data_collector , self .causal_specification )
134
+ self .assertAlmostEqual (causal_test_result .test_value .value , 4 , delta = 1 )
135
+
136
+ def test_invalid_causal_effect (self ):
137
+ """Check that executing the causal test case returns the correct results for dummy data using a linear
138
+ regression estimator."""
139
+ base_test_case = BaseTestCase (treatment_variable = self .A , outcome_variable = self .C , effect = "error" )
140
+
141
+ with self .assertRaises (Exception ):
142
+ self .causal_dag .identification (base_test_case )
143
+
144
+ def test_execute_test_observational_linear_regression_estimator (self ):
145
+ """Check that executing the causal test case returns the correct results for dummy data using a linear
146
+ regression estimator."""
147
+ estimation_model = LinearRegressionEstimator (
148
+ "A" ,
149
+ self .treatment_value ,
150
+ self .control_value ,
151
+ self .minimal_adjustment_set ,
152
+ "C" ,
153
+ self .df ,
154
+ )
155
+ causal_test_result = self .causal_test_case .execute_test (estimation_model , self .data_collector , self .causal_specification )
156
+ self .assertAlmostEqual (causal_test_result .test_value .value , 4 , delta = 1e-10 )
157
+
158
+ def test_execute_test_observational_linear_regression_estimator_direct_effect (self ):
159
+ """Check that executing the causal test case returns the correct results for dummy data using a linear
160
+ regression estimator."""
161
+ base_test_case = BaseTestCase (treatment_variable = self .A , outcome_variable = self .C , effect = "direct" )
162
+
163
+ causal_test_case = CausalTestCase (
164
+ base_test_case = base_test_case ,
165
+ expected_causal_effect = self .expected_causal_effect ,
166
+ control_value = 0 ,
167
+ treatment_value = 1 ,
168
+ )
169
+
170
+ minimal_adjustment_set = self .causal_dag .identification (base_test_case )
171
+ # 6. Easier to access treatment and outcome values
172
+ self .treatment_value = 1
173
+ self .control_value = 0
174
+ estimation_model = LinearRegressionEstimator (
175
+ "A" ,
176
+ self .treatment_value ,
177
+ self .control_value ,
178
+ minimal_adjustment_set ,
179
+ "C" ,
180
+ self .df ,
181
+ )
182
+ causal_test_result = causal_test_case .execute_test (estimation_model , self .data_collector , self .causal_specification )
183
+ self .assertAlmostEqual (causal_test_result .test_value .value , 4 , delta = 1e-10 )
184
+
185
+ def test_execute_test_observational_linear_regression_estimator_coefficient (self ):
186
+ """Check that executing the causal test case returns the correct results for dummy data using a linear
187
+ regression estimator."""
188
+ estimation_model = LinearRegressionEstimator (
189
+ "D" ,
190
+ self .treatment_value ,
191
+ self .control_value ,
192
+ self .minimal_adjustment_set ,
193
+ "A" ,
194
+ self .df ,
195
+ )
196
+ self .causal_test_case .estimate_type = "coefficient"
197
+ causal_test_result = self .causal_test_case .execute_test (estimation_model , self .data_collector , self .causal_specification )
198
+ self .assertEqual (int (causal_test_result .test_value .value ), 0 )
199
+
200
+ def test_execute_test_observational_linear_regression_estimator_risk_ratio (self ):
201
+ """Check that executing the causal test case returns the correct results for dummy data using a linear
202
+ regression estimator."""
203
+ estimation_model = LinearRegressionEstimator (
204
+ "D" ,
205
+ self .treatment_value ,
206
+ self .control_value ,
207
+ self .minimal_adjustment_set ,
208
+ "A" ,
209
+ self .df ,
210
+ )
211
+ self .causal_test_case .estimate_type = "risk_ratio"
212
+ causal_test_result = self .causal_test_case .execute_test (estimation_model , self .data_collector , self .causal_specification )
213
+ self .assertEqual (int (causal_test_result .test_value .value ), 0 )
214
+
215
+ def test_invalid_estimate_type (self ):
216
+ """Check that executing the causal test case returns the correct results for dummy data using a linear
217
+ regression estimator."""
218
+ estimation_model = LinearRegressionEstimator (
219
+ "D" ,
220
+ self .treatment_value ,
221
+ self .control_value ,
222
+ self .minimal_adjustment_set ,
223
+ "A" ,
224
+ self .df ,
225
+ )
226
+ self .causal_test_case .estimate_type = "invalid"
227
+ with self .assertRaises (AttributeError ):
228
+ self .causal_test_case .execute_test (estimation_model , self .data_collector , self .causal_specification )
229
+
230
+ def test_execute_test_observational_linear_regression_estimator_squared_term (self ):
231
+ """Check that executing the causal test case returns the correct results for dummy data with a squared term
232
+ using a linear regression estimator. C ~ 4*(A+2) + D + D^2"""
233
+ estimation_model = LinearRegressionEstimator (
234
+ "A" ,
235
+ self .treatment_value ,
236
+ self .control_value ,
237
+ self .minimal_adjustment_set ,
238
+ "C" ,
239
+ self .df ,
240
+ formula = f"C ~ A + { '+' .join (self .minimal_adjustment_set )} + (D ** 2)" ,
241
+ )
242
+ causal_test_result = self .causal_test_case .execute_test (estimation_model , self .data_collector , self .causal_specification )
243
+ self .assertAlmostEqual (round (causal_test_result .test_value .value , 1 ), 4 , delta = 1 )
244
+
245
+ def test_execute_observational_causal_forest_estimator_cates (self ):
246
+ """Check that executing the causal test case returns the correct conditional average treatment effects for
247
+ dummy data with effect multiplicative effect modification. C ~ (4*(A+2) + D)*M"""
248
+ # Add some effect modifier M that has a multiplicative effect on C
249
+ self .df ["M" ] = np .random .randint (
250
+ 1 , 5 , len (self .df )
251
+ )
252
+ self .df ["C" ] *= self .df [
253
+ "M"
254
+ ]
255
+ estimation_model = CausalForestEstimator (
256
+ "A" ,
257
+ self .treatment_value ,
258
+ self .control_value ,
259
+ self .minimal_adjustment_set ,
260
+ "C" ,
261
+ self .df ,
262
+ effect_modifiers = {"M" : None },
263
+ )
264
+ self .causal_test_case .estimate_type = "cates"
265
+ causal_test_result = self .causal_test_case .execute_test (estimation_model , self .data_collector , self .causal_specification )
266
+ causal_test_result = causal_test_result .test_value .value
267
+ # Check that each effect modifier's strata has a greater ATE than the last (ascending order)
268
+ causal_test_result_m1 = causal_test_result .loc [causal_test_result ["M" ] == 1 ]
269
+ causal_test_result_m2 = causal_test_result .loc [causal_test_result ["M" ] == 2 ]
270
+ causal_test_result_m3 = causal_test_result .loc [causal_test_result ["M" ] == 3 ]
271
+ causal_test_result_m4 = causal_test_result .loc [causal_test_result ["M" ] == 4 ]
272
+ self .assertLess (causal_test_result_m1 ["cate" ].mean (), causal_test_result_m2 ["cate" ].mean ())
273
+ self .assertLess (causal_test_result_m2 ["cate" ].mean (), causal_test_result_m3 ["cate" ].mean ())
274
+ self .assertLess (causal_test_result_m3 ["cate" ].mean (), causal_test_result_m4 ["cate" ].mean ())
275
+
276
+ def tearDown (self ) -> None :
277
+ remove_temp_dir_if_existent ()
0 commit comments