Skip to content

Commit b899c2c

Browse files
Update causal_test_max_conductances.py to use new test_engine
1 parent 87bc91d commit b899c2c

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

examples/lr91/causal_test_max_conductances.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from causal_testing.testing.causal_test_outcome import Positive, Negative, NoEffect
1111
from causal_testing.testing.causal_test_engine import CausalTestEngine
1212
from causal_testing.testing.estimators import LinearRegressionEstimator
13+
from causal_testing.testing.base_causal_test import BaseCausalTest
1314
from matplotlib.pyplot import rcParams
1415

1516
# Uncommenting the code below will make all graphs publication quality but requires a suitable latex installation
@@ -109,12 +110,12 @@ def effects_on_APD90(observational_data_path, treatment_var, control_val, treatm
109110

110111
# 5. Create a causal specification from the scenario and causal DAG
111112
causal_specification = CausalSpecification(scenario, causal_dag)
112-
113+
base_test_case = BaseCausalTest(treatment_var, apd90)
113114
# 6. Create a causal test case
114-
causal_test_case = CausalTestCase(control_input_configuration={treatment_var: control_val},
115+
causal_test_case = CausalTestCase(base_causal_test=base_test_case,
115116
expected_causal_effect=expected_causal_effect,
116-
treatment_input_configuration={treatment_var: treatment_val},
117-
outcome_variables={apd90})
117+
control_value=control_val,
118+
treatment_value=treatment_val)
118119

119120
# 7. Create a data collector
120121
data_collector = ObservationalDataCollector(scenario, observational_data_path)
@@ -123,9 +124,9 @@ def effects_on_APD90(observational_data_path, treatment_var, control_val, treatm
123124
causal_test_engine = CausalTestEngine(causal_specification, data_collector)
124125

125126
# 9. Obtain the minimal adjustment set from the causal DAG
126-
causal_test_engine.identification(causal_test_case)
127+
minimal_adjustment_set = causal_dag.identification(base_test_case)
127128
linear_regression_estimator = LinearRegressionEstimator((treatment_var.name,), treatment_val, control_val,
128-
causal_test_engine.minimal_adjustment_set,
129+
minimal_adjustment_set,
129130
('APD90',)
130131
)
131132

@@ -157,7 +158,6 @@ def plot_ates_with_cis(results_dict: dict, xs: list, save: bool = True):
157158
latex_compatible_treatment_str = rf"${before_underscore}_{after_underscore_braces}$"
158159
cis_low = [c[0] for c in cis]
159160
cis_high = [c[1] for c in cis]
160-
161161
axes.fill_between(xs, cis_low, cis_high, alpha=.2, color=input_colors[treatment],
162162
label=latex_compatible_treatment_str)
163163
axes.plot(xs, ates, color=input_colors[treatment], linewidth=1)

0 commit comments

Comments
 (0)