Skip to content

Commit d22709d

Browse files
Merge pull request #142 from CITCOM-project/pylint_refactoring
Pylint refactoring
2 parents bbe2300 + d745bde commit d22709d

File tree

13 files changed

+138
-131
lines changed

13 files changed

+138
-131
lines changed

.pylintrc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ disable=raw-checker-failed,
153153
deprecated-pragma,
154154
use-symbolic-message-instead,
155155
logging-fstring-interpolation,
156+
import-error,
156157

157158
# Enable the message, report, category or checker with the given id(s). You can
158159
# either give multiple identifier separated by comma (,) or put this option

causal_testing/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
"""
2+
This is the CausalTestingFramework Module
3+
It contains 5 subpackages:
4+
data_collection
5+
generation
6+
json_front
7+
specification
8+
testing
9+
"""
10+
111
import logging
212

313
logger = logging.getLogger(__name__)

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class AbstractCausalTestCase:
2828
"""
2929

3030
def __init__(
31+
# pylint: disable=too-many-arguments
3132
self,
3233
scenario: Scenario,
3334
intervention_constraints: set[z3.ExprRef],
@@ -77,7 +78,11 @@ def sanitise(string):
7778
)
7879

7980
def _generate_concrete_tests(
80-
self, sample_size: int, rct: bool = False, seed: int = 0
81+
# pylint: disable=too-many-locals
82+
self,
83+
sample_size: int,
84+
rct: bool = False,
85+
seed: int = 0,
8186
) -> tuple[list[CausalTestCase], pd.DataFrame]:
8287
"""Generates a list of `num` concrete test cases.
8388
@@ -151,6 +156,7 @@ def _generate_concrete_tests(
151156
return concrete_tests, pd.DataFrame(runs, columns=run_columns + ["bin"])
152157

153158
def generate_concrete_tests(
159+
# pylint: disable=too-many-arguments, too-many-locals
154160
self,
155161
sample_size: int,
156162
target_ks_score: float = None,

causal_testing/json_front/json_class.py

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
"""This module contains the JsonUtility class, details of using this class can be found here:
22
https://causal-testing-framework.readthedocs.io/en/latest/json_front_end.html"""
3+
34
import argparse
45
import json
56
import logging
67

78
from abc import ABC
9+
from dataclasses import dataclass
810
from pathlib import Path
911

1012
import pandas as pd
@@ -42,49 +44,38 @@ class JsonUtility(ABC):
4244
"""
4345

4446
def __init__(self, log_path):
45-
self.json_path = None
46-
self.dag_path = None
47-
self.data_path = None
48-
self.inputs = None
49-
self.outputs = None
50-
self.metas = None
47+
self.paths = None
48+
self.variables = None
5149
self.data = None
5250
self.test_plan = None
5351
self.modelling_scenario = None
5452
self.causal_specification = None
5553
self.setup_logger(log_path)
5654

57-
def set_path(self, json_path: str, dag_path: str, data_path: str):
55+
def set_paths(self, json_path: str, dag_path: str, data_path: str):
5856
"""
5957
Takes a path of the directory containing all scenario specific files and creates individual paths for each file
6058
:param json_path: string path representation to .json file containing test specifications
6159
:param dag_path: string path representation to the .dot file containing the Causal DAG
6260
:param data_path: string path representation to the data file
63-
:returns:
64-
- json_path -
65-
- dag_path -
66-
- data_path -
6761
"""
68-
self.json_path = Path(json_path)
69-
self.dag_path = Path(dag_path)
70-
self.data_path = Path(data_path)
62+
self.paths = JsonClassPaths(json_path=json_path, dag_path=dag_path, data_path=data_path)
7163

72-
def set_variables(self, inputs: dict, outputs: dict, metas: dict):
64+
def set_variables(self, inputs: list[dict], outputs: list[dict], metas: list[dict]):
7365
"""Populate the Causal Variables
7466
:param inputs:
7567
:param outputs:
7668
:param metas:
7769
"""
78-
self.inputs = [Input(i["name"], i["type"], i["distribution"]) for i in inputs]
79-
self.outputs = [Output(i["name"], i["type"], i.get("distribution", None)) for i in outputs]
80-
self.metas = [Meta(i["name"], i["type"], i["populate"]) for i in metas] if metas else []
70+
71+
self.variables = CausalVariables(inputs=inputs, outputs=outputs, metas=metas)
8172

8273
def setup(self):
8374
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
84-
self.modelling_scenario = Scenario(self.inputs + self.outputs + self.metas, None)
75+
self.modelling_scenario = Scenario(self.variables.inputs + self.variables.outputs + self.variables.metas, None)
8576
self.modelling_scenario.setup_treatment_variables()
8677
self.causal_specification = CausalSpecification(
87-
scenario=self.modelling_scenario, causal_dag=CausalDAG(self.dag_path)
78+
scenario=self.modelling_scenario, causal_dag=CausalDAG(self.paths.dag_path)
8879
)
8980
self._json_parse()
9081
self._populate_metas()
@@ -139,20 +130,20 @@ def _execute_tests(self, concrete_tests, estimators, test, f_flag):
139130

140131
def _json_parse(self):
141132
"""Parse a JSON input file into inputs, outputs, metas and a test plan"""
142-
with open(self.json_path, encoding="utf-8") as f:
133+
with open(self.paths.json_path, encoding="utf-8") as f:
143134
self.test_plan = json.load(f)
144135

145-
self.data = pd.read_csv(self.data_path)
136+
self.data = pd.read_csv(self.paths.data_path)
146137

147138
def _populate_metas(self):
148139
"""
149140
Populate data with meta-variable values and add distributions to Causal Testing Framework Variables
150141
"""
151142

152-
for meta in self.metas:
143+
for meta in self.variables.metas:
153144
meta.populate(self.data)
154145

155-
for var in self.metas + self.outputs:
146+
for var in self.variables.metas + self.variables.outputs:
156147
if not var.distribution:
157148
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
158149
fitter.fit()
@@ -202,7 +193,7 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
202193
- causal_test_engine - Test Engine instance for the test being run
203194
- estimation_model - Estimator instance for the test being run
204195
"""
205-
data_collector = ObservationalDataCollector(self.modelling_scenario, self.data_path)
196+
data_collector = ObservationalDataCollector(self.modelling_scenario, self.paths.data_path)
206197
causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
207198
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
208199
treatment_var = causal_test_case.treatment_variable
@@ -273,3 +264,38 @@ def get_args(test_args=None) -> argparse.Namespace:
273264
required=True,
274265
)
275266
return parser.parse_args(test_args)
267+
268+
269+
@dataclass
270+
class JsonClassPaths:
271+
"""
272+
A dataclass that converts strings of paths to Path objects for use in the JsonUtility class
273+
:param json_path: string path representation to .json file containing test specifications
274+
:param dag_path: string path representation to the .dot file containing the Causal DAG
275+
:param data_path: string path representation to the data file
276+
"""
277+
278+
json_path: Path
279+
dag_path: Path
280+
data_path: Path
281+
282+
def __init__(self, json_path: str, dag_path: str, data_path: str):
283+
self.json_path = Path(json_path)
284+
self.dag_path = Path(dag_path)
285+
self.data_path = Path(data_path)
286+
287+
288+
@dataclass()
289+
class CausalVariables:
290+
"""
291+
A dataclass that converts
292+
"""
293+
294+
inputs: list[Input]
295+
outputs: list[Output]
296+
metas: list[Meta]
297+
298+
def __init__(self, inputs: list[dict], outputs: list[dict], metas: list[dict]):
299+
self.inputs = [Input(i["name"], i["type"], i["distribution"]) for i in inputs]
300+
self.outputs = [Output(i["name"], i["type"], i.get("distribution", None)) for i in outputs]
301+
self.metas = [Meta(i["name"], i["type"], i["populate"]) for i in metas] if metas else []

causal_testing/specification/causal_dag.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,19 +150,19 @@ def check_iv_assumptions(self, treatment, outcome, instrument) -> bool:
150150
raise ValueError(f"Instrument {instrument} is not associated with treatment {treatment} in the DAG")
151151

152152
# (ii) Instrument does not affect outcome except through its potential effect on treatment
153-
if not all([treatment in path for path in nx.all_simple_paths(self.graph, source=instrument, target=outcome)]):
153+
if not all((treatment in path for path in nx.all_simple_paths(self.graph, source=instrument, target=outcome))):
154154
raise ValueError(
155155
f"Instrument {instrument} affects the outcome {outcome} other than through the treatment {treatment}"
156156
)
157157

158158
# (iii) Instrument and outcome do not share causes
159159
if any(
160-
[
160+
(
161161
cause
162162
for cause in self.graph.nodes
163163
if list(nx.all_simple_paths(self.graph, source=cause, target=instrument))
164164
and list(nx.all_simple_paths(self.graph, source=cause, target=outcome))
165-
]
165+
)
166166
):
167167
raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes")
168168

causal_testing/specification/variable.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010
import lhsmdu
1111
from pandas import DataFrame
1212
from scipy.stats._distn_infrastructure import rv_generic
13-
from z3 import Bool, BoolRef, Const, EnumSort, Int, RatNumRef, Real, String, DatatypeRef
13+
from z3 import Bool, BoolRef, Const, EnumSort, Int, RatNumRef, Real, String
1414

1515
# Declare type variable
1616
T = TypeVar("T")
17-
Z3 = TypeVar("Z3")
17+
z3 = TypeVar("Z3")
1818

1919

20-
def z3_types(datatype: T) -> Z3:
20+
def z3_types(datatype: T) -> z3:
2121
"""Cast datatype to Z3 datatype
2222
:param datatype: python datatype to be cast
2323
:return: Type name compatible with Z3 library
@@ -76,7 +76,6 @@ def __init__(self, name: str, datatype: T, distribution: rv_generic = None):
7676
def __repr__(self):
7777
return f"{self.typestring()}: {self.name}::{self.datatype.__name__}"
7878

79-
# TODO: We're going to need to implement all the supported Z3 operations like this
8079
def __ge__(self, other: Any) -> BoolRef:
8180
"""Create the Z3 expression `other >= self`.
8281
@@ -167,8 +166,6 @@ def cast(self, val: Any) -> T:
167166
return val.as_string()
168167
if (isinstance(val, (float, int, bool))) and (self.datatype in (float, int, bool)):
169168
return self.datatype(val)
170-
if issubclass(self.datatype, Enum) and isinstance(val, DatatypeRef):
171-
return self.datatype(str(val))
172169
return self.datatype(str(val))
173170

174171
def z3_val(self, z3_var, val: Any) -> T:

causal_testing/testing/causal_test_case.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111

1212
class CausalTestCase:
13+
# pylint: disable=too-many-instance-attributes
1314
"""
1415
A CausalTestCase extends the information held in a BaseTestCase. As well as storing the treatment and outcome
1516
variables, a CausalTestCase stores the values of these variables. Also the outcome variable and value are
@@ -22,6 +23,7 @@ class CausalTestCase:
2223
"""
2324

2425
def __init__(
26+
# pylint: disable=too-many-arguments
2527
self,
2628
base_test_case: BaseTestCase,
2729
expected_causal_effect: CausalTestOutcome,

causal_testing/testing/causal_test_engine.py

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResu
6060
causal_test_result objects
6161
"""
6262
if self.scenario_execution_data_df.empty:
63-
raise Exception("No data has been loaded. Please call load_data prior to executing a causal test case.")
63+
raise ValueError("No data has been loaded. Please call load_data prior to executing a causal test case.")
6464
test_suite_results = {}
6565
for edge in test_suite:
6666
print("edge: ")
@@ -75,7 +75,7 @@ def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResu
7575
list(minimal_adjustment_set) + [edge.treatment_variable.name] + [edge.outcome_variable.name]
7676
)
7777
if self._check_positivity_violation(variables_for_positivity):
78-
raise Exception("POSITIVITY VIOLATION -- Cannot proceed.")
78+
raise ValueError("POSITIVITY VIOLATION -- Cannot proceed.")
7979

8080
estimators = test_suite[edge]["estimators"]
8181
tests = test_suite[edge]["tests"]
@@ -85,13 +85,10 @@ def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResu
8585
causal_test_results = []
8686

8787
for test in tests:
88-
treatment_variable = test.treatment_variable
89-
treatment_value = test.treatment_value
90-
control_value = test.control_value
9188
estimator = estimator_class(
92-
treatment_variable.name,
93-
treatment_value,
94-
control_value,
89+
test.treatment_variable.name,
90+
test.treatment_value,
91+
test.control_value,
9592
minimal_adjustment_set,
9693
test.outcome_variable.name,
9794
)
@@ -125,7 +122,7 @@ def execute_test(
125122
:return causal_test_result: A CausalTestResult for the executed causal test case.
126123
"""
127124
if self.scenario_execution_data_df.empty:
128-
raise Exception("No data has been loaded. Please call load_data prior to executing a causal test case.")
125+
raise ValueError("No data has been loaded. Please call load_data prior to executing a causal test case.")
129126
if estimator.df is None:
130127
estimator.df = self.scenario_execution_data_df
131128
treatment_variable = causal_test_case.treatment_variable
@@ -141,7 +138,7 @@ def execute_test(
141138
variables_for_positivity = list(minimal_adjustment_set) + [treatment_variable.name] + [outcome_variable.name]
142139

143140
if self._check_positivity_violation(variables_for_positivity):
144-
raise Exception("POSITIVITY VIOLATION -- Cannot proceed.")
141+
raise ValueError("POSITIVITY VIOLATION -- Cannot proceed.")
145142

146143
causal_test_result = self._return_causal_test_results(estimate_type, estimator, causal_test_case)
147144
return causal_test_result
@@ -161,11 +158,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
161158

162159
cates_df, confidence_intervals = estimator.estimate_cates()
163160
causal_test_result = CausalTestResult(
164-
treatment=estimator.treatment,
165-
outcome=estimator.outcome,
166-
treatment_value=estimator.treatment_value,
167-
control_value=estimator.control_value,
168-
adjustment_set=estimator.adjustment_set,
161+
estimator=estimator,
169162
test_value=TestValue("ate", cates_df),
170163
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
171164
confidence_intervals=confidence_intervals,
@@ -174,11 +167,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
174167
logger.debug("calculating risk_ratio")
175168
risk_ratio, confidence_intervals = estimator.estimate_risk_ratio()
176169
causal_test_result = CausalTestResult(
177-
treatment=estimator.treatment,
178-
outcome=estimator.outcome,
179-
treatment_value=estimator.treatment_value,
180-
control_value=estimator.control_value,
181-
adjustment_set=estimator.adjustment_set,
170+
estimator=estimator,
182171
test_value=TestValue("risk_ratio", risk_ratio),
183172
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
184173
confidence_intervals=confidence_intervals,
@@ -187,11 +176,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
187176
logger.debug("calculating ate")
188177
ate, confidence_intervals = estimator.estimate_ate()
189178
causal_test_result = CausalTestResult(
190-
treatment=estimator.treatment,
191-
outcome=estimator.outcome,
192-
treatment_value=estimator.treatment_value,
193-
control_value=estimator.control_value,
194-
adjustment_set=estimator.adjustment_set,
179+
estimator=estimator,
195180
test_value=TestValue("ate", ate),
196181
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
197182
confidence_intervals=confidence_intervals,
@@ -202,11 +187,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
202187
logger.debug("calculating ate")
203188
ate, confidence_intervals = estimator.estimate_ate_calculated()
204189
causal_test_result = CausalTestResult(
205-
treatment=estimator.treatment,
206-
outcome=estimator.outcome,
207-
treatment_value=estimator.treatment_value,
208-
control_value=estimator.control_value,
209-
adjustment_set=estimator.adjustment_set,
190+
estimator=estimator,
210191
test_value=TestValue("ate", ate),
211192
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
212193
confidence_intervals=confidence_intervals,

0 commit comments

Comments
 (0)