Skip to content

Commit 92b6250

Browse files
Fix merge issues
1 parent 6d0e39b commit 92b6250

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,18 +121,21 @@ def _generate_concrete_tests(
121121
)
122122
model = optimizer.model()
123123

124-
base_test_case = BaseTestCase(self.treatment_variables, list(self.expected_causal_effect.keys())[0])
124+
base_test_case = BaseTestCase(
125+
treatment_variable=self.treatment_variable,
126+
outcome_variable=list(self.expected_causal_effect.keys())[0],
127+
effect=self.effect,
128+
)
125129

126130
concrete_test = CausalTestCase(
127131
base_test_case=base_test_case,
128-
control_value=self.treatment_variables.cast(model[self.treatment_variables.z3]),
129-
treatment_value=self.treatment_variables.cast(
130-
model[self.scenario.treatment_variables[self.treatment_variables.name].z3]
132+
control_value=self.treatment_variable.cast(model[self.treatment_variable.z3]),
133+
treatment_value=self.treatment_variable.cast(
134+
model[self.scenario.treatment_variables[self.treatment_variable.name].z3]
131135
),
132136
expected_causal_effect=list(self.expected_causal_effect.values())[0],
133137
estimate_type=self.estimate_type,
134138
effect_modifier_configuration={v: v.cast(model[v.z3]) for v in self.effect_modifiers},
135-
effect=self.effect,
136139
)
137140

138141
for v in self.scenario.inputs():

causal_testing/json_front/json_class.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
123123
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)
124124
logger.info("Executing test: %s", test["name"])
125125
logger.info(abstract_test)
126-
logger.info([abstract_test.treatment_variables.name, abstract_test.treatment_variables.distribution])
126+
logger.info([abstract_test.treatment_variable.name, abstract_test.treatment_variable.distribution])
127127
logger.info("Number of concrete tests for test case: %s", str(len(concrete_tests)))
128128
failures = self._execute_tests(concrete_tests, estimators, test, f_flag)
129129
logger.info("%s/%s failed for %s\n", failures, len(concrete_tests), test["name"])

0 commit comments

Comments
 (0)