Skip to content

Commit c6ba4ee

Browse files
Examples now use causal_test_case.execute_test
1 parent e509eb3 commit c6ba4ee

File tree

5 files changed

+31
-46
lines changed

5 files changed

+31
-46
lines changed

examples/covasim_/doubling_beta/example_beta.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from causal_testing.data_collection.data_collector import ObservationalDataCollector
99
from causal_testing.testing.causal_test_case import CausalTestCase
1010
from causal_testing.testing.causal_test_outcome import Positive
11-
from causal_testing.testing.causal_test_engine import CausalTestEngine
1211
from causal_testing.testing.estimators import LinearRegressionEstimator
1312
from causal_testing.testing.base_test_case import BaseTestCase
1413
from matplotlib.pyplot import rcParams
@@ -36,7 +35,7 @@
3635

3736

3837
def doubling_beta_CATE_on_csv(
39-
observational_data_path: str, simulate_counterfactuals: bool = False, verbose: bool = False
38+
observational_data_path: str, simulate_counterfactuals: bool = False, verbose: bool = False
4039
):
4140
"""Compute the CATE of increasing beta from 0.016 to 0.032 on cum_infections using the dataframe
4241
loaded from the specified path. Additionally simulate the counterfactuals by repeating the analysis
@@ -50,9 +49,9 @@ def doubling_beta_CATE_on_csv(
5049
"""
5150
results_dict = {"association": {}, "causation": {}}
5251

53-
# Read in the observational data, perform identification, and setup the causal_test_engine
52+
# Read in the observational data, perform identification
5453
past_execution_df = pd.read_csv(observational_data_path)
55-
_, causal_test_engine, causal_test_case = engine_setup(observational_data_path)
54+
data_collector, _, causal_test_case, causal_specification = setup(observational_data_path)
5655

5756
linear_regression_estimator = LinearRegressionEstimator(
5857
"beta",
@@ -65,7 +64,9 @@ def doubling_beta_CATE_on_csv(
6564
)
6665

6766
# Add squared terms for beta, since it has a quadratic relationship with cumulative infections
68-
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case)
67+
causal_test_result = causal_test_case.execute_test(estimator=linear_regression_estimator,
68+
data_collector=data_collector,
69+
causal_specification=causal_specification)
6970

