|
10 | 10 | from causal_testing.testing.causal_test_outcome import Positive
|
11 | 11 | from causal_testing.testing.causal_test_engine import CausalTestEngine
|
12 | 12 | from causal_testing.testing.estimators import LinearRegressionEstimator
|
| 13 | +from causal_testing.testing.base_test_case import BaseTestCase |
13 | 14 | from matplotlib.pyplot import rcParams
|
14 | 15 |
|
15 | 16 | # Uncommenting the code below will make all graphs publication quality but requires a suitable latex installation
|
@@ -203,22 +204,26 @@ def engine_setup(observational_data_path):
|
203 | 204 | # 4. Construct a causal specification from the scenario and causal DAG
|
204 | 205 | causal_specification = CausalSpecification(scenario, causal_dag)
|
205 | 206 |
|
206 |
| - # 5. Create a causal test case |
207 |
| - causal_test_case = CausalTestCase(control_input_configuration={beta: 0.016}, |
| 207 | + # 5. Create a base test case |
| 208 | + base_test_case = BaseTestCase(treatment_variable=beta, |
| 209 | + outcome_variable=cum_infections) |
| 210 | + |
| 211 | + # 6. Create a causal test case |
| 212 | + causal_test_case = CausalTestCase(base_test_case=base_test_case, |
208 | 213 | expected_causal_effect=Positive,
|
209 |
| - treatment_input_configuration={beta: 0.032}, |
210 |
| - outcome_variables={cum_infections}) |
| 214 | + control_value=0.016, |
| 215 | + treatment_value=0.032) |
211 | 216 |
|
212 |
| - # 6. Create a data collector |
| 217 | + # 7. Create a data collector |
213 | 218 | data_collector = ObservationalDataCollector(scenario, observational_data_path)
|
214 | 219 |
|
215 |
| - # 7. Create an instance of the causal test engine |
| 220 | + # 8. Create an instance of the causal test engine |
216 | 221 | causal_test_engine = CausalTestEngine(causal_specification, data_collector)
|
217 | 222 |
|
218 |
| - # 8. Obtain the minimal adjustment set for the causal test case from the causal DAG |
219 |
| - causal_test_engine.identification(causal_test_case) |
| 223 | + # 9. Obtain the minimal adjustment set for the base test case from the causal DAG |
| 224 | + minimal_adjustment_set = causal_dag.identification(base_test_case) |
220 | 225 |
|
221 |
| - return causal_test_engine.minimal_adjustment_set, causal_test_engine, causal_test_case |
| 226 | + return minimal_adjustment_set, causal_test_engine, causal_test_case |
222 | 227 |
|
223 | 228 |
|
224 | 229 | def plot_doubling_beta_CATEs(results_dict, title, figure=None, axes=None, row=None, col=None):
|
|
0 commit comments