11
11
from causal_testing .testing .causal_test_engine import CausalTestEngine
12
12
from causal_testing .testing .estimators import LinearRegressionEstimator
13
13
from causal_testing .testing .base_causal_test import BaseCausalTest
14
+ from causal_testing .testing .causal_test_suite import CausalTestSuite
14
15
from matplotlib .pyplot import rcParams
15
16
16
17
# Uncommenting the code below will make all graphs publication quality but requires a suitable latex installation
@@ -58,7 +59,8 @@ def causal_testing_sensitivity_analysis():
58
59
59
60
apd90 = Output ('APD90' , int )
60
61
outcome_variable = apd90
61
- test_suite = {}
62
+ test_suite = CausalTestSuite ()
63
+
62
64
for conductance_param , mean_and_oracle in conductance_means .items ():
63
65
treatment_variable = Input (conductance_param , float )
64
66
base_test_case = BaseCausalTest (treatment_variable , outcome_variable )
@@ -67,17 +69,16 @@ def causal_testing_sensitivity_analysis():
67
69
mean , oracle = mean_and_oracle
68
70
for treatment_value in treatment_values :
69
71
test_list .append (CausalTestCase (base_test_case , oracle , control_value , treatment_value ))
70
-
71
- test_suite [base_test_case ] = {'tests' : test_list ,
72
- 'estimators' : [LinearRegressionEstimator ],
73
- 'estimate_type' : "ate" }
72
+ test_suite .add_test_object (base_test_case , test_list , [LinearRegressionEstimator ], 'ate' )
74
73
75
74
causal_test_results = effects_on_APD90 (OBSERVATIONAL_DATA_PATH , test_suite )
76
75
76
+ # Extract data from causal_test_results needed for plotting
77
77
for base_test_case in causal_test_results :
78
- results [base_test_case .treatment_variable .name ] = {"ate" : [result .ate for result in causal_test_results [base_test_case ][0 ]],
79
- "cis" : [result .confidence_intervals for result in
80
- causal_test_results [base_test_case ][0 ]]}
78
+ results [base_test_case .treatment_variable .name ] = \
79
+ {"ate" : [result .ate for result in causal_test_results [base_test_case ][0 ]],
80
+ "cis" : [result .confidence_intervals for result in
81
+ causal_test_results [base_test_case ][0 ]]}
81
82
82
83
plot_ates_with_cis (results , treatment_values )
83
84
@@ -130,7 +131,7 @@ def effects_on_APD90(observational_data_path, test_suite):
130
131
# 9. Obtain the minimal adjustment set from the causal DAG
131
132
132
133
# 10. Run the causal test and print results
133
- causal_test_results = causal_test_engine .execute_test_suite (test_suite )
134
+ causal_test_results = causal_test_engine .execute_test_suite (test_suite . test_suite )
134
135
print (causal_test_results )
135
136
return causal_test_results
136
137
0 commit comments