10
10
from causal_testing .testing .causal_test_outcome import Positive , Negative , NoEffect
11
11
from causal_testing .testing .causal_test_engine import CausalTestEngine
12
12
from causal_testing .testing .estimators import LinearRegressionEstimator
13
+ from causal_testing .testing .base_causal_test import BaseCausalTest
13
14
from matplotlib .pyplot import rcParams
14
15
15
16
# Uncommenting the code below will make all graphs publication quality but requires a suitable latex installation
@@ -109,12 +110,12 @@ def effects_on_APD90(observational_data_path, treatment_var, control_val, treatm
109
110
110
111
# 5. Create a causal specification from the scenario and causal DAG
111
112
causal_specification = CausalSpecification (scenario , causal_dag )
112
-
113
+ base_test_case = BaseCausalTest ( treatment_var , apd90 )
113
114
# 6. Create a causal test case
114
- causal_test_case = CausalTestCase (control_input_configuration = { treatment_var : control_val } ,
115
+ causal_test_case = CausalTestCase (base_causal_test = base_test_case ,
115
116
expected_causal_effect = expected_causal_effect ,
116
- treatment_input_configuration = { treatment_var : treatment_val } ,
117
- outcome_variables = { apd90 } )
117
+ control_value = control_val ,
118
+ treatment_value = treatment_val )
118
119
119
120
# 7. Create a data collector
120
121
data_collector = ObservationalDataCollector (scenario , observational_data_path )
@@ -123,9 +124,9 @@ def effects_on_APD90(observational_data_path, treatment_var, control_val, treatm
123
124
causal_test_engine = CausalTestEngine (causal_specification , data_collector )
124
125
125
126
# 9. Obtain the minimal adjustment set from the causal DAG
126
- causal_test_engine .identification (causal_test_case )
127
+ minimal_adjustment_set = causal_dag .identification (base_test_case )
127
128
linear_regression_estimator = LinearRegressionEstimator ((treatment_var .name ,), treatment_val , control_val ,
128
- causal_test_engine . minimal_adjustment_set ,
129
+ minimal_adjustment_set ,
129
130
('APD90' ,)
130
131
)
131
132
@@ -157,7 +158,6 @@ def plot_ates_with_cis(results_dict: dict, xs: list, save: bool = True):
157
158
latex_compatible_treatment_str = rf"${ before_underscore } _{ after_underscore_braces } $"
158
159
cis_low = [c [0 ] for c in cis ]
159
160
cis_high = [c [1 ] for c in cis ]
160
-
161
161
axes .fill_between (xs , cis_low , cis_high , alpha = .2 , color = input_colors [treatment ],
162
162
label = latex_compatible_treatment_str )
163
163
axes .plot (xs , ates , color = input_colors [treatment ], linewidth = 1 )
0 commit comments