Skip to content

Commit 2ceb654

Browse files
committed
Removed datacollector from testing
1 parent b8ace8e commit 2ceb654

File tree

2 files changed

+25
-72
lines changed

2 files changed

+25
-72
lines changed

examples/covasim_/doubling_beta/example_beta.py

Lines changed: 20 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from causal_testing.specification.scenario import Scenario
77
from causal_testing.specification.variable import Input, Output
88
from causal_testing.specification.causal_specification import CausalSpecification
9-
from causal_testing.data_collection.data_collector import ObservationalDataCollector
109
from causal_testing.testing.causal_test_case import CausalTestCase
1110
from causal_testing.testing.causal_test_outcome import Positive
1211
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
@@ -52,7 +51,26 @@ def doubling_beta_CATE_on_csv(
5251

5352
# Read in the observational data, perform identification
5453
past_execution_df = pd.read_csv(observational_data_path)
55-
data_collector, _, causal_test_case, causal_specification = setup(past_execution_df)
54+
55+
# 2. Create variables
56+
pop_size = Input("pop_size", int)
57+
pop_infected = Input("pop_infected", int)
58+
n_days = Input("n_days", int)
59+
cum_infections = Output("cum_infections", int)
60+
cum_deaths = Output("cum_deaths", int)
61+
location = Input("location", str)
62+
variants = Input("variants", str)
63+
avg_age = Input("avg_age", float)
64+
beta = Input("beta", float)
65+
contacts = Input("contacts", float)
66+
67+
# 5. Create a base test case
68+
base_test_case = BaseTestCase(treatment_variable=beta, outcome_variable=cum_infections)
69+
70+
# 6. Create a causal test case
71+
causal_test_case = CausalTestCase(
72+
base_test_case=base_test_case, expected_causal_effect=Positive, control_value=0.016, treatment_value=0.032
73+
)
5674

5775
linear_regression_estimator = LinearRegressionEstimator(
5876
"beta",
@@ -98,15 +116,6 @@ def doubling_beta_CATE_on_csv(
98116
# Repeat causal inference after deleting all rows with treatment value to obtain counterfactual inferences
99117
if simulate_counterfactuals:
100118
counterfactual_past_execution_df = past_execution_df[past_execution_df["beta"] != 0.032]
101-
counterfactual_linear_regression_estimator = LinearRegressionEstimator(
102-
"beta",
103-
0.032,
104-
0.016,
105-
{"avg_age", "contacts"},
106-
"cum_infections",
107-
df=counterfactual_past_execution_df,
108-
formula="cum_infections ~ beta + I(beta ** 2) + avg_age + contacts",
109-
)
110119
counterfactual_causal_test_result = causal_test_case.execute_test(estimator=linear_regression_estimator)
111120

112121
results_dict["counterfactual"] = {
@@ -215,59 +224,6 @@ def doubling_beta_CATEs(observational_data_path: str, simulate_counterfactual: b
215224
age_contact_fig.savefig(outpath_base_str + "age_contact_executions.pdf", format="pdf")
216225

217226

218-
def setup(observational_data):
219-
# 1. Read in the Causal DAG
220-
causal_dag = CausalDAG(f"{ROOT}/dag.dot")
221-
222-
# 2. Create variables
223-
pop_size = Input("pop_size", int)
224-
pop_infected = Input("pop_infected", int)
225-
n_days = Input("n_days", int)
226-
cum_infections = Output("cum_infections", int)
227-
cum_deaths = Output("cum_deaths", int)
228-
location = Input("location", str)
229-
variants = Input("variants", str)
230-
avg_age = Input("avg_age", float)
231-
beta = Input("beta", float)
232-
contacts = Input("contacts", float)
233-
234-
# 3. Create scenario by applying constraints over a subset of the input variables
235-
scenario = Scenario(
236-
variables={
237-
pop_size,
238-
pop_infected,
239-
n_days,
240-
cum_infections,
241-
cum_deaths,
242-
location,
243-
variants,
244-
avg_age,
245-
beta,
246-
contacts,
247-
},
248-
constraints={pop_size.z3 == 51633, pop_infected.z3 == 1000, n_days.z3 == 216},
249-
)
250-
251-
# 4. Construct a causal specification from the scenario and causal DAG
252-
causal_specification = CausalSpecification(scenario, causal_dag)
253-
254-
# 5. Create a base test case
255-
base_test_case = BaseTestCase(treatment_variable=beta, outcome_variable=cum_infections)
256-
257-
# 6. Create a causal test case
258-
causal_test_case = CausalTestCase(
259-
base_test_case=base_test_case, expected_causal_effect=Positive, control_value=0.016, treatment_value=0.032
260-
)
261-
262-
# 7. Create a data collector
263-
data_collector = ObservationalDataCollector(scenario, observational_data)
264-
265-
# 8. Obtain the minimal adjustment set for the base test case from the causal DAG
266-
minimal_adjustment_set = causal_dag.identification(base_test_case)
267-
268-
return data_collector, minimal_adjustment_set, causal_test_case, causal_specification
269-
270-
271227
def plot_doubling_beta_CATEs(results_dict, title, figure=None, axes=None, row=None, col=None):
272228
# Get the CATE as a percentage for association and causation
273229
ate = results_dict["causation"]["ate"][0]

examples/covasim_/vaccinating_elderly/example_vaccine.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from causal_testing.specification.scenario import Scenario
77
from causal_testing.specification.variable import Input, Output
88
from causal_testing.specification.causal_specification import CausalSpecification
9-
from causal_testing.data_collection.data_collector import ObservationalDataCollector
109
from causal_testing.testing.causal_test_case import CausalTestCase
1110
from causal_testing.testing.causal_test_outcome import Positive, Negative, NoEffect
1211
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
@@ -19,8 +18,8 @@
1918

2019

2120
def setup_test_case(verbose: bool = False):
22-
"""Run the causal test case for the effect of changing vaccine to prioritise elderly from an observational
23-
data collector that was previously simulated.
21+
"""Run the causal test case for the effect of changing vaccine to prioritise elderly from observational
22+
data that was previously simulated.
2423
2524
:param verbose: Whether to print verbose details (causal test results).
2625
:return results_dict: A dictionary containing ATE, 95% CIs, and Test Pass/Fail
@@ -57,11 +56,9 @@ def setup_test_case(verbose: bool = False):
5756
# 4. Construct a causal specification from the scenario and causal DAG
5857
causal_specification = CausalSpecification(scenario, causal_dag)
5958

60-
# 5. Instantiate the observational data collector using the previously simulated data
59+
# 5. Read the previously simulated data
6160
obs_df = pd.read_csv("simulated_data.csv")
6261

63-
data_collector = ObservationalDataCollector(scenario, obs_df)
64-
6562
# 6. Express expected outcomes
6663
expected_outcome_effects = {
6764
cum_infections: Positive(),
@@ -90,7 +87,7 @@ def setup_test_case(verbose: bool = False):
9087
)
9188

9289
# 9. Execute test and save results in dict
93-
causal_test_result = causal_test_case.execute_test(linear_regression_estimator, data_collector)
90+
causal_test_result = causal_test_case.execute_test(linear_regression_estimator, obs_df)
9491

9592
if verbose:
9693
logging.info("Causation:\n%s", causal_test_result)
@@ -110,4 +107,4 @@ def setup_test_case(verbose: bool = False):
110107

111108
test_results = setup_test_case(verbose=True)
112109

113-
logging.info("%s", test_results)
110+
logging.info("%s", test_results)

0 commit comments

Comments
 (0)