Skip to content

Commit 413d3b0

Browse files
Merge branch 'main' into json-cate
# Conflicts: # causal_testing/json_front/json_class.py # tests/json_front_tests/test_json_class.py
2 parents 08f5437 + 72c3997 commit 413d3b0

File tree

7 files changed

+181
-91
lines changed

7 files changed

+181
-91
lines changed

causal_testing/json_front/json_class.py

Lines changed: 100 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from causal_testing.specification.causal_specification import CausalSpecification
2121
from causal_testing.specification.scenario import Scenario
2222
from causal_testing.specification.variable import Input, Meta, Output
23+
from causal_testing.testing.base_test_case import BaseTestCase
2324
from causal_testing.testing.causal_test_case import CausalTestCase
2425
from causal_testing.testing.causal_test_engine import CausalTestEngine
2526
from causal_testing.testing.estimators import Estimator
@@ -47,7 +48,7 @@ class JsonUtility:
4748

4849
def __init__(self, output_path: str, output_overwrite: bool = False):
4950
self.input_paths = None
50-
self.variables = None
51+
self.variables = {"inputs": {}, "outputs": {}, "metas": {}}
5152
self.data = []
5253
self.test_plan = None
5354
self.scenario = None
@@ -67,6 +68,7 @@ def set_paths(self, json_path: str, dag_path: str, data_paths: str):
6768
def setup(self, scenario: Scenario):
6869
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
6970
self.scenario = scenario
71+
self._get_scenario_variables()
7072
self.scenario.setup_treatment_variables()
7173
self.causal_specification = CausalSpecification(
7274
scenario=self.scenario, causal_dag=CausalDAG(self.input_paths.dag_path)
@@ -100,7 +102,7 @@ def _create_abstract_test_case(self, test, mutates, effects):
100102
)
101103
return abstract_test
102104

