@@ -60,7 +60,9 @@ def causal_testing_sensitivity_analysis():
60
60
apd90 = Output ('APD90' , int )
61
61
outcome_variable = apd90
62
62
test_suite = CausalTestSuite ()
63
+ estimator_list = [LinearRegressionEstimator ]
63
64
65
+ # For each parameter in conductance_means, setup variables and add a test case to the test suite
64
66
for conductance_param , mean_and_oracle in conductance_means .items ():
65
67
treatment_variable = Input (conductance_param , float )
66
68
base_test_case = BaseTestCase (treatment_variable , outcome_variable )
@@ -69,14 +71,19 @@ def causal_testing_sensitivity_analysis():
69
71
mean , oracle = mean_and_oracle
70
72
for treatment_value in treatment_values :
71
73
test_list .append (CausalTestCase (base_test_case , oracle , control_value , treatment_value ))
72
- test_suite .add_test_object (base_test_case , test_list , [LinearRegressionEstimator ], 'ate' )
74
+ test_suite .add_test_object (base_test_case = base_test_case ,
75
+ test_list = test_list ,
76
+ estimators = estimator_list ,
77
+ estimate_type = 'ate' )
73
78
74
79
causal_test_results = effects_on_APD90 (OBSERVATIONAL_DATA_PATH , test_suite )
75
80
76
81
# Extract data from causal_test_results needed for plotting
77
82
for base_test_case in causal_test_results :
83
+
84
+ # Place results of test_suite into format required for plotting
78
85
results [base_test_case .treatment_variable .name ] = \
79
- {"ate" : [result .ate for result in causal_test_results [base_test_case ][0 ]],
86
+ {"ate" : [result .ate for result in causal_test_results [base_test_case ]['LinearRegressionEstimator' ]],
80
87
"cis" : [result .confidence_intervals for result in
81
88
causal_test_results [base_test_case ][0 ]]}
82
89
0 commit comments