Skip to content

Commit b7ec69e

Browse files
Update example to use CausalTestSuite class
1 parent 75c2972 commit b7ec69e

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

causal_testing/testing/causal_test_suite.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ def __init__(
1313
):
1414
self.test_suite = {}
1515

16-
def add_test_object(self, base_test_case, causal_test_case_list, estimators):
17-
test_object = {'tests': causal_test_case_list, 'estimators': list(estimators)}
18-
self.test_suite['base_test_case'] = test_object
16+
def add_test_object(self, base_test_case, causal_test_case_list, estimators, estimate_type):
17+
test_object = {'tests': causal_test_case_list, 'estimators': estimators, 'estimate_type': estimate_type}
18+
self.test_suite[base_test_case] = test_object
1919

2020
def get_single_test_object(self, base_test_case):
2121
return self.test_suite[base_test_case]

examples/lr91/causal_test_max_conductances_test_suite.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from causal_testing.testing.causal_test_engine import CausalTestEngine
1212
from causal_testing.testing.estimators import LinearRegressionEstimator
1313
from causal_testing.testing.base_causal_test import BaseCausalTest
14+
from causal_testing.testing.causal_test_suite import CausalTestSuite
1415
from matplotlib.pyplot import rcParams
1516

1617
# Uncommenting the code below will make all graphs publication quality but requires a suitable latex installation
@@ -58,7 +59,8 @@ def causal_testing_sensitivity_analysis():
5859

5960
apd90 = Output('APD90', int)
6061
outcome_variable = apd90
61-
test_suite = {}
62+
test_suite = CausalTestSuite()
63+
6264
for conductance_param, mean_and_oracle in conductance_means.items():
6365
treatment_variable = Input(conductance_param, float)
6466
base_test_case = BaseCausalTest(treatment_variable, outcome_variable)
@@ -67,17 +69,16 @@ def causal_testing_sensitivity_analysis():
6769
mean, oracle = mean_and_oracle
6870
for treatment_value in treatment_values:
6971
test_list.append(CausalTestCase(base_test_case, oracle, control_value, treatment_value))
70-
71-
test_suite[base_test_case] = {'tests': test_list,
72-
'estimators': [LinearRegressionEstimator],
73-
'estimate_type': "ate"}
72+
test_suite.add_test_object(base_test_case, test_list, [LinearRegressionEstimator], 'ate')
7473

7574
causal_test_results = effects_on_APD90(OBSERVATIONAL_DATA_PATH, test_suite)
7675

76+
# Extract data from causal_test_results needed for plotting
7777
for base_test_case in causal_test_results:
78-
results[base_test_case.treatment_variable.name] = {"ate": [result.ate for result in causal_test_results[base_test_case][0]],
79-
"cis": [result.confidence_intervals for result in
80-
causal_test_results[base_test_case][0]]}
78+
results[base_test_case.treatment_variable.name] = \
79+
{"ate": [result.ate for result in causal_test_results[base_test_case][0]],
80+
"cis": [result.confidence_intervals for result in
81+
causal_test_results[base_test_case][0]]}
8182

8283
plot_ates_with_cis(results, treatment_values)
8384

@@ -130,7 +131,7 @@ def effects_on_APD90(observational_data_path, test_suite):
130131
# 9. Obtain the minimal adjustment set from the causal DAG
131132

132133
# 10. Run the causal test and print results
133-
causal_test_results = causal_test_engine.execute_test_suite(test_suite)
134+
causal_test_results = causal_test_engine.execute_test_suite(test_suite.test_suite)
134135
print(causal_test_results)
135136
return causal_test_results
136137

0 commit comments

Comments
 (0)