Skip to content

Commit 2f6d774

Browse files
committed
Merge branch 'main' of github.com:CITCOM-project/CausalTestingFramework into functional_form
2 parents 7079204 + d22709d commit 2f6d774

File tree

17 files changed

+226
-203
lines changed

17 files changed

+226
-203
lines changed

.github/workflows/publish-to-pypi.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
name: Publish python PyPI
22

3+
on:
4+
push:
5+
tags:
6+
- v*
7+
38
jobs:
49
build-release:
510
name: Build and publish PyPI
@@ -17,6 +22,8 @@ jobs:
1722
pip3 install .
1823
pip3 install .[pypi]
1924
pip3 install build
25+
pip3 install setuptools --upgrade
26+
pip3 install setuptools_scm
2027
- name: Build Package
2128
run: |
2229
python -m build --no-isolation

.github/workflows/publish-to-test-pypi.yaml

Lines changed: 0 additions & 33 deletions
This file was deleted.

.pylintrc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ disable=raw-checker-failed,
153153
deprecated-pragma,
154154
use-symbolic-message-instead,
155155
logging-fstring-interpolation,
156+
import-error,
156157

157158
# Enable the message, report, category or checker with the given id(s). You can
158159
# either give multiple identifier separated by comma (,) or put this option

causal_testing/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
"""
2+
This is the CausalTestingFramework Module
3+
It contains 5 subpackages:
4+
data_collection
5+
generation
6+
json_front
7+
specification
8+
testing
9+
"""
10+
111
import logging
212

