Skip to content

Commit 7453ee8

Browse files
Add concrete test params to test json
1 parent 304853e commit 7453ee8

File tree

1 file changed

+61
-50
lines changed

1 file changed

+61
-50
lines changed

causal_testing/json_front/json_class.py

Lines changed: 61 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -104,48 +104,7 @@ def _create_abstract_test_case(self, test, mutates, effects):
104104
effect=test.get("effect", "total"),
105105
)
106106
return abstract_test
107-
def run_coefficient_test(self, test, f_flag):
108-
base_test_case = BaseTestCase(
109-
treatment_variable=next(self.scenario.variables[v] for v in test["mutations"]),
110-
outcome_variable=next(self.scenario.variables[v] for v in test["expected_effect"]),
111-
effect=test.get("effect", "direct"),
112-
)
113-
assert len(test["expected_effect"]) == 1, "Can only have one expected effect."
114-
causal_test_case = CausalTestCase(
115-
base_test_case=base_test_case,
116-
expected_causal_effect=next(
117-
self.effects[effect] for variable, effect in test["expected_effect"].items()
118-
),
119-
estimate_type="coefficient",
120-
effect_modifier_configuration={
121-
self.scenario.variables[v] for v in test.get("effect_modifiers", [])
122-
},
123-
)
124-
result = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
125-
msg = (
126-
f"Executing test: {test['name']} \n"
127-
+ f" {causal_test_case} \n"
128-
+ " "
129-
+ ("\n ").join(str(result[1]).split("\n"))
130-
+ "==============\n"
131-
+ f" Result: {'FAILED' if result[0] else 'Passed'}"
132-
)
133-
return msg
134-
def run_ate_test(self, test, f_flag):
135-
abstract_test = self._create_abstract_test_case(test, self.mutates, self.effects)
136-
concrete_tests, _ = abstract_test.generate_concrete_tests(sample_size=5, target_ks_score=0.05)
137-
failures, _ = self._execute_tests(concrete_tests, test, f_flag)
138107

139-
msg = (
140-
f"Executing test: {test['name']} \n"
141-
+ " abstract_test \n"
142-
+ f" {abstract_test} \n"
143-
+ f" {abstract_test.treatment_variable.name},"
144-
+ f" {abstract_test.treatment_variable.distribution} \n"
145-
+ f" Number of concrete tests for test case: {str(len(concrete_tests))} \n"
146-
+ f" {failures}/{len(concrete_tests)} failed for {test['name']}"
147-
)
148-
return msg
149108
def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False, mutates: dict = None):
150109
"""Runs and evaluates each test case specified in the JSON input
151110
@@ -164,7 +123,7 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
164123
if test["estimate_type"] == "coefficient":
165124
msg = self.run_coefficient_test(test=test, f_flag=f_flag)
166125
else:
167-
msg = self.run_ate_test(test=test, f_flag=f_flag)
126+
msg = self.run_ate_test(test=test, f_flag=f_flag)
168127
self._append_to_file(msg, logging.INFO)
169128
else:
170129
outcome_variable = next(
@@ -185,14 +144,66 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
185144
failed, _ = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
186145

187146
msg = (
188-
f"Executing concrete test: {test['name']} \n"
189-
+ f"treatment variable: {test['treatment_variable']} \n"
190-
+ f"outcome_variable = {outcome_variable} \n"
191-
+ f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
192-
+ f"Result: {'FAILED' if failed else 'Passed'}"
147+
f"Executing concrete test: {test['name']} \n"
148+
+ f"treatment variable: {test['treatment_variable']} \n"
149+
+ f"outcome_variable = {outcome_variable} \n"
150+
+ f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
151+
+ f"Result: {'FAILED' if failed else 'Passed'}"
193152
)
194153
self._append_to_file(msg, logging.INFO)
195154

155+
def run_coefficient_test(self, test, f_flag):
156+
base_test_case = BaseTestCase(
157+
treatment_variable=next(self.scenario.variables[v] for v in test["mutations"]),
158+
outcome_variable=next(self.scenario.variables[v] for v in test["expected_effect"]),
159+
effect=test.get("effect", "direct"),
160+
)
161+
assert len(test["expected_effect"]) == 1, "Can only have one expected effect."
162+
causal_test_case = CausalTestCase(
163+
base_test_case=base_test_case,
164+
expected_causal_effect=next(
165+
self.effects[effect] for variable, effect in test["expected_effect"].items()
166+
),
167+
estimate_type="coefficient",
168+
effect_modifier_configuration={
169+
self.scenario.variables[v] for v in test.get("effect_modifiers", [])
170+
},
171+
)
172+
result = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
173+
msg = (
174+
f"Executing test: {test['name']} \n"
175+
+ f" {causal_test_case} \n"
176+
+ " "
177+
+ ("\n ").join(str(result[1]).split("\n"))
178+
+ "==============\n"
179+
+ f" Result: {'FAILED' if result[0] else 'Passed'}"
180+
)
181+
return msg
182+
183+
def run_ate_test(self, test, f_flag):
184+
if "sample_size" in test:
185+
sample_size = test["sample_size"]
186+
else:
187+
sample_size = 5
188+
if "target_ks_score" in test:
189+
target_ks_score = test["target_ks_score"]
190+
else:
191+
target_ks_score = 0.05
192+
abstract_test = self._create_abstract_test_case(test, self.mutates, self.effects)
193+
concrete_tests, _ = abstract_test.generate_concrete_tests(sample_size=sample_size, target_ks_score=target_ks_score)
194+
failures, _ = self._execute_tests(concrete_tests, test, f_flag)
195+
196+
msg = (
197+
f"Executing test: {test['name']} \n"
198+
+ " abstract_test \n"
199+
+ f" {abstract_test} \n"
200+
+ f" {abstract_test.treatment_variable.name},"
201+
+ f" {abstract_test.treatment_variable.distribution} \n"
202+
+ f" Number of concrete tests for test case: {str(len(concrete_tests))} \n"
203+
+ f" {failures}/{len(concrete_tests)} failed for {test['name']}"
204+
)
205+
return msg
206+
196207
def _execute_tests(self, concrete_tests, test, f_flag):
197208
failures = 0
198209
details = []
@@ -222,7 +233,7 @@ def _populate_metas(self):
222233
meta.populate(self.data)
223234

224235
def _execute_test_case(
225-
self, causal_test_case: CausalTestCase, test: Iterable[Mapping], f_flag: bool
236+
self, causal_test_case: CausalTestCase, test: Iterable[Mapping], f_flag: bool
226237
) -> (bool, CausalTestResult):
227238
"""Executes a singular test case, prints the results and returns the test case result
228239
:param causal_test_case: The concrete test case to be executed
@@ -262,7 +273,7 @@ def _execute_test_case(
262273
return failed, causal_test_result
263274

264275
def _setup_test(
265-
self, causal_test_case: CausalTestCase, test: Mapping, conditions: list[str] = None
276+
self, causal_test_case: CausalTestCase, test: Mapping, conditions: list[str] = None
266277
) -> tuple[CausalTestEngine, Estimator]:
267278
"""Create the necessary inputs for a single test case
268279
:param causal_test_case: The concrete test case to be executed
@@ -347,7 +358,7 @@ def get_args(test_args=None) -> argparse.Namespace:
347358
parser.add_argument(
348359
"-w",
349360
help="Specify to overwrite any existing output files. This can lead to the loss of existing outputs if not "
350-
"careful",
361+
"careful",
351362
action="store_true",
352363
)
353364
parser.add_argument(

0 commit comments

Comments
 (0)