8
8
from causal_testing .data_collection .data_collector import ObservationalDataCollector
9
9
from causal_testing .testing .causal_test_case import CausalTestCase
10
10
from causal_testing .testing .causal_test_outcome import Positive
11
- from causal_testing .testing .causal_test_engine import CausalTestEngine
12
11
from causal_testing .testing .estimators import LinearRegressionEstimator
13
12
from causal_testing .testing .base_test_case import BaseTestCase
14
13
from matplotlib .pyplot import rcParams
36
35
37
36
38
37
def doubling_beta_CATE_on_csv (
39
- observational_data_path : str , simulate_counterfactuals : bool = False , verbose : bool = False
38
+ observational_data_path : str , simulate_counterfactuals : bool = False , verbose : bool = False
40
39
):
41
40
"""Compute the CATE of increasing beta from 0.016 to 0.032 on cum_infections using the dataframe
42
41
loaded from the specified path. Additionally simulate the counterfactuals by repeating the analysis
@@ -50,9 +49,9 @@ def doubling_beta_CATE_on_csv(
50
49
"""
51
50
results_dict = {"association" : {}, "causation" : {}}
52
51
53
- # Read in the observational data, perform identification, and setup the causal_test_engine
52
+ # Read in the observational data, perform identification
54
53
past_execution_df = pd .read_csv (observational_data_path )
55
- _ , causal_test_engine , causal_test_case = engine_setup (observational_data_path )
54
+ data_collector , _ , causal_test_case , causal_specification = setup (observational_data_path )
56
55
57
56
linear_regression_estimator = LinearRegressionEstimator (
58
57
"beta" ,
@@ -65,7 +64,9 @@ def doubling_beta_CATE_on_csv(
65
64
)
66
65
67
66
# Add squared terms for beta, since it has a quadratic relationship with cumulative infections
68
- causal_test_result = causal_test_engine .execute_test (linear_regression_estimator , causal_test_case )
67
+ causal_test_result = causal_test_case .execute_test (estimator = linear_regression_estimator ,
68
+ data_collector = data_collector ,
69
+ causal_specification = causal_specification )
69
70
70
71
# Repeat for association estimate (no adjustment)
71
72
no_adjustment_linear_regression_estimator = LinearRegressionEstimator (
@@ -77,9 +78,9 @@ def doubling_beta_CATE_on_csv(
77
78
df = past_execution_df ,
78
79
formula = "cum_infections ~ beta + np.power(beta, 2)" ,
79
80
)
80
- association_test_result = causal_test_engine .execute_test (
81
- no_adjustment_linear_regression_estimator , causal_test_case
82
- )
81
+ association_test_result = causal_test_case .execute_test (estimator = no_adjustment_linear_regression_estimator ,
82
+ data_collector = data_collector ,
83
+ causal_specification = causal_specification )
83
84
84
85
# Store results for plotting
85
86
results_dict ["association" ] = {
@@ -109,8 +110,9 @@ def doubling_beta_CATE_on_csv(
109
110
df = counterfactual_past_execution_df ,
110
111
formula = "cum_infections ~ beta + np.power(beta, 2) + avg_age + contacts" ,
111
112
)
112
- counterfactual_causal_test_result = causal_test_engine .execute_test (
113
- linear_regression_estimator , causal_test_case
113
+ counterfactual_causal_test_result = causal_test_case .execute_test (
114
+ estimator = linear_regression_estimator , data_collector = data_collector ,
115
+ causal_specification = causal_specification
114
116
)
115
117
results_dict ["counterfactual" ] = {
116
118
"ate" : counterfactual_causal_test_result .test_value .value ,
@@ -218,7 +220,7 @@ def doubling_beta_CATEs(observational_data_path: str, simulate_counterfactual: b
218
220
age_contact_fig .savefig (outpath_base_str + "age_contact_executions.pdf" , format = "pdf" )
219
221
220
222
221
- def engine_setup (observational_data_path ):
223
+ def setup (observational_data_path ):
222
224
# 1. Read in the Causal DAG
223
225
causal_dag = CausalDAG (f"{ ROOT } /dag.dot" )
224
226
@@ -265,13 +267,10 @@ def engine_setup(observational_data_path):
265
267
# 7. Create a data collector
266
268
data_collector = ObservationalDataCollector (scenario , pd .read_csv (observational_data_path ))
267
269
268
- # 8. Create an instance of the causal test engine
269
- causal_test_engine = CausalTestEngine (causal_specification , data_collector )
270
-
271
- # 9. Obtain the minimal adjustment set for the base test case from the causal DAG
270
+ # 8. Obtain the minimal adjustment set for the base test case from the causal DAG
272
271
minimal_adjustment_set = causal_dag .identification (base_test_case )
273
272
274
- return minimal_adjustment_set , causal_test_engine , causal_test_case
273
+ return data_collector , minimal_adjustment_set , causal_test_case , causal_specification
275
274
276
275
277
276
def plot_doubling_beta_CATEs (results_dict , title , figure = None , axes = None , row = None , col = None ):
0 commit comments