103-
def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag: bool):
105+
def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False, mutates: dict = None):
104106
"""Runs and evaluates each test case specified in the JSON input
105107
106108
:param effects: Dictionary mapping effect class instances to string representations.
@@ -113,49 +115,96 @@ def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
113115
for test in self.test_plan["tests"]:
114116
if "skip" in test and test["skip"]:
115117
continue
118+
test["estimator"] = estimators[test["estimator"]]
119+
if "mutations" in test:
120+
if test["estimate_type"] == "coefficient":
121+
base_test_case = BaseTestCase(
122+
treatment_variable=next(self.scenario.variables[v] for v in test["mutations"]),
123+
outcome_variable=next(self.scenario.variables[v] for v in test["expectedEffect"]),
124+
effect=test.get("effect", "direct"),
125+
)
126+
assert len(test["expectedEffect"]) == 1, "Can only have one expected effect."
127+
concrete_tests = [
128+
CausalTestCase(
129+
base_test_case=base_test_case,
130+
expected_causal_effect=next(
131+
effects[effect] for variable, effect in test["expectedEffect"].items()
132+
),
133+
estimate_type="coefficient",
134+
effect_modifier_configuration={
135+
self.scenario.variables[v] for v in test.get("effect_modifiers", [])
136+
},
137+
)
138+
]
139+
failures = self._execute_tests(concrete_tests, estimators, test, f_flag)
140+
msg = (
141+
f"Executing test: {test['name']} \n"
142+
+ f" {concrete_tests[0]} \n"
143+
+ f" {failures}/{len(concrete_tests)} failed for {test['name']}"
144+
)
145+
else:
146+
abstract_test = self._create_abstract_test_case(test, mutates, effects)
147+
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)
148+
failures = self._execute_tests(concrete_tests, estimators, test, f_flag)
149+
msg = (
150+
f"Executing test: {test['name']} \n"
151+
+ " abstract_test \n"
152+
+ f" {abstract_test} \n"
153+
+ f" {abstract_test.treatment_variable.name},{abstract_test.treatment_variable.distribution} \n"
154+
+ f" Number of concrete tests for test case: {str(len(concrete_tests))} \n"
155+
+ f" {failures}/{len(concrete_tests)} failed for {test['name']}"
156+
)
157+
self._append_to_file(msg, logging.INFO)
158+
else:
159+
outcome_variable = next(
160+
iter(test["expected_effect"])
161+
) # Take first key from dictionary of expected effect
162+
base_test_case = BaseTestCase(
163+
treatment_variable=self.variables["inputs"][test["treatment_variable"]],
164+
outcome_variable=self.variables["outputs"][outcome_variable],
165+
)
116166

117-
if test["estimate_type"] == "coefficient":
118-
base_test_case = BaseTestCase(
119-
treatment_variable=next(self.scenario.variables[v] for v in test["mutations"]),
120-
outcome_variable=next(self.scenario.variables[v] for v in test["expectedEffect"]),
121-
effect=test.get("effect", "direct"),
122-
)
123-
assert len(test["expectedEffect"]) == 1, "Can only have one expected effect."
124-
concrete_tests = [
125-
CausalTestCase(
167+
causal_test_case = CausalTestCase(
126168
base_test_case=base_test_case,
127-
expected_causal_effect=next(
128-
effects[effect] for variable, effect in test["expectedEffect"].items()
129-
),
130-
estimate_type="coefficient",
131-
effect_modifier_configuration={
132-
self.scenario.variables[v] for v in test.get("effect_modifiers", [])
133-
},
169+
expected_causal_effect=effects[test["expected_effect"][outcome_variable]],
170+
control_value=test["control_value"],
171+
treatment_value=test["treatment_value"],
172+
estimate_type=test["estimate_type"],
134173
)
135-
]
136-
failures = self._execute_tests(concrete_tests, estimators, test, f_flag)
137-
msg = (
138-
f"Executing test: {test['name']} \n"
139-
+ f" {concrete_tests[0]} \n"
140-
+ f" {failures}/{len(concrete_tests)} failed for {test['name']}"
141-
)
142-
else:
143-
abstract_test = self._create_abstract_test_case(test, mutates, effects)
144-
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)
145-
failures = self._execute_tests(concrete_tests, estimators, test, f_flag)
146-
msg = (
147-
f"Executing test: {test['name']} \n"
148-
+ " abstract_test \n"
149-
+ f" {abstract_test} \n"
150-
+ f" {abstract_test.treatment_variable.name},{abstract_test.treatment_variable.distribution} \n"
151-
+ f" Number of concrete tests for test case: {str(len(concrete_tests))} \n"
152-
+ f" {failures}/{len(concrete_tests)} failed for {test['name']}"
153-
)
154-
self._append_to_file(msg, logging.INFO)
174+
if self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag):
175+
result = "failed"
176+
else:
177+
result = "passed"
178+
179+
msg = (
180+
f"Executing concrete test: {test['name']} \n"
181+
+ f"treatment variable: {test['treatment_variable']} \n"
182+
+ f"outcome_variable = {outcome_variable} \n"
183+
+ f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
184+
+ f"result - {result}"
185+
)
186+
self._append_to_file(msg, logging.INFO)
187+
188+
def _create_abstract_test_case(self, test, mutates, effects):
189+
assert len(test["mutations"]) == 1
190+
abstract_test = AbstractCausalTestCase(
191+
scenario=self.scenario,
192+
intervention_constraints=[mutates[v](k) for k, v in test["mutations"].items()],
193+
treatment_variable=next(self.scenario.variables[v] for v in test["mutations"]),
194+
expected_causal_effect={
195+
self.scenario.variables[variable]: effects[effect]
196+
for variable, effect in test["expected_effect"].items()
197+
},
198+
effect_modifiers={self.scenario.variables[v] for v in test["effect_modifiers"]}
199+
if "effect_modifiers" in test
200+
else {},
201+
estimate_type=test["estimate_type"],
202+
effect=test.get("effect", "total"),
203+
)
204+
return abstract_test
155205

156-
def _execute_tests(self, concrete_tests, estimators, test, f_flag):
206+
def _execute_tests(self, concrete_tests, test, f_flag):
157207
failures = 0
158-
test["estimator"] = estimators[test["estimator"]]
159208
if "formula" in test:
160209
self._append_to_file(f"Estimator formula used for test: {test['formula']}")
161210
for concrete_test in concrete_tests:
@@ -206,7 +255,6 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, test: Iterable[Ma
206255

207256
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
208257

209-
result_string = str()
210258
if causal_test_result.ci_low() and causal_test_result.ci_high():
211259
result_string = (
212260
f"{causal_test_result.ci_low()} < {causal_test_result.test_value.value} < "
@@ -226,7 +274,7 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, test: Iterable[Ma
226274
return failed
227275

228276
def _setup_test(
229-
self, causal_test_case: CausalTestCase, test: Mapping, conditions: list[str] = None
277+
self, causal_test_case: CausalTestCase, test: Mapping, conditions: list[str] = None
230278
) -> tuple[CausalTestEngine, Estimator]:
231279
"""Create the necessary inputs for a single test case
232280
:param causal_test_case: The concrete test case to be executed
@@ -258,7 +306,6 @@ def _setup_test(
258306
}
259307
if "formula" in test:
260308
estimator_kwargs["formula"] = test["formula"]
261-
262309
estimation_model = test["estimator"](**estimator_kwargs)
263310
return causal_test_engine, estimation_model
264311

