Skip to content

Commit a75ec7d

Browse files
Update Json & abstract_causal_test_case.py to match new engine
1 parent c507180 commit a75ec7d

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from causal_testing.specification.variable import Variable
1010
from causal_testing.testing.causal_test_case import CausalTestCase
1111
from causal_testing.testing.causal_test_outcome import CausalTestOutcome
12+
from causal_testing.testing.base_causal_test import BaseCausalTest
1213

1314
logger = logging.getLogger(__name__)
1415

@@ -29,7 +30,7 @@ def __init__(
2930
effect_modifiers: set[Variable] = None,
3031
estimate_type: str = "ate",
3132
):
32-
assert treatment_variables.issubset(scenario.variables.values()), (
33+
assert {treatment_variables}.issubset(scenario.variables.values()), (
3334
"Treatment variables must be a subset of variables."
3435
+ f" Instead got:\ntreatment_variables={treatment_variables}\nvariables={scenario.variables}"
3536
)
@@ -109,13 +110,15 @@ def _generate_concrete_tests(
109110
)
110111
model = optimizer.model()
111112

113+
base_causal_test = BaseCausalTest(self.treatment_variables, list(self.expected_causal_effect.keys())[0])
114+
112115
concrete_test = CausalTestCase(
113-
control_input_configuration={v: v.cast(model[v.z3]) for v in self.treatment_variables},
114-
treatment_input_configuration={
115-
v: v.cast(model[self.scenario.treatment_variables[v.name].z3]) for v in self.treatment_variables
116-
},
116+
base_causal_test=base_causal_test,
117+
control_value=self.treatment_variables.cast(model[self.treatment_variables.z3]),
118+
treatment_value=self.treatment_variables.cast(
119+
model[self.scenario.treatment_variables[self.treatment_variables.name].z3]
120+
),
117121
expected_causal_effect=list(self.expected_causal_effect.values())[0],
118-
outcome_variables=list(self.expected_causal_effect.keys()),
119122
estimate_type=self.estimate_type,
120123
effect_modifier_configuration={v: v.cast(model[v.z3]) for v in self.effect_modifiers},
121124
)

causal_testing/json_front/json_class.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,11 @@ def setup(self):
8989
self._populate_metas()
9090

9191
def _create_abstract_test_case(self, test, mutates, effects):
92+
9293
abstract_test = AbstractCausalTestCase(
9394
scenario=self.modelling_scenario,
9495
intervention_constraints=[mutates[v](k) for k, v in test["mutations"].items()],
95-
treatment_variables={self.modelling_scenario.variables[v] for v in test["mutations"]},
96+
treatment_variables=self.modelling_scenario.variables[next(iter(test["mutations"]))],
9697
expected_causal_effect={
9798
self.modelling_scenario.variables[variable]: effects[effect]
9899
for variable, effect in test["expectedEffect"].items()
@@ -121,7 +122,7 @@ def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
121122
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)
122123
logger.info("Executing test: %s", test["name"])
123124
logger.info(abstract_test)
124-
logger.info([(v.name, v.distribution) for v in abstract_test.treatment_variables])
125+
logger.info([abstract_test.treatment_variables.name, abstract_test.treatment_variables.distribution])
125126
logger.info("Number of concrete tests for test case: %s", str(len(concrete_tests)))
126127
failures = self._execute_tests(concrete_tests, estimators, test, f_flag)
127128
logger.info("%s/%s failed", failures, len(concrete_tests))
@@ -204,15 +205,15 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
204205
"""
205206
data_collector = ObservationalDataCollector(self.modelling_scenario, self.data_path)
206207
causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
207-
causal_test_engine.identification(causal_test_case)
208+
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_causal_test)
208209
treatment_vars = list(causal_test_case.treatment_input_configuration)
209-
minimal_adjustment_set = causal_test_engine.minimal_adjustment_set - {v.name for v in treatment_vars}
210+
minimal_adjustment_set = minimal_adjustment_set - {v.name for v in treatment_vars}
210211
estimation_model = estimator(
211212
(list(treatment_vars)[0].name,),
212213
[causal_test_case.treatment_input_configuration[v] for v in treatment_vars][0],
213214
[causal_test_case.control_input_configuration[v] for v in treatment_vars][0],
214215
minimal_adjustment_set,
215-
(list(causal_test_case.outcome_variables)[0].name,),
216+
(causal_test_case.outcome_variable.name,),
216217
causal_test_engine.scenario_execution_data_df,
217218
effect_modifiers=causal_test_case.effect_modifier_configuration,
218219
)

0 commit comments

Comments
 (0)