Skip to content

Commit 3be14bc

Browse files
Merge pull request #189 from CITCOM-project/json_concrete_param
Expose concrete test generation parameters in JSON Frontend
2 parents 0a5c924 + f95a6f1 commit 3be14bc

File tree

2 files changed

+99
-41
lines changed

2 files changed

+99
-41
lines changed

causal_testing/json_front/json_class.py

Lines changed: 68 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -123,54 +123,15 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
123123
:param estimators: Dictionary mapping estimator classes to string representations.
124124
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
125125
"""
126-
failures = 0
127-
msg = ""
128126
for test in self.test_plan["tests"]:
129127
if "skip" in test and test["skip"]:
130128
continue
131129
test["estimator"] = estimators[test["estimator"]]
132130
if "mutations" in test:
133131
if test["estimate_type"] == "coefficient":
134-
base_test_case = BaseTestCase(
135-
treatment_variable=next(self.scenario.variables[v] for v in test["mutations"]),
136-
outcome_variable=next(self.scenario.variables[v] for v in test["expected_effect"]),
137-
effect=test.get("effect", "direct"),
138-
)
139-
assert len(test["expected_effect"]) == 1, "Can only have one expected effect."
140-
causal_test_case = CausalTestCase(
141-
base_test_case=base_test_case,
142-
expected_causal_effect=next(
143-
effects[effect] for variable, effect in test["expected_effect"].items()
144-
),
145-
estimate_type="coefficient",
146-
effect_modifier_configuration={
147-
self.scenario.variables[v] for v in test.get("effect_modifiers", [])
148-
},
149-
)
150-
result = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
151-
msg = (
152-
f"Executing test: {test['name']} \n"
153-
+ f" {causal_test_case} \n"
154-
+ " "
155-
+ ("\n ").join(str(result[1]).split("\n"))
156-
+ "==============\n"
157-
+ f" Result: {'FAILED' if result[0] else 'Passed'}"
158-
)
159-
print(msg)
132+
msg = self._run_coefficient_test(test=test, f_flag=f_flag, effects=effects)
160133
else:
161-
abstract_test = self._create_abstract_test_case(test, mutates, effects)
162-
concrete_tests, _ = abstract_test.generate_concrete_tests(5, 0.05)
163-
failures, _ = self._execute_tests(concrete_tests, test, f_flag)
164-
165-
msg = (
166-
f"Executing test: {test['name']} \n"
167-
+ " abstract_test \n"
168-
+ f" {abstract_test} \n"
169-
+ f" {abstract_test.treatment_variable.name},"
170-
+ f" {abstract_test.treatment_variable.distribution} \n"
171-
+ f" Number of concrete tests for test case: {str(len(concrete_tests))} \n"
172-
+ f" {failures}/{len(concrete_tests)} failed for {test['name']}"
173-
)
134+
msg = self._run_ate_test(test=test, f_flag=f_flag, effects=effects, mutates=mutates)
174135
self._append_to_file(msg, logging.INFO)
175136
else:
176137
outcome_variable = next(
@@ -197,8 +158,74 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
197158
+ f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
198159
+ f"Result: {'FAILED' if failed else 'Passed'}"
199160
)
161+
print(msg)
200162
self._append_to_file(msg, logging.INFO)
201163

164+
def _run_coefficient_test(self, test: dict, f_flag: bool, effects: dict):
165+
"""Builds structures and runs test case for tests with an estimate_type of 'coefficient'.
166+
167+
:param test: Single JSON test definition stored in a mapping (dict)
168+
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
169+
:param effects: Dictionary mapping effect class instances to string representations.
170+
:return: String containing the message to be outputted
171+
"""
172+
base_test_case = BaseTestCase(
173+
treatment_variable=next(self.scenario.variables[v] for v in test["mutations"]),
174+
outcome_variable=next(self.scenario.variables[v] for v in test["expected_effect"]),
175+
effect=test.get("effect", "direct"),
176+
)
177+
assert len(test["expected_effect"]) == 1, "Can only have one expected effect."
178+
causal_test_case = CausalTestCase(
179+
base_test_case=base_test_case,
180+
expected_causal_effect=next(effects[effect] for variable, effect in test["expected_effect"].items()),
181+
estimate_type="coefficient",
182+
effect_modifier_configuration={self.scenario.variables[v] for v in test.get("effect_modifiers", [])},
183+
)
184+
result = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
185+
msg = (
186+
f"Executing test: {test['name']} \n"
187+
+ f" {causal_test_case} \n"
188+
+ " "
189+
+ ("\n ").join(str(result[1]).split("\n"))
190+
+ "==============\n"
191+
+ f" Result: {'FAILED' if result[0] else 'Passed'}"
192+
)
193+
return msg
194+
195+
def _run_ate_test(self, test: dict, f_flag: bool, effects: dict, mutates: dict):
196+
"""Builds structures and runs test case for tests with an estimate_type of 'ate'.
197+
198+
:param test: Single JSON test definition stored in a mapping (dict)
199+
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
200+
:param effects: Dictionary mapping effect class instances to string representations.
201+
:param mutates: Dictionary mapping mutation functions to string representations.
202+
:return: String containing the message to be outputted
203+
"""
204+
if "sample_size" in test:
205+
sample_size = test["sample_size"]
206+
else:
207+
sample_size = 5
208+
if "target_ks_score" in test:
209+
target_ks_score = test["target_ks_score"]
210+
else:
211+
target_ks_score = 0.05
212+
abstract_test = self._create_abstract_test_case(test, mutates, effects)
213+
concrete_tests, _ = abstract_test.generate_concrete_tests(
214+
sample_size=sample_size, target_ks_score=target_ks_score
215+
)
216+
failures, _ = self._execute_tests(concrete_tests, test, f_flag)
217+
218+
msg = (
219+
f"Executing test: {test['name']} \n"
220+
+ " abstract_test \n"
221+
+ f" {abstract_test} \n"
222+
+ f" {abstract_test.treatment_variable.name},"
223+
+ f" {abstract_test.treatment_variable.distribution} \n"
224+
+ f" Number of concrete tests for test case: {str(len(concrete_tests))} \n"
225+
+ f" {failures}/{len(concrete_tests)} failed for {test['name']}"
226+
)
227+
return msg
228+
202229
def _execute_tests(self, concrete_tests, test, f_flag):
203230
failures = 0
204231
details = []

tests/json_front_tests/test_json_class.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,37 @@ def test_run_concrete_json_testcase(self):
239239
with open("temp_out.txt", "r") as reader:
240240
temp_out = reader.readlines()
241241
self.assertIn("FAILED", temp_out[-1])
242+
def test_concrete_generate_params(self):
243+
example_test = {
244+
"tests": [
245+
{
246+
"name": "test1",
247+
"mutations": {"test_input": "Increase"},
248+
"estimator": "LinearRegressionEstimator",
249+
"estimate_type": "ate",
250+
"effect_modifiers": [],
251+
"expected_effect": {"test_output": "NoEffect"},
252+
"sample_size": 5,
253+
"target_ks_score": 0.05,
254+
"skip": False,
255+
}
256+
]
257+
}
258+
self.json_class.test_plan = example_test
259+
effects = {"NoEffect": NoEffect()}
260+
mutates = {
261+
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
262+
> self.json_class.scenario.variables[x].z3
263+
}
264+
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
265+
266+
self.json_class.run_json_tests(effects=effects, estimators=estimators, f_flag=False, mutates=mutates)
267+
268+
# Test that the final log message prints that failed tests are printed, which is expected behaviour for this
269+
# scenario
270+
with open("temp_out.txt", "r") as reader:
271+
temp_out = reader.readlines()
272+
self.assertIn("failed", temp_out[-1])
242273

243274
def test_no_data_provided(self):
244275
example_test = {

0 commit comments

Comments
 (0)