Skip to content

Commit 4a6958b

Browse files
Merge pull request #67 from CITCOM-project/json_poisson_example
Json front end
2 parents 57e8007 + e29ebf0 commit 4a6958b

File tree

11 files changed

+1705
-4
lines changed

11 files changed

+1705
-4
lines changed

causal_testing/data_collection/data_collector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def filter_valid_data(self, data: pd.DataFrame, check_pos: bool = True) -> pd.Da
4343
solver = z3.Solver()
4444
for c in self.scenario.constraints:
4545
solver.assert_and_track(c, f"background: {c}")
46-
sat = []
46+
sat = list()
4747
unsat_core = None
4848
for _, row in data.iterrows():
4949
solver.push()

causal_testing/json_front/README.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# JSON Causal Testing Framework Frontend
2+
3+
The JSON frontend allows Causal Tests and parameters to be specified in JSON to allow for tests to be quickly written
4+
whilst retaining the flexibility of the Causal Testing Framework (CTF).
5+
6+
An example is provided in `examples/poisson` which will be walked through in this README to better understand
7+
the framework
8+
9+
`examples/poisson/run_causal_tests.py` contains python code written by the user to implement scenario specific features
10+
such as:
11+
1. Custom Estimators
12+
2. Causal Variable specification
13+
3. Causal test case outcomes
14+
4. Meta constraint functions
15+
5. Mapping JSON distributions, effects, and estimators to python objects
16+
17+
Use case specific information is also declared here such as the paths to the relevant files needed for the tests.
18+
19+
`examples/poisson/causal_tests.json` is the JSON file that allows for the easy specification of multiple causal tests.
20+
Each test requires:
21+
1. Test name
22+
2. Mutations
23+
3. Estimator
24+
4. Estimate_type
25+
5. Effect modifiers
26+
6. Expected effects
27+
7. Skip: boolean that if set true the test won't be executed and will be skipped
28+
29+
To run the JSON frontend example from the root directory of the project, use
30+
`python examples/poisson/run_causal_tests.py"`
31+
32+
A failure flag `-f` can be specified to stop the framework running if a test is failed
33+
`python examples/poisson/run_causal_tests.py -f"`
34+

causal_testing/json_front/__init__.py

