Skip to content

Commit 62ccef3

Browse files
pylint
1 parent 0c9bdd3 commit 62ccef3

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

causal_testing/json_front/json_class.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,6 @@ def __init__(self, output_path: str, output_overwrite: bool = False):
5555
self.causal_specification = None
5656
self.output_path = Path(output_path)
5757
self.check_file_exists(self.output_path, output_overwrite)
58-
self.effects = None
59-
self.mutates = None
6058

6159
def set_paths(self, json_path: str, dag_path: str, data_paths: str):
6260
"""
@@ -113,17 +111,15 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
113111
:param estimators: Dictionary mapping estimator classes to string representations.
114112
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
115113
"""
116-
self.effects = effects
117-
self.mutates = mutates
118114
for test in self.test_plan["tests"]:
119115
if "skip" in test and test["skip"]:
120116
continue
121117
test["estimator"] = estimators[test["estimator"]]
122118
if "mutations" in test:
123119
if test["estimate_type"] == "coefficient":
124-
msg = self.run_coefficient_test(test=test, f_flag=f_flag)
120+
msg = self._run_coefficient_test(test=test, f_flag=f_flag, effects=effects)
125121
else:
126-
msg = self.run_ate_test(test=test, f_flag=f_flag)
122+
msg = self._run_ate_test(test=test, f_flag=f_flag, effects=effects, mutates=mutates)
127123
self._append_to_file(msg, logging.INFO)
128124
else:
129125
outcome_variable = next(
@@ -152,11 +148,12 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
152148
)
153149
self._append_to_file(msg, logging.INFO)
154150

155-
def run_coefficient_test(self, test: dict, f_flag: bool):
151+
def _run_coefficient_test(self, test: dict, f_flag: bool, effects: dict):
156152
"""Builds structures and runs test case for tests with an estimate_type of 'coefficient'.
157153
158154
:param test: Single JSON test definition stored in a mapping (dict)
159155
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
156+
:param effects: Dictionary mapping effect class instances to string representations.
160157
:return: String containing the message to be outputted
161158
"""
162159
base_test_case = BaseTestCase(
@@ -167,7 +164,7 @@ def run_coefficient_test(self, test: dict, f_flag: bool):
167164
assert len(test["expected_effect"]) == 1, "Can only have one expected effect."
168165
causal_test_case = CausalTestCase(
169166
base_test_case=base_test_case,
170-
expected_causal_effect=next(self.effects[effect] for variable, effect in test["expected_effect"].items()),
167+
expected_causal_effect=next(effects[effect] for variable, effect in test["expected_effect"].items()),
171168
estimate_type="coefficient",
172169
effect_modifier_configuration={self.scenario.variables[v] for v in test.get("effect_modifiers", [])},
173170
)
@@ -182,11 +179,13 @@ def run_coefficient_test(self, test: dict, f_flag: bool):
182179
)
183180
return msg
184181

185-
def run_ate_test(self, test: dict, f_flag: bool):
182+
def _run_ate_test(self, test: dict, f_flag: bool, effects: dict, mutates: dict):
186183
"""Builds structures and runs test case for tests with an estimate_type of 'ate'.
187184
188185
:param test: Single JSON test definition stored in a mapping (dict)
189186
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
187+
:param effects: Dictionary mapping effect class instances to string representations.
188+
:param mutates: Dictionary mapping mutation functions to string representations.
190189
:return: String containing the message to be outputted
191190
"""
192191
if "sample_size" in test:
@@ -197,7 +196,7 @@ def run_ate_test(self, test: dict, f_flag: bool):
197196
target_ks_score = test["target_ks_score"]
198197
else:
199198
target_ks_score = 0.05
200-
abstract_test = self._create_abstract_test_case(test, self.mutates, self.effects)
199+
abstract_test = self._create_abstract_test_case(test, mutates, effects)
201200
concrete_tests, _ = abstract_test.generate_concrete_tests(
202201
sample_size=sample_size, target_ks_score=target_ks_score
203202
)

0 commit comments

Comments
 (0)