Skip to content

Commit 1e030d8

Browse files
Merge pull request #135 from CITCOM-project/pylint
Fix majority if small linting issues within the framework
2 parents e9f51e5 + 3a4f76a commit 1e030d8

17 files changed

+181
-121
lines changed

.pylintrc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ disable=raw-checker-failed,
152152
useless-suppression,
153153
deprecated-pragma,
154154
use-symbolic-message-instead,
155+
logging-fstring-interpolation,
155156

156157
# Enable the message, report, category or checker with the given id(s). You can
157158
# either give multiple identifier separated by comma (,) or put this option
@@ -239,7 +240,9 @@ good-names=i,
239240
j,
240241
k,
241242
ex,
243+
df,
242244
Run,
245+
z3,
243246
_
244247

245248
# Good variable names regexes, separated by a comma. If names match any regex,

causal_testing/data_collection/data_collector.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
"""This module contains the DataCollector abstract class, as well as its concrete extensions: ExperimentalDataCollector
2+
and ObservationalDataCollector"""
3+
14
import logging
25
from abc import ABC, abstractmethod
36
from enum import Enum
@@ -73,10 +76,7 @@ def filter_valid_data(self, data: pd.DataFrame, check_pos: bool = True) -> pd.Da
7376
size_diff = len(data) - len(satisfying_data)
7477
if size_diff > 0:
7578
logger.warning(
76-
"Discarded %s/%s values due to constraint violations.\n" "For example%s",
77-
size_diff,
78-
len(data),
79-
unsat_core,
79+
f"Discarded {size_diff}/{len(data)} values due to constraint violations.\n For example {unsat_core}",
8080
)
8181
return satisfying_data
8282

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
1+
"""This module contains the class AbstractCausalTestCase, which generates concrete test cases"""
2+
import itertools
13
import logging
4+
from enum import Enum
5+
from typing import Iterable
26

37
import lhsmdu
48
import pandas as pd
59
import z3
610
from scipy import stats
7-
import itertools
11+
812

913
from causal_testing.specification.scenario import Scenario
1014
from causal_testing.specification.variable import Variable
1115
from causal_testing.testing.causal_test_case import CausalTestCase
1216
from causal_testing.testing.causal_test_outcome import CausalTestOutcome
1317
from causal_testing.testing.base_test_case import BaseTestCase
1418

15-
from enum import Enum
1619

1720
logger = logging.getLogger(__name__)
1821

@@ -60,7 +63,9 @@ def __str__(self):
6063
)
6164
return f"When we apply intervention {self.intervention_constraints}, {outcome_string}"
6265

63-
def datapath(self):
66+
def datapath(self) -> str:
67+
"""Create and return the sanitised data path"""
68+
6469
def sanitise(string):
6570
return "".join([x for x in string if x.isalnum()])
6671

