|
6 | 6 | from causal_testing.specification.scenario import Scenario
|
7 | 7 | from causal_testing.specification.variable import Input, Output
|
8 | 8 | from causal_testing.specification.causal_specification import CausalSpecification
|
9 |
| -from causal_testing.data_collection.data_collector import ObservationalDataCollector |
10 | 9 | from causal_testing.testing.causal_test_case import CausalTestCase
|
11 | 10 | from causal_testing.testing.causal_test_outcome import Positive
|
12 | 11 | from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
|
@@ -52,7 +51,26 @@ def doubling_beta_CATE_on_csv(
|
52 | 51 |
|
53 | 52 | # Read in the observational data, perform identification
|
54 | 53 | past_execution_df = pd.read_csv(observational_data_path)
|
55 |
| - data_collector, _, causal_test_case, causal_specification = setup(past_execution_df) |
| 54 | + |
| 55 | + # 2. Create variables |
| 56 | + pop_size = Input("pop_size", int) |
| 57 | + pop_infected = Input("pop_infected", int) |
| 58 | + n_days = Input("n_days", int) |
| 59 | + cum_infections = Output("cum_infections", int) |
| 60 | + cum_deaths = Output("cum_deaths", int) |
| 61 | + location = Input("location", str) |
| 62 | + variants = Input("variants", str) |
| 63 | + avg_age = Input("avg_age", float) |
| 64 | + beta = Input("beta", float) |
| 65 | + contacts = Input("contacts", float) |
| 66 | + |
| 67 | + # 5. Create a base test case |
| 68 | + base_test_case = BaseTestCase(treatment_variable=beta, outcome_variable=cum_infections) |
| 69 | + |
| 70 | + # 6. Create a causal test case |
| 71 | + causal_test_case = CausalTestCase( |
| 72 | + base_test_case=base_test_case, expected_causal_effect=Positive, control_value=0.016, treatment_value=0.032 |
| 73 | + ) |
56 | 74 |
|
57 | 75 | linear_regression_estimator = LinearRegressionEstimator(
|
58 | 76 | "beta",
|
@@ -98,15 +116,6 @@ def doubling_beta_CATE_on_csv(
|
98 | 116 | # Repeat causal inference after deleting all rows with treatment value to obtain counterfactual inferences
|
99 | 117 | if simulate_counterfactuals:
|
100 | 118 | counterfactual_past_execution_df = past_execution_df[past_execution_df["beta"] != 0.032]
|
101 |
| - counterfactual_linear_regression_estimator = LinearRegressionEstimator( |
102 |
| - "beta", |
103 |
| - 0.032, |
104 |
| - 0.016, |
105 |
| - {"avg_age", "contacts"}, |
106 |
| - "cum_infections", |
107 |
| - df=counterfactual_past_execution_df, |
108 |
| - formula="cum_infections ~ beta + I(beta ** 2) + avg_age + contacts", |
109 |
| - ) |
110 | 119 | counterfactual_causal_test_result = causal_test_case.execute_test(estimator=linear_regression_estimator)
|
111 | 120 |
|
112 | 121 | results_dict["counterfactual"] = {
|
@@ -215,59 +224,6 @@ def doubling_beta_CATEs(observational_data_path: str, simulate_counterfactual: b
|
215 | 224 | age_contact_fig.savefig(outpath_base_str + "age_contact_executions.pdf", format="pdf")
|
216 | 225 |
|
217 | 226 |
|
218 |
| -def setup(observational_data): |
219 |
| - # 1. Read in the Causal DAG |
220 |
| - causal_dag = CausalDAG(f"{ROOT}/dag.dot") |
221 |
| - |
222 |
| - # 2. Create variables |
223 |
| - pop_size = Input("pop_size", int) |
224 |
| - pop_infected = Input("pop_infected", int) |
225 |
| - n_days = Input("n_days", int) |
226 |
| - cum_infections = Output("cum_infections", int) |
227 |
| - cum_deaths = Output("cum_deaths", int) |
228 |
| - location = Input("location", str) |
229 |
| - variants = Input("variants", str) |
230 |
| - avg_age = Input("avg_age", float) |
231 |
| - beta = Input("beta", float) |
232 |
| - contacts = Input("contacts", float) |
233 |
| - |
234 |
| - # 3. Create scenario by applying constraints over a subset of the input variables |
235 |
| - scenario = Scenario( |
236 |
| - variables={ |
237 |
| - pop_size, |
238 |
| - pop_infected, |
239 |
| - n_days, |
240 |
| - cum_infections, |
241 |
| - cum_deaths, |
242 |
| - location, |
243 |
| - variants, |
244 |
| - avg_age, |
245 |
| - beta, |
246 |
| - contacts, |
247 |
| - }, |
248 |
| - constraints={pop_size.z3 == 51633, pop_infected.z3 == 1000, n_days.z3 == 216}, |
249 |
| - ) |
250 |
| - |
251 |
| - # 4. Construct a causal specification from the scenario and causal DAG |
252 |
| - causal_specification = CausalSpecification(scenario, causal_dag) |
253 |
| - |
254 |
| - # 5. Create a base test case |
255 |
| - base_test_case = BaseTestCase(treatment_variable=beta, outcome_variable=cum_infections) |
256 |
| - |
257 |
| - # 6. Create a causal test case |
258 |
| - causal_test_case = CausalTestCase( |
259 |
| - base_test_case=base_test_case, expected_causal_effect=Positive, control_value=0.016, treatment_value=0.032 |
260 |
| - ) |
261 |
| - |
262 |
| - # 7. Create a data collector |
263 |
| - data_collector = ObservationalDataCollector(scenario, observational_data) |
264 |
| - |
265 |
| - # 8. Obtain the minimal adjustment set for the base test case from the causal DAG |
266 |
| - minimal_adjustment_set = causal_dag.identification(base_test_case) |
267 |
| - |
268 |
| - return data_collector, minimal_adjustment_set, causal_test_case, causal_specification |
269 |
| - |
270 |
| - |
271 | 227 | def plot_doubling_beta_CATEs(results_dict, title, figure=None, axes=None, row=None, col=None):
|
272 | 228 | # Get the CATE as a percentage for association and causation
|
273 | 229 | ate = results_dict["causation"]["ate"][0]
|
|
0 commit comments