Whitespace-only changes.
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
from pathlib import Path
2+
3+
from abc import ABC
4+
import json
5+
from fitter import Fitter, get_common_distributions
6+
import pandas as pd
7+
import scipy
8+
9+
from causal_testing.specification.variable import Input, Output, Meta
10+
from causal_testing.specification.scenario import Scenario
11+
from causal_testing.specification.causal_dag import CausalDAG
12+
from causal_testing.specification.causal_specification import CausalSpecification
13+
from causal_testing.generation.abstract_causal_test_case import AbstractCausalTestCase
14+
from causal_testing.data_collection.data_collector import ObservationalDataCollector
15+
from causal_testing.testing.causal_test_engine import CausalTestEngine
16+
from causal_testing.testing.causal_test_case import CausalTestCase
17+
from causal_testing.testing.estimators import Estimator
18+
19+
20+
class JsonUtility(ABC):
21+
"""
22+
The JsonUtility Class provides the functionality to use structured JSON to setup and run causal tests on the
23+
CausalTestingFramework.
24+
25+
:attr {Path} json_path: Path to the JSON input file.
26+
:attr {Path} dag_path: Path to the dag.dot file containing the Causal DAG.
27+
:attr {Path} data_path: Path to the csv data file.
28+
:attr {Input} inputs: Causal variables representing inputs.
29+
:attr {Output} outputs: Causal variables representing outputs.
30+
:attr {Meta} metas: Causal variables representing metavariables.
31+
:attr {pd.DataFrame}: Pandas DataFrame containing runtime data.
32+
:attr {dict} test_plan: Dictionary containing the key value pairs from the loaded json test plan.
33+
:attr {Scenario} modelling_scenario:
34+
:attr {CausalSpecification} causal_specification:
35+
"""
36+
37+
def __init__(self):
38+
self.json_path = None
39+
self.dag_path = None
40+
self.data_path = None
41+
self.inputs = None
42+
self.outputs = None
43+
self.metas = None
44+
self.data = None
45+
self.test_plan = None
46+
self.modelling_scenario = None
47+
self.causal_specification = None
48+
49+
def set_path(self, json_path: str, dag_path: str, data_path: str):
50+
"""
51+
Takes a path of the directory containing all scenario specific files and creates individual paths for each file
52+
:param json_path: string path representation to .json file containing test specifications
53+
:param dag_path: string path representation to the .dot file containing the Causal DAG
54+
:param data_path: string path representation to the data file
55+
:returns:
56+
- json_path -
57+
- dag_path -
58+
- data_path -
59+
"""
60+
self.json_path = Path(json_path)
61+
self.dag_path = Path(dag_path)
62+
self.data_path = Path(data_path)
63+
64+
def set_variables(self, inputs: dict, outputs: dict, metas: dict, distributions: dict, populates: dict):
65+
""" Populate the Causal Variables
66+
:param inputs:
67+
:param outputs:
68+
:param metas:
69+
:param distributions:
70+
:param populates:
71+
"""
72+
self.inputs = [Input(i['name'], i['type'], distributions[i['distribution']]) for i in
73+
inputs]
74+
self.outputs = [Output(i['name'], i['type']) for i in outputs]
75+
self.metas = [Meta(i['name'], i['type'], populates[i['populate']]) for i in
76+
metas] if metas else list()
77+
78+
def setup(self):
79+
""" Function to populate all the necessary parts of the json_class needed to execute tests
80+
"""
81+
self.modelling_scenario = Scenario(self.inputs + self.outputs + self.metas, None)
82+
self.modelling_scenario.setup_treatment_variables()
83+
self.causal_specification = CausalSpecification(scenario=self.modelling_scenario,
84+
causal_dag=CausalDAG(self.dag_path))
85+
self._json_parse()
86+
self._populate_metas()
87+
88+
def execute_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag: bool):
89+
""" Runs and evaluates each test case specified in the JSON input
90+
91+
:param effects: Dictionary mapping effect class instances to string representations.
92+
:param mutates: Dictionary mapping mutation functions to string representations.
93+
:param estimators: Dictionary mapping estimator classes to string representations.
94+
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
95+
"""
96+
executed_tests = 0
97+
failures = 0
98+
for test in self.test_plan['tests']:
99+
if "skip" in test and test['skip']:
100+
continue
101+
102+
abstract_test = AbstractCausalTestCase(
103+
scenario=self.modelling_scenario,
104+
intervention_constraints=[mutates[v](k) for k, v in test['mutations'].items()],
105+
treatment_variables={self.modelling_scenario.variables[v] for v in test['mutations']},
106+
expected_causal_effect={self.modelling_scenario.variables[variable]: effects[effect] for
107+
variable, effect
108+
in
109+
test["expectedEffect"].items()},
110+
effect_modifiers={self.modelling_scenario.variables[v] for v in
111+
test['effect_modifiers']} if "effect_modifiers" in test else {},
112+
estimate_type=test['estimate_type']
113+
)
114+
115+
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)
116+
print(abstract_test)
117+
print([(v.name, v.distribution) for v in abstract_test.treatment_variables])
118+
print(len(concrete_tests))
119+
for concrete_test in concrete_tests:
120+
executed_tests += 1
121+
failed = self._execute_test_case(concrete_test, estimators[test['estimator']], f_flag)
122+
if failed:
123+
failures += 1
124+
125+
print(f"{failures}/{executed_tests} failed")
126+
127+
def _json_parse(self):
128+
"""Parse a JSON input file into inputs, outputs, metas and a test plan
129+
:param distributions: dictionary of user defined scipy distributions
130+
:param populates: dictionary of user defined populate functions
131+
"""
132+
with open(self.json_path) as f:
133+
self.test_plan = json.load(f)
134+
135+
self.data = pd.read_csv(self.data_path)
136+
137+
def _populate_metas(self):
138+
"""
139+
Populate data with meta-variable values and add distributions to Causal Testing Framework Variables
140+
"""
141+
142+
for meta in self.metas:
143+
meta.populate(self.data)
144+
145+
for var in self.metas + self.outputs:
146+
f = Fitter(self.data[var.name], distributions=get_common_distributions())
147+
f.fit()
148+
(dist, params) = list(f.get_best(method="sumsquare_error").items())[0]
149+
var.distribution = getattr(scipy.stats, dist)(**params)
150+
print(var.name, f"{dist}({params})")
151+
152+
def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estimator, f_flag: bool) -> bool:
153+
""" Executes a singular test case, prints the results and returns the test case result
154+
:param causal_test_case: The concrete test case to be executed
155+
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
156+
:return: A boolean that if True indicates the causal test case passed and if false indicates the test case failed.
157+
:rtype: bool
158+
"""
159+
failed = False
160+
161+
causal_test_engine, estimation_model = self._setup_test(causal_test_case, estimator)
162+
causal_test_result = causal_test_engine.execute_test(estimation_model,
163+
estimate_type=causal_test_case.estimate_type)
164+
165+
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
166+
167+
result_string = str()
168+
if causal_test_result.ci_low() and causal_test_result.ci_high():
169+
result_string = f"{causal_test_result.ci_low()} < {causal_test_result.ate} < {causal_test_result.ci_high()}"
170+
else:
171+
result_string = causal_test_result.ate
172+
if f_flag:
173+
assert test_passes, f"{causal_test_case}\n FAILED - expected {causal_test_case.expected_causal_effect}, " \
174+
f"got {result_string}"
175+
if not test_passes:
176+
failed = True
177+
print(f" FAILED - expected {causal_test_case.expected_causal_effect}, got {causal_test_result.ate}")
178+
return failed
179+
180+
def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) -> tuple[CausalTestEngine, Estimator]:
181+
""" Create the necessary inputs for a single test case
182+
:param causal_test_case: The concrete test case to be executed
183+
:returns:
184+
- causal_test_engine - Test Engine instance for the test being run
185+
- estimation_model - Estimator instance for the test being run
186+
"""
187+
data_collector = ObservationalDataCollector(self.modelling_scenario, self.data_path)
188+
causal_test_engine = CausalTestEngine(causal_test_case, self.causal_specification, data_collector)
189+
minimal_adjustment_set = causal_test_engine.load_data(index_col=0)
190+
treatment_vars = list(causal_test_case.treatment_input_configuration)
191+
minimal_adjustment_set = minimal_adjustment_set - {v.name for v in treatment_vars}
192+
estimation_model = estimator((list(treatment_vars)[0].name,),
193+
[causal_test_case.treatment_input_configuration[v] for v in treatment_vars][0],
194+
[causal_test_case.control_input_configuration[v] for v in treatment_vars][0],
195+
minimal_adjustment_set,
196+
(list(causal_test_case.outcome_variables)[0].name,),
197+
causal_test_engine.scenario_execution_data_df,
198+
effect_modifiers=causal_test_case.effect_modifier_configuration
199+
)
200+
201+
self.add_modelling_assumptions(estimation_model)
202+
203+
return causal_test_engine, estimation_model
204+
205+
def add_modelling_assumptions(self, estimation_model: Estimator):
206+
""" Optional abstract method where user functionality can be written to determine what assumptions are required
207+
for specific test cases
208+
:param estimation_model: estimator model instance for the current running test.
209+
"""
210+
return

causal_testing/testing/causal_test_outcome.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ def to_dict(self):
5555
base_dict["ci_high"] = max(self.confidence_intervals)
5656
return base_dict
5757

58-
5958
def ci_low(self):
6059
"""Return the lower bracket of the confidence intervals."""
6160
if not self.confidence_intervals:

docs/source/index.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ Causal testing is a causal inference-driven framework for functional black-box t
2121

2222
/autoapi/causal_testing/index
2323

24+
.. toctree::
25+
:maxdepth: 1
26+
:caption: Examples
27+
28+
json_frontend
29+
2430
Indices and tables
2531
==================
2632

0 commit comments

Comments
 (0)