Skip to content

Commit 86ee372

Browse files
committed
black
1 parent bcbbb2e commit 86ee372

File tree

5 files changed

+24
-19
lines changed

5 files changed

+24
-19
lines changed

causal_testing/specification/causal_dag.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ class CausalDAG(nx.DiGraph):
133133
def __init__(self, dot_path: str = None, **attr):
134134
super().__init__(**attr)
135135
if dot_path:
136-
with open(dot_path, 'r', encoding='utf-8') as file:
137-
dot_content = file.read().replace('\n', '')
136+
with open(dot_path, "r", encoding="utf-8") as file:
137+
dot_content = file.read().replace("\n", "")
138138
# Previously, we used pydot_graph_from_file() to read in the dot_path directly, however,
139139
# this method does not currently have a way of removing spurious nodes.
140140
# Workaround: Read in the file using open(), remove new lines, and then create the pydot_graph.

causal_testing/surrogate/causal_surrogate_assisted.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class SimulationResult:
1919
relationship: str
2020

2121

22-
class SearchAlgorithm(ABC): # pylint: disable=too-few-public-methods
22+
class SearchAlgorithm(ABC): # pylint: disable=too-few-public-methods
2323
"""Class to be inherited with the search algorithm consisting of a search function and the fitness function of the
2424
space to be searched"""
2525

causal_testing/surrogate/surrogate_search_algorithms.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self, delta=0.05, config: dict = None) -> None:
2626

2727
# pylint: disable=too-many-locals
2828
def search(
29-
self, surrogate_models: list[CubicSplineRegressionEstimator], specification: CausalSpecification
29+
self, surrogate_models: list[CubicSplineRegressionEstimator], specification: CausalSpecification
3030
) -> list:
3131
solutions = []
3232

@@ -47,7 +47,8 @@ def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
4747
ate = surrogate.estimate_ate_calculated(adjustment_dict)
4848
if len(ate) > 1:
4949
raise ValueError(
50-
"Multiple ate values provided but currently only single values supported in this method")
50+
"Multiple ate values provided but currently only single values supported in this method"
51+
)
5152
return contradiction_function(ate[0])
5253

5354
gene_types, gene_space = self.create_gene_types(surrogate, specification)
@@ -84,7 +85,7 @@ def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
8485

8586
@staticmethod
8687
def create_gene_types(
87-
surrogate_model: CubicSplineRegressionEstimator, specification: CausalSpecification
88+
surrogate_model: CubicSplineRegressionEstimator, specification: CausalSpecification
8889
) -> tuple[list, list]:
8990
"""Generate the gene_types and gene_space for a given fitness function and specification
9091
:param surrogate_model: Instance of a CubicSplineRegressionEstimator

causal_testing/testing/causal_test_outcome.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@ class SomeEffect(CausalTestOutcome):
2929
def apply(self, res: CausalTestResult) -> bool:
3030
if res.test_value.type == "risk_ratio":
3131
return any(
32-
1 < ci_low < ci_high or ci_low < ci_high < 1 for ci_low, ci_high in zip(res.ci_low(), res.ci_high()))
33-
if res.test_value.type in ('coefficient', 'ate'):
32+
1 < ci_low < ci_high or ci_low < ci_high < 1 for ci_low, ci_high in zip(res.ci_low(), res.ci_high())
33+
)
34+
if res.test_value.type in ("coefficient", "ate"):
3435
return any(
35-
0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low, ci_high in zip(res.ci_low(), res.ci_high()))
36+
0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low, ci_high in zip(res.ci_low(), res.ci_high())
37+
)
3638

3739
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
3840

@@ -51,17 +53,19 @@ def __init__(self, atol: float = 1e-10, ctol: float = 0.05):
5153

5254
def apply(self, res: CausalTestResult) -> bool:
5355
if res.test_value.type == "risk_ratio":
54-
return any(ci_low < 1 < ci_high or np.isclose(value, 1.0, atol=self.atol) for ci_low, ci_high, value in
55-
zip(res.ci_low(), res.ci_high(), res.test_value.value))
56-
if res.test_value.type in ('coefficient', 'ate'):
56+
return any(
57+
ci_low < 1 < ci_high or np.isclose(value, 1.0, atol=self.atol)
58+
for ci_low, ci_high, value in zip(res.ci_low(), res.ci_high(), res.test_value.value)
59+
)
60+
if res.test_value.type in ("coefficient", "ate"):
5761
value = res.test_value.value if isinstance(res.ci_high(), Iterable) else [res.test_value.value]
5862
return (
59-
sum(
60-
not ((ci_low < 0 < ci_high) or abs(v) < self.atol)
61-
for ci_low, ci_high, v in zip(res.ci_low(), res.ci_high(), value)
62-
)
63-
/ len(value)
64-
< self.ctol
63+
sum(
64+
not ((ci_low < 0 < ci_high) or abs(v) < self.atol)
65+
for ci_low, ci_high, v in zip(res.ci_low(), res.ci_high(), value)
66+
)
67+
/ len(value)
68+
< self.ctol
6569
)
6670

6771
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")

causal_testing/testing/estimators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def estimate_coefficient(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
356356
(
357357
self.df.dtypes[factor.name()] == "object"
358358
for factor in patsy_md.rhs_termlist[1].factors
359-
# TODO: Remove the requirement for this as it prevents us from discovering categoricals within I(...) blocks
359+
# We want to remove this long term as it prevents us from discovering categoricals within I(...) blocks
360360
if factor in self.df
361361
)
362362
):

0 commit comments

Comments
 (0)