313
logger = logging.getLogger(__name__)

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class AbstractCausalTestCase:
2828
"""
2929

3030
def __init__(
31+
# pylint: disable=too-many-arguments
3132
self,
3233
scenario: Scenario,
3334
intervention_constraints: set[z3.ExprRef],
@@ -77,7 +78,11 @@ def sanitise(string):
7778
)
7879

7980
def _generate_concrete_tests(
80-
self, sample_size: int, rct: bool = False, seed: int = 0
81+
# pylint: disable=too-many-locals
82+
self,
83+
sample_size: int,
84+
rct: bool = False,
85+
seed: int = 0,
8186
) -> tuple[list[CausalTestCase], pd.DataFrame]:
8287
"""Generates a list of `num` concrete test cases.
8388
@@ -151,6 +156,7 @@ def _generate_concrete_tests(
151156
return concrete_tests, pd.DataFrame(runs, columns=run_columns + ["bin"])
152157

153158
def generate_concrete_tests(
159+
# pylint: disable=too-many-arguments, too-many-locals
154160
self,
155161
sample_size: int,
156162
target_ks_score: float = None,

causal_testing/json_front/json_class.py

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
"""This module contains the JsonUtility class, details of using this class can be found here:
22
https://causal-testing-framework.readthedocs.io/en/latest/json_front_end.html"""
3+
34
import argparse
45
import json
56
import logging
67

78
from abc import ABC
9+
from dataclasses import dataclass
810
from pathlib import Path
911

1012
import pandas as pd
@@ -42,49 +44,38 @@ class JsonUtility(ABC):
4244
"""
4345

4446
def __init__(self, log_path):
45-
self.json_path = None
46-
self.dag_path = None
47-
self.data_path = None
48-
self.inputs = None
49-
self.outputs = None
50-
self.metas = None
47+
self.paths = None
48+
self.variables = None
5149
self.data = None
5250
self.test_plan = None
5351
self.modelling_scenario = None
5452
self.causal_specification = None
5553
self.setup_logger(log_path)
5654

57-
def set_path(self, json_path: str, dag_path: str, data_path: str):
55+
def set_paths(self, json_path: str, dag_path: str, data_path: str):
5856
"""
5957
Takes a path of the directory containing all scenario specific files and creates individual paths for each file
6058
:param json_path: string path representation to .json file containing test specifications
6159
:param dag_path: string path representation to the .dot file containing the Causal DAG
6260
:param data_path: string path representation to the data file
63-
:returns:
64-
- json_path -
65-
- dag_path -
66-
- data_path -
6761
"""
68-
self.json_path = Path(json_path)
69-
self.dag_path = Path(dag_path)
70-
self.data_path = Path(data_path)
62+
self.paths = JsonClassPaths(json_path=json_path, dag_path=dag_path, data_path=data_path)
7163

72-
def set_variables(self, inputs: dict, outputs: dict, metas: dict):
64+
def set_variables(self, inputs: list[dict], outputs: list[dict], metas: list[dict]):
7365
"""Populate the Causal Variables
7466
:param inputs:
7567
:param outputs:
7668
:param metas:
7769
"""
78-
self.inputs = [Input(i["name"], i["type"], i["distribution"]) for i in inputs]
79-
self.outputs = [Output(i["name"], i["type"], i.get("distribution", None)) for i in outputs]
80-
self.metas = [Meta(i["name"], i["type"], i["populate"]) for i in metas] if metas else []
70+
71+
self.variables = CausalVariables(inputs=inputs, outputs=outputs, metas=metas)
8172

8273
def setup(self):
8374
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
84-
self.modelling_scenario = Scenario(self.inputs + self.outputs + self.metas, None)
75+
self.modelling_scenario = Scenario(self.variables.inputs + self.variables.outputs + self.variables.metas, None)
8576
self.modelling_scenario.setup_treatment_variables()
8677
self.causal_specification = CausalSpecification(
87-
scenario=self.modelling_scenario, causal_dag=CausalDAG(self.dag_path)
78+
scenario=self.modelling_scenario, causal_dag=CausalDAG(self.paths.dag_path)
8879
)
8980
self._json_parse()
9081
self._populate_metas()
@@ -139,20 +130,20 @@ def _execute_tests(self, concrete_tests, estimators, test, f_flag):
139130

140131
def _json_parse(self):
141132
"""Parse a JSON input file into inputs, outputs, metas and a test plan"""
142-
with open(self.json_path, encoding="utf-8") as f:
133+
with open(self.paths.json_path, encoding="utf-8") as f:
143134
self.test_plan = json.load(f)
144135

145-
self.data = pd.read_csv(self.data_path)
136+
self.data = pd.read_csv(self.paths.data_path)
146137

147138
def _populate_metas(self):
148139
"""
149140
Populate data with meta-variable values and add distributions to Causal Testing Framework Variables
150141
"""
151142

152-
for meta in self.metas:
143+
for meta in self.variables.metas:
153144
meta.populate(self.data)
154145

155-
for var in self.metas + self.outputs:
146+
for var in self.variables.metas + self.variables.outputs:
156147
if not var.distribution:
157148
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
158149
fitter.fit()
@@ -202,7 +193,7 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
202193
- causal_test_engine - Test Engine instance for the test being run
203194
- estimation_model - Estimator instance for the test being run
204195
"""
205-
data_collector = ObservationalDataCollector(self.modelling_scenario, self.data_path)
196+
data_collector = ObservationalDataCollector(self.modelling_scenario, self.paths.data_path)
206197
causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
207198
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
208199
treatment_var = causal_test_case.treatment_variable
@@ -273,3 +264,38 @@ def get_args(test_args=None) -> argparse.Namespace:
273264
required=True,
274265
)
275266
return parser.parse_args(test_args)
267+
268+
269+
@dataclass
270+
class JsonClassPaths:
271+
"""
272+
A dataclass that converts strings of paths to Path objects for use in the JsonUtility class
273+
:param json_path: string path representation to .json file containing test specifications
274+
:param dag_path: string path representation to the .dot file containing the Causal DAG
275+
:param data_path: string path representation to the data file
276+
"""
277+
278+
json_path: Path
279+
dag_path: Path
280+
data_path: Path
281+
282+
def __init__(self, json_path: str, dag_path: str, data_path: str):
283+
self.json_path = Path(json_path)
284+
self.dag_path = Path(dag_path)
285+
self.data_path = Path(data_path)
286+
287+
288+
@dataclass()
289+
class CausalVariables:
290+
"""
291+
A dataclass that converts
292+
"""
293+
294+
inputs: list[Input]
295+
outputs: list[Output]
296+
metas: list[Meta]
297+
298+
def __init__(self, inputs: list[dict], outputs: list[dict], metas: list[dict]):
299+
self.inputs = [Input(i["name"], i["type"], i["distribution"]) for i in inputs]
300+
self.outputs = [Output(i["name"], i["type"], i.get("distribution", None)) for i in outputs]
301+
self.metas = [Meta(i["name"], i["type"], i["populate"]) for i in metas] if metas else []

causal_testing/specification/causal_dag.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,19 +150,19 @@ def check_iv_assumptions(self, treatment, outcome, instrument) -> bool:
150150
raise ValueError(f"Instrument {instrument} is not associated with treatment {treatment} in the DAG")
151151

152152
# (ii) Instrument does not affect outcome except through its potential effect on treatment
153-
if not all([treatment in path for path in nx.all_simple_paths(self.graph, source=instrument, target=outcome)]):
153+
if not all((treatment in path for path in nx.all_simple_paths(self.graph, source=instrument, target=outcome))):
154154
raise ValueError(
155155
f"Instrument {instrument} affects the outcome {outcome} other than through the treatment {treatment}"
156156
)
157157

158158
# (iii) Instrument and outcome do not share causes
159159
if any(
160-
[
160+
(
161161
cause
162162
for cause in self.graph.nodes
163163
if list(nx.all_simple_paths(self.graph, source=cause, target=instrument))
164164
and list(nx.all_simple_paths(self.graph, source=cause, target=outcome))
165-
]
165+
)
166166
):
167167
raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes")
168168

causal_testing/specification/variable.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010
import lhsmdu
1111
from pandas import DataFrame
1212
from scipy.stats._distn_infrastructure import rv_generic
13-
from z3 import Bool, BoolRef, Const, EnumSort, Int, RatNumRef, Real, String, DatatypeRef
13+
from z3 import Bool, BoolRef, Const, EnumSort, Int, RatNumRef, Real, String
1414

1515
# Declare type variable
1616
T = TypeVar("T")
17-
Z3 = TypeVar("Z3")
17+
z3 = TypeVar("Z3")
1818

1919

20-
def z3_types(datatype: T) -> Z3:
20+
def z3_types(datatype: T) -> z3:
2121
"""Cast datatype to Z3 datatype
2222
:param datatype: python datatype to be cast
2323
:return: Type name compatible with Z3 library
@@ -76,7 +76,6 @@ def __init__(self, name: str, datatype: T, distribution: rv_generic = None):
7676
def __repr__(self):
7777
return f"{self.typestring()}: {self.name}::{self.datatype.__name__}"
7878

79-
# TODO: We're going to need to implement all the supported Z3 operations like this
8079
def __ge__(self, other: Any) -> BoolRef:
8180
"""Create the Z3 expression `other >= self`.
8281
@@ -167,8 +166,6 @@ def cast(self, val: Any) -> T:
167166
return val.as_string()
168167
if (isinstance(val, (float, int, bool))) and (self.datatype in (float, int, bool)):
169168
return self.datatype(val)
170-
if issubclass(self.datatype, Enum) and isinstance(val, DatatypeRef):
171-
return self.datatype(str(val))
172169
return self.datatype(str(val))
173170

174171
def z3_val(self, z3_var, val: Any) -> T:

causal_testing/testing/causal_test_case.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111

1212
class CausalTestCase:
13+
# pylint: disable=too-many-instance-attributes
1314
"""
1415
A CausalTestCase extends the information held in a BaseTestCase. As well as storing the treatment and outcome
1516
variables, a CausalTestCase stores the values of these variables. Also the outcome variable and value are
@@ -22,6 +23,7 @@ class CausalTestCase:
2223
"""
2324

2425
def __init__(
26+
# pylint: disable=too-many-arguments
2527
self,
2628
base_test_case: BaseTestCase,
2729
expected_causal_effect: CausalTestOutcome,

0 commit comments

Comments
 (0)