Skip to content

Commit 6e02f2a

Browse files
refactor abstract test case generation
1 parent e6900b5 commit 6e02f2a

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

causal_testing/json_front/json_class.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,22 @@ def setup(self):
8888
self._json_parse()
8989
self._populate_metas()
9090

91+
def _create_abstract_test_case(self, test, mutates, effects):
92+
abstract_test = AbstractCausalTestCase(
93+
scenario=self.modelling_scenario,
94+
intervention_constraints=mutates,
95+
treatment_variables={self.modelling_scenario.variables[v] for v in test["mutations"]},
96+
expected_causal_effect={
97+
self.modelling_scenario.variables[variable]: effects[effect]
98+
for variable, effect in test["expectedEffect"].items()
99+
},
100+
effect_modifiers={self.modelling_scenario.variables[v] for v in test["effect_modifiers"]}
101+
if "effect_modifiers" in test
102+
else {},
103+
estimate_type=test["estimate_type"],
104+
)
105+
return abstract_test
106+
91107
def execute_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag: bool):
92108
"""Runs and evaluates each test case specified in the JSON input
93109
@@ -102,18 +118,10 @@ def execute_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
102118
if "skip" in test and test["skip"]:
103119
continue
104120

105-
abstract_test = AbstractCausalTestCase(
106-
scenario=self.modelling_scenario,
107-
intervention_constraints=[mutates[v](k) for k, v in test["mutations"].items()],
108-
treatment_variables={self.modelling_scenario.variables[v] for v in test["mutations"]},
109-
expected_causal_effect={
110-
self.modelling_scenario.variables[variable]: effects[effect]
111-
for variable, effect in test["expectedEffect"].items()
112-
},
113-
effect_modifiers={self.modelling_scenario.variables[v] for v in test["effect_modifiers"]}
114-
if "effect_modifiers" in test
115-
else {},
116-
estimate_type=test["estimate_type"],
121+
abstract_test = self._create_abstract_test_case(
122+
test,
123+
[mutates[v](k) for k, v in test["mutations"].items()],
124+
effects
117125
)
118126

119127
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)

0 commit comments

Comments
 (0)