Skip to content

Commit 0c9bdd3

Browse files
docstrings + black
1 parent 7453ee8 commit 0c9bdd3

File tree

1 file changed

+41
-31
lines changed

1 file changed

+41
-31
lines changed

causal_testing/json_front/json_class.py

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -144,15 +144,21 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
144144
failed, _ = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
145145

146146
msg = (
147-
f"Executing concrete test: {test['name']} \n"
148-
+ f"treatment variable: {test['treatment_variable']} \n"
149-
+ f"outcome_variable = {outcome_variable} \n"
150-
+ f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
151-
+ f"Result: {'FAILED' if failed else 'Passed'}"
147+
f"Executing concrete test: {test['name']} \n"
148+
+ f"treatment variable: {test['treatment_variable']} \n"
149+
+ f"outcome_variable = {outcome_variable} \n"
150+
+ f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
151+
+ f"Result: {'FAILED' if failed else 'Passed'}"
152152
)
153153
self._append_to_file(msg, logging.INFO)
154154

155-
def run_coefficient_test(self, test, f_flag):
155+
def run_coefficient_test(self, test: dict, f_flag: bool):
156+
"""Builds structures and runs test case for tests with an estimate_type of 'coefficient'.
157+
158+
:param test: Single JSON test definition stored in a mapping (dict)
159+
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
160+
:return: String containing the message to be outputted
161+
"""
156162
base_test_case = BaseTestCase(
157163
treatment_variable=next(self.scenario.variables[v] for v in test["mutations"]),
158164
outcome_variable=next(self.scenario.variables[v] for v in test["expected_effect"]),
@@ -161,26 +167,28 @@ def run_coefficient_test(self, test, f_flag):
161167
assert len(test["expected_effect"]) == 1, "Can only have one expected effect."
162168
causal_test_case = CausalTestCase(
163169
base_test_case=base_test_case,
164-
expected_causal_effect=next(
165-
self.effects[effect] for variable, effect in test["expected_effect"].items()
166-
),
170+
expected_causal_effect=next(self.effects[effect] for variable, effect in test["expected_effect"].items()),
167171
estimate_type="coefficient",
168-
effect_modifier_configuration={
169-
self.scenario.variables[v] for v in test.get("effect_modifiers", [])
170-
},
172+
effect_modifier_configuration={self.scenario.variables[v] for v in test.get("effect_modifiers", [])},
171173
)
172174
result = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
173175
msg = (
174-
f"Executing test: {test['name']} \n"
175-
+ f" {causal_test_case} \n"
176-
+ " "
177-
+ ("\n ").join(str(result[1]).split("\n"))
178-
+ "==============\n"
179-
+ f" Result: {'FAILED' if result[0] else 'Passed'}"
176+
f"Executing test: {test['name']} \n"
177+
+ f" {causal_test_case} \n"
178+
+ " "
179+
+ ("\n ").join(str(result[1]).split("\n"))
180+
+ "==============\n"
181+
+ f" Result: {'FAILED' if result[0] else 'Passed'}"
180182
)
181183
return msg
182184

183-
def run_ate_test(self, test, f_flag):
185+
def run_ate_test(self, test: dict, f_flag: bool):
186+
"""Builds structures and runs test case for tests with an estimate_type of 'ate'.
187+
188+
:param test: Single JSON test definition stored in a mapping (dict)
189+
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
190+
:return: String containing the message to be outputted
191+
"""
184192
if "sample_size" in test:
185193
sample_size = test["sample_size"]
186194
else:
@@ -190,17 +198,19 @@ def run_ate_test(self, test, f_flag):
190198
else:
191199
target_ks_score = 0.05
192200
abstract_test = self._create_abstract_test_case(test, self.mutates, self.effects)
193-
concrete_tests, _ = abstract_test.generate_concrete_tests(sample_size=sample_size, target_ks_score=target_ks_score)
201+
concrete_tests, _ = abstract_test.generate_concrete_tests(
202+
sample_size=sample_size, target_ks_score=target_ks_score
203+
)
194204
failures, _ = self._execute_tests(concrete_tests, test, f_flag)
195205

196206
msg = (
197-
f"Executing test: {test['name']} \n"
198-
+ " abstract_test \n"
199-
+ f" {abstract_test} \n"
200-
+ f" {abstract_test.treatment_variable.name},"
201-
+ f" {abstract_test.treatment_variable.distribution} \n"
202-
+ f" Number of concrete tests for test case: {str(len(concrete_tests))} \n"
203-
+ f" {failures}/{len(concrete_tests)} failed for {test['name']}"
207+
f"Executing test: {test['name']} \n"
208+
+ " abstract_test \n"
209+
+ f" {abstract_test} \n"
210+
+ f" {abstract_test.treatment_variable.name},"
211+
+ f" {abstract_test.treatment_variable.distribution} \n"
212+
+ f" Number of concrete tests for test case: {str(len(concrete_tests))} \n"
213+
+ f" {failures}/{len(concrete_tests)} failed for {test['name']}"
204214
)
205215
return msg
206216

@@ -233,7 +243,7 @@ def _populate_metas(self):
233243
meta.populate(self.data)
234244

235245
def _execute_test_case(
236-
self, causal_test_case: CausalTestCase, test: Iterable[Mapping], f_flag: bool
246+
self, causal_test_case: CausalTestCase, test: Iterable[Mapping], f_flag: bool
237247
) -> (bool, CausalTestResult):
238248
"""Executes a singular test case, prints the results and returns the test case result
239249
:param causal_test_case: The concrete test case to be executed
@@ -273,11 +283,11 @@ def _execute_test_case(
273283
return failed, causal_test_result
274284

275285
def _setup_test(
276-
self, causal_test_case: CausalTestCase, test: Mapping, conditions: list[str] = None
286+
self, causal_test_case: CausalTestCase, test: Mapping, conditions: list[str] = None
277287
) -> tuple[CausalTestEngine, Estimator]:
278288
"""Create the necessary inputs for a single test case
279289
:param causal_test_case: The concrete test case to be executed
280-
:param test: Single JSON test definition stored in a mapping (dict)
290+
`:param test: Single JSON test definition stored in a mapping (dict)`
281291
:param conditions: A list of conditions which should be applied to the
282292
data. Conditions should be in the query format detailed at
283293
https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.query.html
@@ -358,7 +368,7 @@ def get_args(test_args=None) -> argparse.Namespace:
358368
parser.add_argument(
359369
"-w",
360370
help="Specify to overwrite any existing output files. This can lead to the loss of existing outputs if not "
361-
"careful",
371+
"careful",
362372
action="store_true",
363373
)
364374
parser.add_argument(

0 commit comments

Comments
 (0)