Skip to content

Commit 237d80b

Browse files
covasim examples update
1 parent 54c1077 commit 237d80b

File tree

3 files changed

+26
-21
lines changed

3 files changed

+26
-21
lines changed

examples/covasim_/doubling_beta/causal_test_beta.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from causal_testing.testing.causal_test_outcome import Positive
1111
from causal_testing.testing.causal_test_engine import CausalTestEngine
1212
from causal_testing.testing.estimators import LinearRegressionEstimator
13+
from causal_testing.testing.base_test_case import BaseTestCase
1314
from matplotlib.pyplot import rcParams
1415

1516
# Uncommenting the code below will make all graphs publication quality but requires a suitable latex installation
@@ -203,22 +204,26 @@ def engine_setup(observational_data_path):
203204
# 4. Construct a causal specification from the scenario and causal DAG
204205
causal_specification = CausalSpecification(scenario, causal_dag)
205206

206-
# 5. Create a causal test case
207-
causal_test_case = CausalTestCase(control_input_configuration={beta: 0.016},
207+
# 5. Create a base test case
208+
base_test_case = BaseTestCase(treatment_variable=beta,
209+
outcome_variable=cum_infections)
210+
211+
# 6. Create a causal test case
212+
causal_test_case = CausalTestCase(base_test_case=base_test_case,
208213
expected_causal_effect=Positive,
209-
treatment_input_configuration={beta: 0.032},
210-
outcome_variables={cum_infections})
214+
control_value=0.016,
215+
treatment_value=0.032)
211216

212-
# 6. Create a data collector
217+
# 7. Create a data collector
213218
data_collector = ObservationalDataCollector(scenario, observational_data_path)
214219

215-
# 7. Create an instance of the causal test engine
220+
# 8. Create an instance of the causal test engine
216221
causal_test_engine = CausalTestEngine(causal_specification, data_collector)
217222

218-
# 8. Obtain the minimal adjustment set for the causal test case from the causal DAG
219-
causal_test_engine.identification(causal_test_case)
223+
# 9. Obtain the minimal adjustment set for the base test case from the causal DAG
224+
minimal_adjustment_set = causal_dag.identification(base_test_case)
220225

221-
return causal_test_engine.minimal_adjustment_set, causal_test_engine, causal_test_case
226+
return minimal_adjustment_set, causal_test_engine, causal_test_case
222227

223228

224229
def plot_doubling_beta_CATEs(results_dict, title, figure=None, axes=None, row=None, col=None):

examples/covasim_/vaccinating_elderly/causal_test_vaccine.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from causal_testing.testing.causal_test_outcome import Positive, Negative, NoEffect
1212
from causal_testing.testing.causal_test_engine import CausalTestEngine
1313
from causal_testing.testing.estimators import LinearRegressionEstimator
14+
from causal_testing.testing.base_test_case import BaseTestCase
1415

1516

1617
def experimental_causal_test_vaccinate_elderly(runs_per_test_per_config: int = 30, verbose: bool = False):
@@ -71,21 +72,21 @@ def experimental_causal_test_vaccinate_elderly(runs_per_test_per_config: int = 3
7172

7273
# 7. Create an instance of the causal test engine
7374
causal_test_engine = CausalTestEngine(causal_specification, data_collector, index_col=0)
75+
7476
for outcome_variable, expected_effect in expected_outcome_effects.items():
75-
causal_test_case = CausalTestCase(control_input_configuration={vaccine: 0},
77+
base_test_case = BaseTestCase(treatment_variable=vaccine,
78+
outcome_variable=outcome_variable)
79+
causal_test_case = CausalTestCase(base_test_case=base_test_case,
7680
expected_causal_effect=expected_effect,
77-
treatment_input_configuration={vaccine: 1},
78-
outcome_variables={outcome_variable})
79-
80-
81-
81+
control_value=0,
82+
treatment_value=1)
8283

8384
# 8. Obtain the minimal adjustment set for the causal test case from the causal DAG
84-
causal_test_engine.identification(causal_test_case)
85+
minimal_adjustment_set = causal_dag.identification(base_test_case)
8586

8687
# 9. Build statistical model
8788
linear_regression_estimator = LinearRegressionEstimator((vaccine.name,), 1, 0,
88-
causal_test_engine.minimal_adjustment_set,
89+
minimal_adjustment_set,
8990
(outcome_variable.name,))
9091

9192
# 10. Execute test and save results in dict

examples/lr91/causal_test_max_conductances_test_suite.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,20 +72,19 @@ def causal_testing_sensitivity_analysis():
7272
for treatment_value in treatment_values:
7373
test_list.append(CausalTestCase(base_test_case, oracle, control_value, treatment_value))
7474
test_suite.add_test_object(base_test_case=base_test_case,
75-
test_list=test_list,
76-
estimators=estimator_list,
75+
causal_test_case_list=test_list,
76+
estimators_classes=estimator_list,
7777
estimate_type='ate')
7878

7979
causal_test_results = effects_on_APD90(OBSERVATIONAL_DATA_PATH, test_suite)
8080

8181
# Extract data from causal_test_results needed for plotting
8282
for base_test_case in causal_test_results:
83-
8483
# Place results of test_suite into format required for plotting
8584
results[base_test_case.treatment_variable.name] = \
8685
{"ate": [result.ate for result in causal_test_results[base_test_case]['LinearRegressionEstimator']],
8786
"cis": [result.confidence_intervals for result in
88-
causal_test_results[base_test_case][0]]}
87+
causal_test_results[base_test_case]['LinearRegressionEstimator']]}
8988

9089
plot_ates_with_cis(results, treatment_values)
9190

0 commit comments

Comments
 (0)