@@ -101,25 +106,7 @@ def _generate_concrete_tests(
101106
samples[var.name] = lhsmdu.inverseTransformSample(var.distribution, samples[var.name])
102107

103108
for index, row in samples.iterrows():
104-
optimizer = z3.Optimize()
105-
for c in self.scenario.constraints:
106-
optimizer.assert_and_track(c, str(c))
107-
for c in self.intervention_constraints:
108-
optimizer.assert_and_track(c, str(c))
109-
110-
for v in run_columns:
111-
optimizer.add_soft(
112-
self.scenario.variables[v].z3
113-
== self.scenario.variables[v].z3_val(self.scenario.variables[v].z3, row[v])
114-
)
115-
116-
if optimizer.check() == z3.unsat:
117-
logger.warning(
118-
"Satisfiability of test case was unsat.\n" "Constraints \n %s \n Unsat core %s",
119-
optimizer,
120-
optimizer.unsat_core(),
121-
)
122-
model = optimizer.model()
109+
model = self._optimizer_model(run_columns, row)
123110

124111
base_test_case = BaseTestCase(
125112
treatment_variable=self.treatment_variable,
@@ -146,7 +133,7 @@ def _generate_concrete_tests(
146133
+ f"{constraints}\nUsing value {v.cast(model[v.z3])} instead in test\n{concrete_test}"
147134
)
148135

149-
if not any([vars(t) == vars(concrete_test) for t in concrete_tests]):
136+
if not any((vars(t) == vars(concrete_test) for t in concrete_tests)):
150137
concrete_tests.append(concrete_test)
151138
# Control run
152139
control_run = {
@@ -197,12 +184,12 @@ def generate_concrete_tests(
197184

198185
pre_break = False
199186
for i in range(hard_max):
200-
concrete_tests_, runs_ = self._generate_concrete_tests(sample_size, rct, seed + i)
201-
for t_ in concrete_tests_:
202-
if not any([vars(t_) == vars(t) for t in concrete_tests]):
203-
concrete_tests.append(t_)
204-
runs = pd.concat([runs, runs_])
205-
assert concrete_tests_ not in concrete_tests, "Duplicate entries unlikely unless something went wrong"
187+
concrete_tests_temp, runs_temp = self._generate_concrete_tests(sample_size, rct, seed + i)
188+
for test in concrete_tests_temp:
189+
if not any((vars(test) == vars(t) for t in concrete_tests)):
190+
concrete_tests.append(test)
191+
runs = pd.concat([runs, runs_temp])
192+
assert concrete_tests_temp not in concrete_tests, "Duplicate entries unlikely unless something went wrong"
206193

207194
control_configs = pd.DataFrame([{test.treatment_variable: test.control_value} for test in concrete_tests])
208195
ks_stats = {
@@ -230,7 +217,7 @@ def generate_concrete_tests(
230217
control_values = [test.control_value for test in concrete_tests]
231218
treatment_values = [test.treatment_value for test in concrete_tests]
232219

233-
if self.treatment_variable.datatype is bool and set([(True, False), (False, True)]).issubset(
220+
if self.treatment_variable.datatype is bool and {(True, False), (False, True)}.issubset(
234221
set(zip(control_values, treatment_values))
235222
):
236223
pre_break = True
@@ -244,7 +231,7 @@ def generate_concrete_tests(
244231
).issubset(zip(control_values, treatment_values)):
245232
pre_break = True
246233
break
247-
elif target_ks_score and all((stat <= target_ks_score for stat in ks_stats.values())):
234+
if target_ks_score and all((stat <= target_ks_score for stat in ks_stats.values())):
248235
pre_break = True
249236
break
250237

@@ -256,3 +243,30 @@ def generate_concrete_tests(
256243
len(concrete_tests),
257244
)
258245
return concrete_tests, runs
246+
247+
def _optimizer_model(self, run_columns: Iterable[str], row: pd.core.series) -> z3.Optimize:
248+
"""
249+
:param run_columns: A sorted list of Variable names from the scenario variables
250+
:param row: A pandas Series containing a row from the Samples dataframe
251+
:return: z3 optimize model with constraints tracked and soft constraints added
252+
:rtype: z3.Optimize
253+
"""
254+
optimizer = z3.Optimize()
255+
for c in self.scenario.constraints:
256+
optimizer.assert_and_track(c, str(c))
257+
for c in self.intervention_constraints:
258+
optimizer.assert_and_track(c, str(c))
259+
260+
for v in run_columns:
261+
optimizer.add_soft(
262+
self.scenario.variables[v].z3
263+
== self.scenario.variables[v].z3_val(self.scenario.variables[v].z3, row[v])
264+
)
265+
266+
if optimizer.check() == z3.unsat:
267+
logger.warning(
268+
f"Satisfiability of test case was unsat.\n"
269+
f"Constraints \n {optimizer} \n Unsat core {optimizer.unsat_core()}",
270+
)
271+
model = optimizer.model()
272+
return model

causal_testing/json_front/json_class.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""This module contains the JsonUtility class, details of using this class can be found here:
2+
https://causal-testing-framework.readthedocs.io/en/latest/json_front_end.html"""
13
import argparse
24
import json
35
import logging
@@ -68,7 +70,6 @@ def set_path(self, json_path: str, dag_path: str, data_path: str):
6870
self.data_path = Path(data_path)
6971

7072
def set_variables(self, inputs: dict, outputs: dict, metas: dict):
71-
7273
"""Populate the Causal Variables
7374
:param inputs:
7475
:param outputs:
@@ -137,9 +138,8 @@ def _execute_tests(self, concrete_tests, estimators, test, f_flag):
137138
return failures
138139

139140
def _json_parse(self):
140-
141141
"""Parse a JSON input file into inputs, outputs, metas and a test plan"""
142-
with open(self.json_path) as f:
142+
with open(self.json_path, encoding="utf-8") as f:
143143
self.test_plan = json.load(f)
144144

145145
self.data = pd.read_csv(self.data_path)
@@ -179,7 +179,10 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
179179

180180
result_string = str()
181181
if causal_test_result.ci_low() and causal_test_result.ci_high():
182-
result_string = f"{causal_test_result.ci_low()} < {causal_test_result.test_value.value} < {causal_test_result.ci_high()}"
182+
result_string = (
183+
f"{causal_test_result.ci_low()} < {causal_test_result.test_value.value} < "
184+
f"{causal_test_result.ci_high()}"
185+
)
183186
else:
184187
result_string = f"{causal_test_result.test_value.value} no confidence intervals"
185188
if f_flag:
@@ -218,7 +221,7 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
218221

219222
return causal_test_engine, estimation_model
220223

221-
def add_modelling_assumptions(self, estimation_model: Estimator):
224+
def add_modelling_assumptions(self, estimation_model: Estimator): # pylint: disable=unused-argument
222225
"""Optional abstract method where user functionality can be written to determine what assumptions are required
223226
for specific test cases
224227
:param estimation_model: estimator model instance for the current running test.

causal_testing/specification/causal_dag.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
1+
"""This module contains the CausalDAG class, as well as the functions list_all_min_sep and close_seperator"""
2+
3+
from __future__ import annotations
4+
15
import logging
26
from itertools import combinations
37
from random import sample
4-
from typing import TypeVar, Union
8+
from typing import Union
59

610
import networkx as nx
711

812
from .scenario import Scenario
913
from .variable import Output
1014

1115
Node = Union[str, int] # Node type hint: A node is a string or an int
12-
CausalDAG = TypeVar("CausalDAG")
1316

1417
logger = logging.getLogger(__name__)
1518

@@ -49,7 +52,6 @@ def list_all_min_sep(
4952

5053
# 4. Confirm that the connected component containing the treatment node is disjoint with the outcome node set
5154
if not treatment_connected_component_node_set.intersection(outcome_node_set):
52-
5355
# 5. Update the treatment node set to the set of nodes in the connected component containing the treatment node
5456
treatment_node_set = treatment_connected_component_node_set
5557

@@ -60,7 +62,6 @@ def list_all_min_sep(
6062

6163
# 7. Check that there exists at least one neighbour of the treatment nodes that is not in the outcome node set
6264
if treatment_node_set_neighbours.difference(outcome_node_set):
63-
6465
# 7.1. If so, sample a random node from the set of treatment nodes' neighbours not in the outcome node set
6566
node = set(sample(treatment_node_set_neighbours.difference(outcome_node_set), 1))
6667

@@ -82,7 +83,6 @@ def list_all_min_sep(
8283
outcome_node_set.union(node),
8384
)
8485
else:
85-
8686
# 8. If all neighbours of the treatments nodes are in the outcome node set, return the set of treatment
8787
# node neighbours
8888
yield treatment_node_set_neighbours
@@ -352,10 +352,8 @@ def adjustment_set_is_minimal(self, treatments: list[str], outcomes: list[str],
352352
proper_backdoor_graph, treatments, outcomes, smaller_adjustment_set
353353
):
354354
logger.info(
355-
"Z=%s is not minimal because Z'=Z\\{{'%s'}}=" "%s is also a valid adjustment set.",
356-
adjustment_set,
357-
variable,
358-
smaller_adjustment_set,
355+
f"Z={adjustment_set} is not minimal because Z'=Z\\{variable} = {smaller_adjustment_set} is also a"
356+
f"valid adjustment set.",
359357
)
360358
return False
361359

@@ -466,7 +464,7 @@ def depends_on_outputs(self, node: Node, scenario: Scenario) -> bool:
466464
"""
467465
if isinstance(scenario.variables[node], Output):
468466
return True
469-
return any([self.depends_on_outputs(n, scenario) for n in self.graph.predecessors(node)])
467+
return any((self.depends_on_outputs(n, scenario) for n in self.graph.predecessors(node)))
470468

471469
def identification(self, base_test_case):
472470
"""Identify and return the minimum adjustment set
Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
1-
import logging
1+
"""This module holds the abstract CausalSpecification data class, which holds a Scenario and CausalDag"""
2+
23
from abc import ABC
4+
from dataclasses import dataclass
35
from typing import Union
46

57
from causal_testing.specification.causal_dag import CausalDAG
68
from causal_testing.specification.scenario import Scenario
79

810
Node = Union[str, int] # Node type hint: A node is a string or an int
9-
logger = logging.getLogger(__name__)
1011

1112

13+
@dataclass
1214
class CausalSpecification(ABC):
1315
"""
1416
Abstract Class for the Causal Specification (combination of Scenario and Causal Dag)
1517
"""
1618

17-
def __init__(self, scenario: Scenario, causal_dag: CausalDAG):
18-
self.scenario = scenario
19-
self.causal_dag = causal_dag
19+
scenario: Scenario
20+
causal_dag: CausalDAG
2021

2122
def __str__(self):
2223
return f"Scenario: {self.scenario}\nCausal DAG:\n{self.causal_dag}"

causal_testing/specification/scenario.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
"""This module holds the Scenario Class"""
12
from collections.abc import Iterable, Mapping
23

34
from tabulate import tabulate
@@ -30,12 +31,16 @@ def __init__(self, variables: Iterable[Variable] = None, constraints: set[ExprRe
3031
if variables is not None:
3132
self.variables = {v.name: v for v in variables}
3233
else:
33-
self.variables = dict()
34+
self.variables = {}
3435
if constraints is not None:
3536
self.constraints = set(constraints)
3637
else:
3738
self.constraints = set()
3839

40+
self.prime = {}
41+
self.unprime = {}
42+
self.treatment_variables = {}
43+
3944
def __str__(self):
4045
"""Returns a printable string of a scenario, e.g.
4146
Modelling scenario with variables:
@@ -94,9 +99,6 @@ def setup_treatment_variables(self) -> None:
9499
to the contraint set such that the "primed" variables are constrained in
95100
the same way as their unprimed counterparts.
96101
"""
97-
self.prime = {}
98-
self.unprime = {}
99-
self.treatment_variables = {}
100102
for k, v in self.variables.items():
101103
v_prime = self._fresh(v)
102104
self.treatment_variables[k] = v_prime
@@ -141,4 +143,7 @@ def metas(self) -> set[Meta]:
141143
return self.variables_of_type(Meta)
142144

143145
def add_variable(self, v: Variable) -> None:
146+
"""Add variable to variables attribute
147+
:param v: Variable to be added
148+
"""
144149
self.variables[v.name]: v

0 commit comments

Comments
 (0)