@@ -270,10 +317,18 @@ def _append_to_file(self, line: str, log_level: int = None):
270317
is possible to use the inbuilt logging level variables such as logging.INFO and logging.WARNING
271318
"""
272319
with open(self.output_path, "a", encoding="utf-8") as f:
273-
f.write(line)
320+
f.write(line + "\n")
274321
if log_level:
275322
logger.log(level=log_level, msg=line)
276323

324+
def _get_scenario_variables(self):
325+
for input_var in self.scenario.inputs():
326+
self.variables["inputs"][input_var.name] = input_var
327+
for output_var in self.scenario.outputs():
328+
self.variables["outputs"][output_var.name] = output_var
329+
for meta_var in self.scenario.metas():
330+
self.variables["metas"][meta_var.name] = meta_var
331+
277332
@staticmethod
278333
def check_file_exists(output_path: Path, overwrite: bool):
279334
"""Method that checks if the given path to an output file already exists. If overwrite is true the check is
@@ -304,7 +359,7 @@ def get_args(test_args=None) -> argparse.Namespace:
304359
parser.add_argument(
305360
"-w",
306361
help="Specify to overwrite any existing output files. This can lead to the loss of existing outputs if not "
307-
"careful",
362+
"careful",
308363
action="store_true",
309364
)
310365
parser.add_argument(

docs/source/frontends/json_front_end.rst

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,28 @@ Use case specific information is also declared here such as the paths to the rel
2121

2222
causal_tests.json
2323
-----------------
24-
`examples/poisson/causal_tests.json <https://github.com/CITCOM-project/CausalTestingFramework/blob/main/examples/poisson/causal_tests.json>`_ contains python code written by the user to implement scenario specific features
25-
is the JSON file that allows for the easy specification of multiple causal tests.
24+
`examples/poisson/causal_tests.json <https://github.c#om/CITCOM-project/CausalTestingFramework/blob/main/examples/poisson/causal_tests.json>`_ contains python code written by the user to implement scenario specific features
25+
is the JSON file that allows for the easy specification of multiple causal tests. Tests can be specified two ways; firstly by specifying a mutation lke in the example tests with the following structure:
2626
Each test requires:
27-
1. Test name
28-
2. Mutations
29-
3. Estimator
30-
4. Estimate_type
31-
5. Effect modifiers
32-
6. Expected effects
33-
7. Skip: boolean that if set true the test won't be executed and will be skipped
3427

28+
#. name
29+
#. mutations
30+
#. estimator
31+
#. estimate_type
32+
#. effect_modifiers
33+
#. expected_effects
34+
#. skip: boolean that if set true the test won't be executed and will be skipped
35+
36+
The second method of specifying a test is to specify the test in a concrete form with the following structure:
37+
38+
#. name
39+
#. treatment_variable
40+
#. control_value
41+
#. treatment_value
42+
#. estimator
43+
#. estimate_type
44+
#. expected_effect
45+
#. skip
3546

3647
Run Commands
3748
------------

examples/poisson/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ To run this case study:
66
1. Ensure all project dependencies are installed by running `pip install .` in the top level directory
77
(instructions are provided in the project README).
88
2. Change directory to `causal_testing/examples/poisson`.
9-
3. Run the command `python test_run_causal_tests.py --data_path data.csv --dag_path dag.dot --json_path causal_tests.json`
9+
3. Run the command `python example_run_causal_tests.py --data_path data.csv --dag_path dag.dot --json_path causal_tests.json`
1010

1111
This should print a series of causal test results and produce two CSV files. `intensity_num_shapes_results_random_1000.csv` corresponds to table 1, and `width_num_shapes_results_random_1000.csv` relates to our findings regarding the relationship of width and `P_u`.

0 commit comments

Comments
 (0)