Skip to content

Commit 97ff646

Browse files
fix concrete test case after merge
1 parent 36c2e4f commit 97ff646

File tree

2 files changed

+7
-10
lines changed

2 files changed

+7
-10
lines changed

causal_testing/json_front/json_class.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
8787
for test in self.test_plan["tests"]:
8888
if "skip" in test and test["skip"]:
8989
continue
90+
test["estimator"] = estimators[test["estimator"]]
9091
if "mutations" in test:
9192
abstract_test = self._create_abstract_test_case(test, mutates, effects)
9293

@@ -117,9 +118,8 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
117118
treatment_value=test["treatment_value"],
118119
estimate_type=test["estimate_type"],
119120
)
120-
121121
if self._execute_test_case(
122-
causal_test_case=causal_test_case, estimator=estimators[test["estimator"]], f_flag=f_flag
122+
causal_test_case=causal_test_case, test=test, f_flag=f_flag
123123
):
124124
result = "failed"
125125
else:
@@ -130,7 +130,7 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
130130
+ f"treatment variable: {test['treatment_variable']} \n"
131131
+ f"outcome_variable = {outcome_variable} \n"
132132
+ f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
133-
+ f"result - {result} \n"
133+
+ f"result - {result}"
134134
)
135135
self._append_to_file(msg, logging.INFO)
136136

@@ -154,7 +154,6 @@ def _create_abstract_test_case(self, test, mutates, effects):
154154

155155
def _execute_tests(self, concrete_tests, estimators, test, f_flag):
156156
failures = 0
157-
test["estimator"] = estimators[test["estimator"]]
158157
if "formula" in test:
159158
self._append_to_file(f"Estimator formula used for test: {test['formula']}")
160159
for concrete_test in concrete_tests:
@@ -203,7 +202,6 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, test: Iterable[Ma
203202

204203
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
205204

206-
result_string = str()
207205
if causal_test_result.ci_low() and causal_test_result.ci_high():
208206
result_string = (
209207
f"{causal_test_result.ci_low()} < {causal_test_result.test_value.value} < "
@@ -248,7 +246,6 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> tuple[
248246
}
249247
if "formula" in test:
250248
estimator_kwargs["formula"] = test["formula"]
251-
252249
estimation_model = test["estimator"](**estimator_kwargs)
253250
return causal_test_engine, estimation_model
254251

@@ -261,7 +258,7 @@ def _append_to_file(self, line: str, log_level: int = None):
261258
"""
262259
with open(self.output_path, "a", encoding="utf-8") as f:
263260
f.write(
264-
line + "\n",
261+
line + "\n"
265262
)
266263
if log_level:
267264
logger.log(level=log_level, msg=line)

tests/json_front_tests/test_json_class.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def test_formula_in_json_test(self):
136136
"estimator": "LinearRegressionEstimator",
137137
"estimate_type": "ate",
138138
"effect_modifiers": [],
139-
"expectedEffect": {"test_output": "Positive"},
139+
"expected_effect": {"test_output": "Positive"},
140140
"skip": False,
141141
"formula": "test_output ~ test_input"
142142
}
@@ -150,7 +150,7 @@ def test_formula_in_json_test(self):
150150
}
151151
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
152152

153-
self.json_class.generate_tests(effects, mutates, estimators, False)
153+
self.json_class.run_json_tests(effects=effects, mutates=mutates, estimators=estimators, f_flag=False)
154154
with open("temp_out.txt", 'r') as reader:
155155
temp_out = reader.readlines()
156156
self.assertIn("test_output ~ test_input", ''.join(temp_out))
@@ -174,7 +174,7 @@ def test_run_concrete_json_testcase(self):
174174
effects = {"NoEffect": NoEffect()}
175175
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
176176

177-
self.json_class.run_json_tests(effects, estimators, False)
177+
self.json_class.run_json_tests(effects=effects, estimators=estimators, f_flag=False)
178178
with open("temp_out.txt", 'r') as reader:
179179
temp_out = reader.readlines()
180180
self.assertIn("failed", temp_out[-1])

0 commit comments

Comments
 (0)