Skip to content

Commit fb06924

Browse files
refactor if block + pylint + black
1 parent d98a41a commit fb06924

File tree

3 files changed

+14
-26
lines changed

3 files changed

+14
-26
lines changed

causal_testing/json_front/json_class.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -203,28 +203,19 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> tuple[
203203
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
204204
treatment_var = causal_test_case.treatment_variable
205205
minimal_adjustment_set = minimal_adjustment_set - {treatment_var}
206+
estimator_kwargs = {
207+
"treatment": treatment_var.name,
208+
"treatment_value": causal_test_case.treatment_value,
209+
"control_value": causal_test_case.control_value,
210+
"adjustment_set": minimal_adjustment_set,
211+
"outcome": causal_test_case.outcome_variable.name,
212+
"df": causal_test_engine.scenario_execution_data_df,
213+
"effect_modifiers": causal_test_case.effect_modifier_configuration,
214+
}
206215
if "formula" in test:
207-
estimation_model = test["estimator"](
208-
treatment=treatment_var.name,
209-
treatment_value=causal_test_case.treatment_value,
210-
control_value=causal_test_case.control_value,
211-
adjustment_set=minimal_adjustment_set,
212-
outcome=causal_test_case.outcome_variable.name,
213-
df=causal_test_engine.scenario_execution_data_df,
214-
effect_modifiers=causal_test_case.effect_modifier_configuration,
215-
formula=test["formula"],
216-
)
217-
else:
218-
estimation_model = test["estimator"](
219-
treatment=treatment_var.name,
220-
treatment_value=causal_test_case.treatment_value,
221-
control_value=causal_test_case.control_value,
222-
adjustment_set=minimal_adjustment_set,
223-
outcome=causal_test_case.outcome_variable.name,
224-
df=causal_test_engine.scenario_execution_data_df,
225-
effect_modifiers=causal_test_case.effect_modifier_configuration,
226-
)
216+
estimator_kwargs["formula"] = test["formula"]
227217

218+
estimation_model = test["estimator"](**estimator_kwargs)
228219
return causal_test_engine, estimation_model
229220

230221
def _append_to_file(self, line: str, log_level: int = None):
@@ -235,9 +226,7 @@ def _append_to_file(self, line: str, log_level: int = None):
235226
is possible to use the inbuilt logging level variables such as logging.INFO and logging.WARNING
236227
"""
237228
with open(self.output_path, "a", encoding="utf-8") as f:
238-
f.write(
239-
line
240-
)
229+
f.write(line)
241230
if log_level:
242231
logger.log(level=log_level, msg=line)
243232

causal_testing/testing/estimators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def __init__(
320320
self.formula = formula
321321
else:
322322
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
323-
self.formula = f"{outcome} ~ {'+'.join(((terms)))}"
323+
self.formula = f"{outcome} ~ {'+'.join(terms)}"
324324

325325
for term in self.effect_modifiers:
326326
self.adjustment_set.add(term)

tests/json_front_tests/test_json_class.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,7 @@ def test_formula_in_json_test(self):
156156
self.assertIn("test_output ~ test_input", ''.join(temp_out))
157157

158158
def tearDown(self) -> None:
159-
pass
160-
# remove_temp_dir_if_existent()
159+
remove_temp_dir_if_existent()
161160

162161

163162
def populate_example(*args, **kwargs):

0 commit comments

Comments
 (0)