7071
# Repeat for association estimate (no adjustment)
7172
no_adjustment_linear_regression_estimator = LinearRegressionEstimator(
@@ -77,9 +78,9 @@ def doubling_beta_CATE_on_csv(
7778
df=past_execution_df,
7879
formula="cum_infections ~ beta + np.power(beta, 2)",
7980
)
80-
association_test_result = causal_test_engine.execute_test(
81-
no_adjustment_linear_regression_estimator, causal_test_case
82-
)
81+
association_test_result = causal_test_case.execute_test(estimator=no_adjustment_linear_regression_estimator,
82+
data_collector=data_collector,
83+
causal_specification=causal_specification)
8384

8485
# Store results for plotting
8586
results_dict["association"] = {
@@ -109,8 +110,9 @@ def doubling_beta_CATE_on_csv(
109110
df=counterfactual_past_execution_df,
110111
formula="cum_infections ~ beta + np.power(beta, 2) + avg_age + contacts",
111112
)
112-
counterfactual_causal_test_result = causal_test_engine.execute_test(
113-
linear_regression_estimator, causal_test_case
113+
counterfactual_causal_test_result = causal_test_case.execute_test(
114+
estimator=linear_regression_estimator, data_collector=data_collector,
115+
causal_specification=causal_specification
114116
)
115117
results_dict["counterfactual"] = {
116118
"ate": counterfactual_causal_test_result.test_value.value,
@@ -218,7 +220,7 @@ def doubling_beta_CATEs(observational_data_path: str, simulate_counterfactual: b
218220
age_contact_fig.savefig(outpath_base_str + "age_contact_executions.pdf", format="pdf")
219221

220222

221-
def engine_setup(observational_data_path):
223+
def setup(observational_data_path):
222224
# 1. Read in the Causal DAG
223225
causal_dag = CausalDAG(f"{ROOT}/dag.dot")
224226

@@ -265,13 +267,10 @@ def engine_setup(observational_data_path):
265267
# 7. Create a data collector
266268
data_collector = ObservationalDataCollector(scenario, pd.read_csv(observational_data_path))
267269

268-
# 8. Create an instance of the causal test engine
269-
causal_test_engine = CausalTestEngine(causal_specification, data_collector)
270-
271-
# 9. Obtain the minimal adjustment set for the base test case from the causal DAG
270+
# 8. Obtain the minimal adjustment set for the base test case from the causal DAG
272271
minimal_adjustment_set = causal_dag.identification(base_test_case)
273272

274-
return minimal_adjustment_set, causal_test_engine, causal_test_case
273+
return data_collector, minimal_adjustment_set, causal_test_case, causal_specification
275274

276275

277276
def plot_doubling_beta_CATEs(results_dict, title, figure=None, axes=None, row=None, col=None):

examples/covasim_/vaccinating_elderly/example_vaccine.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from causal_testing.data_collection.data_collector import ExperimentalDataCollector
1111
from causal_testing.testing.causal_test_case import CausalTestCase
1212
from causal_testing.testing.causal_test_outcome import Positive, Negative, NoEffect
13-
from causal_testing.testing.causal_test_engine import CausalTestEngine
1413
from causal_testing.testing.estimators import LinearRegressionEstimator
1514
from causal_testing.testing.base_test_case import BaseTestCase
1615

@@ -81,25 +80,23 @@ def test_experimental_vaccinate_elderly(runs_per_test_per_config: int = 30, verb
8180
}
8281
results_dict = {"cum_infections": {}, "cum_vaccinations": {}, "cum_vaccinated": {}, "max_doses": {}}
8382

84-
# 7. Create an instance of the causal test engine
85-
causal_test_engine = CausalTestEngine(causal_specification, data_collector, index_col=0)
86-
8783
for outcome_variable, expected_effect in expected_outcome_effects.items():
8884
base_test_case = BaseTestCase(treatment_variable=vaccine, outcome_variable=outcome_variable)
8985
causal_test_case = CausalTestCase(
9086
base_test_case=base_test_case, expected_causal_effect=expected_effect, control_value=0, treatment_value=1
9187
)
9288

93-
# 8. Obtain the minimal adjustment set for the causal test case from the causal DAG
89+
# 7. Obtain the minimal adjustment set for the causal test case from the causal DAG
9490
minimal_adjustment_set = causal_dag.identification(base_test_case)
9591

96-
# 9. Build statistical model
92+
# 8. Build statistical model
9793
linear_regression_estimator = LinearRegressionEstimator(
9894
vaccine.name, 1, 0, minimal_adjustment_set, outcome_variable.name
9995
)
10096

101-
# 10. Execute test and save results in dict
102-
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case)
97+
# 9. Execute test and save results in dict
98+
causal_test_result = causal_test_case.execute_test(linear_regression_estimator, data_collector,
99+
causal_specification)
103100
if verbose:
104101
logging.info("Causation:\n%s", causal_test_result)
105102
results_dict[outcome_variable.name]["ate"] = causal_test_result.test_value.value

examples/lr91/example_max_conductances.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from causal_testing.data_collection.data_collector import ObservationalDataCollector
99
from causal_testing.testing.causal_test_case import CausalTestCase
1010
from causal_testing.testing.causal_test_outcome import Positive, Negative, NoEffect
11-
from causal_testing.testing.causal_test_engine import CausalTestEngine
1211
from causal_testing.testing.estimators import LinearRegressionEstimator
1312
from causal_testing.testing.base_test_case import BaseTestCase
1413
from matplotlib.pyplot import rcParams
@@ -138,17 +137,14 @@ def effects_on_APD90(observational_data_path, treatment_var, control_val, treatm
138137
# 7. Create a data collector
139138
data_collector = ObservationalDataCollector(scenario, pd.read_csv(observational_data_path))
140139

141-
# 8. Create an instance of the causal test engine
142-
causal_test_engine = CausalTestEngine(causal_specification, data_collector)
143-
144-
# 9. Obtain the minimal adjustment set from the causal DAG
140+
# 8. Obtain the minimal adjustment set from the causal DAG
145141
minimal_adjustment_set = causal_dag.identification(base_test_case)
146142
linear_regression_estimator = LinearRegressionEstimator(
147143
treatment_var.name, treatment_val, control_val, minimal_adjustment_set, "APD90"
148144
)
149145

150-
# 10. Run the causal test and print results
151-
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case)
146+
# 9. Run the causal test and print results
147+
causal_test_result = causal_test_case.execute_test(linear_regression_estimator, data_collector, causal_specification)
152148
logger.info("%s", causal_test_result)
153149
return causal_test_result.test_value.value, causal_test_result.confidence_intervals
154150

@@ -198,4 +194,4 @@ def normalise_data(df, columns=None):
198194

199195

200196
if __name__ == "__main__":
201-
test_sensitivity_analysis(show=True)
197+
test_sensitivity_analysis()

examples/lr91/example_max_conductances_test_suite.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from causal_testing.data_collection.data_collector import ObservationalDataCollector
99
from causal_testing.testing.causal_test_case import CausalTestCase
1010
from causal_testing.testing.causal_test_outcome import Positive, Negative, NoEffect
11-
from causal_testing.testing.causal_test_engine import CausalTestEngine
1211
from causal_testing.testing.estimators import LinearRegressionEstimator
1312
from causal_testing.testing.base_test_case import BaseTestCase
1413
from causal_testing.testing.causal_test_suite import CausalTestSuite
@@ -147,11 +146,9 @@ def effects_on_APD90(observational_data_path, test_suite):
147146
# 7. Create a data collector
148147
data_collector = ObservationalDataCollector(scenario, pd.read_csv(observational_data_path))
149148

150-
# 8. Create an instance of the causal test engine
151-
causal_test_engine = CausalTestEngine(causal_specification, data_collector)
152149

153-
# 9. Run the causal test suite
154-
causal_test_results = causal_test_engine.execute_test_suite(test_suite)
150+
# 8. Run the causal test suite
151+
causal_test_results = test_suite.execute_test_suite(data_collector, causal_specification)
155152
return causal_test_results
156153

157154

@@ -200,4 +197,4 @@ def normalise_data(df, columns=None):
200197

201198

202199
if __name__ == "__main__":
203-
test_sensitivity_analysis(show=True, save=True)
200+
test_sensitivity_analysis()

examples/poisson-line-process/example_poisson_process.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from causal_testing.data_collection.data_collector import ObservationalDataCollector
66
from causal_testing.testing.causal_test_case import CausalTestCase
77
from causal_testing.testing.causal_test_outcome import ExactValue, Positive
8-
from causal_testing.testing.causal_test_engine import CausalTestEngine
98
from causal_testing.testing.estimators import LinearRegressionEstimator, Estimator
109
from causal_testing.testing.base_test_case import BaseTestCase
1110

@@ -85,13 +84,10 @@ def causal_test_intensity_num_shapes(
8584
# 6. Create a data collector
8685
data_collector = ObservationalDataCollector(scenario, pd.read_csv(observational_data_path))
8786

88-
# 7. Create an instance of the causal test engine
89-
causal_test_engine = CausalTestEngine(causal_specification, data_collector)
90-
91-
# 8. Obtain the minimal adjustment set for the causal test case from the causal DAG
87+
# 7. Obtain the minimal adjustment set for the causal test case from the causal DAG
9288
minimal_adjustment_set = causal_dag.identification(causal_test_case.base_test_case)
9389

94-
# 9. Set up an estimator
90+
# 8. Set up an estimator
9591
data = pd.read_csv(observational_data_path)
9692

9793
treatment = causal_test_case.get_treatment_variable()
@@ -122,8 +118,8 @@ def causal_test_intensity_num_shapes(
122118
formula=f"{outcome} ~ {treatment} + {'+'.join(square_terms + inverse_terms + list([e for e in causal_test_case.effect_modifier_configuration]))} -1",
123119
)
124120

125-
# 10. Execute the test
126-
causal_test_result = causal_test_engine.execute_test(estimator, causal_test_case)
121+
# 9. Execute the test
122+
causal_test_result = causal_test_case.execute_test(estimator, data_collector, causal_specification)
127123

128124
return causal_test_result
129125

0 commit comments

Comments
 (0)