Skip to content

Commit 9a55b0c

Browse files
solve positivity error
1 parent 6d9f053 commit 9a55b0c

File tree

2 files changed

+11
-13
lines changed

2 files changed

+11
-13
lines changed

causal_testing/json_front/json_class.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def setup(self):
9191
def _create_abstract_test_case(self, test, mutates, effects):
9292
abstract_test = AbstractCausalTestCase(
9393
scenario=self.modelling_scenario,
94-
intervention_constraints=mutates,
94+
intervention_constraints=[mutates[v](k) for k, v in test["mutations"].items()],
9595
treatment_variables={self.modelling_scenario.variables[v] for v in test["mutations"]},
9696
expected_causal_effect={
9797
self.modelling_scenario.variables[variable]: effects[effect]
@@ -116,9 +116,8 @@ def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
116116
for test in self.test_plan["tests"]:
117117
if "skip" in test and test["skip"]:
118118
continue
119-
mutation = [mutates[v](k) for k, v in test["mutations"].items()]
120119
abstract_test = self._create_abstract_test_case(
121-
test, mutation, effects
120+
test, mutates, effects
122121
)
123122

124123
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)

tests/json_front_tests/test_json_class.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,13 @@ def test_execute_concrete_test(self):
111111
"skip": False,
112112
}
113113
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
114-
abstract_test_case = self.json_class._create_abstract_test_case(example_test, [mutates[v](k) for k, v in example_test["mutations"].items()], effects)
115-
concrete_tests = abstract_test_case.generate_concrete_tests(1, 0.5)
114+
abstract_test_case = self.json_class._create_abstract_test_case(example_test, mutates, effects)
115+
concrete_tests, dummy = abstract_test_case.generate_concrete_tests(5, 0.5)
116116
self.json_class._execute_tests(concrete_tests, estimators, example_test, False)
117117

118118
def tearDown(self) -> None:
119-
remove_temp_dir_if_existent()
120-
119+
#remove_temp_dir_if_existent()
120+
pass
121121

122122
def populate_example(*args, **kwargs):
123123
pass
@@ -144,16 +144,15 @@ def setup_json_file(json_path):
144144

145145

146146
def setup_data_file(data_path):
147-
header = ["test_input", "test_output"]
148-
data = [1, 2]
149-
with open(data_path, "w") as f:
147+
header = ["index", "test_input", "test_output"]
148+
data = [0, 1, 2]
149+
with open(data_path, "w", newline='') as f:
150150
writer = csv.writer(f)
151151
writer.writerow(header)
152152
writer.writerow(data)
153-
154-
153+
"""digraph G { A -> B; B -> C; D -> A; D -> C}"""
155154
def setup_dag_file(dag_path):
156-
dag_dot = "digraph G {A->B}"
155+
dag_dot = """digraph G { test_input -> temp; temp -> test_output; temp2 -> test_input; temp2 -> test_output}"""
157156
with open(dag_path, "w") as f:
158157
f.write(dag_dot)
159158
f.close()

0 commit comments

Comments
 (0)