Skip to content

Commit 8052df5

Browse files
black formatting
1 parent d81bb50 commit 8052df5

14 files changed

+45
-27
lines changed

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""This module contains the class AbstractCausalTestCase, which generates concrete test cases"""
2+
23
import itertools
34
import logging
45
from enum import Enum

causal_testing/json_front/json_class.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ def _create_abstract_test_case(self, test, mutates, effects):
108108
self.scenario.variables[variable]: effects[effect]
109109
for variable, effect in test["expected_effect"].items()
110110
},
111-
effect_modifiers={self.scenario.variables[v] for v in test["effect_modifiers"]}
112-
if "effect_modifiers" in test
113-
else {},
111+
effect_modifiers=(
112+
{self.scenario.variables[v] for v in test["effect_modifiers"]} if "effect_modifiers" in test else {}
113+
),
114114
estimate_type=test["estimate_type"],
115115
effect=test.get("effect", "total"),
116116
)

causal_testing/specification/metamorphic_relation.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,11 @@ def generate_follow_up(self, n_tests: int, min_val: float, max_val: float, seed:
7373
source_follow_up_test_inputs[[follow_up_input]]
7474
.rename(columns={follow_up_input: self.treatment_var})
7575
.to_dict(orient="records"),
76-
test_inputs.to_dict(orient="records")
77-
if not test_inputs.empty
78-
else [{}] * len(source_follow_up_test_inputs),
76+
(
77+
test_inputs.to_dict(orient="records")
78+
if not test_inputs.empty
79+
else [{}] * len(source_follow_up_test_inputs)
80+
),
7981
)
8082
]
8183

causal_testing/specification/scenario.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""This module holds the Scenario Class"""
2+
23
from collections.abc import Iterable, Mapping
34

45
from tabulate import tabulate

causal_testing/surrogate/causal_surrogate_assisted.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class SimulationResult:
1919
relationship: str
2020

2121

22-
class SearchAlgorithm(ABC): # pylint: disable=too-few-public-methods
22+
class SearchAlgorithm(ABC): # pylint: disable=too-few-public-methods
2323
"""Class to be inherited with the search algorithm consisting of a search function and the fitness function of the
2424
space to be searched"""
2525

causal_testing/surrogate/surrogate_search_algorithms.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Module containing implementation of search algorithm for surrogate search """
2+
23
# Fitness functions are required to be iteratively defined, including all variables within.
34

45
from operator import itemgetter
@@ -26,7 +27,7 @@ def __init__(self, delta=0.05, config: dict = None) -> None:
2627

2728
# pylint: disable=too-many-locals
2829
def search(
29-
self, surrogate_models: list[CubicSplineRegressionEstimator], specification: CausalSpecification
30+
self, surrogate_models: list[CubicSplineRegressionEstimator], specification: CausalSpecification
3031
) -> list:
3132
solutions = []
3233

@@ -47,7 +48,8 @@ def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
4748
ate = surrogate.estimate_ate_calculated(adjustment_dict)
4849
if len(ate) > 1:
4950
raise ValueError(
50-
"Multiple ate values provided but currently only single values supported in this method")
51+
"Multiple ate values provided but currently only single values supported in this method"
52+
)
5153
return contradiction_function(ate[0])
5254

5355
gene_types, gene_space = self.create_gene_types(surrogate, specification)
@@ -84,7 +86,7 @@ def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
8486

8587
@staticmethod
8688
def create_gene_types(
87-
surrogate_model: CubicSplineRegressionEstimator, specification: CausalSpecification
89+
surrogate_model: CubicSplineRegressionEstimator, specification: CausalSpecification
8890
) -> tuple[list, list]:
8991
"""Generate the gene_types and gene_space for a given fitness function and specification
9092
:param surrogate_model: Instance of a CubicSplineRegressionEstimator

causal_testing/testing/base_test_case.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""This module contains the BaseTestCase dataclass, which stores the information required for identification"""
2+
23
from dataclasses import dataclass
34
from causal_testing.specification.variable import Variable
45
from causal_testing.testing.effect import Effect

causal_testing/testing/causal_test_adequacy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
This module contains code to measure various aspects of causal test adequacy.
33
"""
4+
45
from itertools import combinations
56
from copy import deepcopy
67
import pandas as pd

causal_testing/testing/causal_test_case.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""This module contains the CausalTestCase class, a class that holds the information required for a causal test"""
2+
23
import logging
34
from typing import Any
45
import numpy as np

causal_testing/testing/causal_test_outcome.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@ class SomeEffect(CausalTestOutcome):
2929
def apply(self, res: CausalTestResult) -> bool:
3030
if res.test_value.type == "risk_ratio":
3131
return any(
32-
1 < ci_low < ci_high or ci_low < ci_high < 1 for ci_low, ci_high in zip(res.ci_low(), res.ci_high()))
33-
if res.test_value.type in ('coefficient', 'ate'):
32+
1 < ci_low < ci_high or ci_low < ci_high < 1 for ci_low, ci_high in zip(res.ci_low(), res.ci_high())
33+
)
34+
if res.test_value.type in ("coefficient", "ate"):
3435
return any(
35-
0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low, ci_high in zip(res.ci_low(), res.ci_high()))
36+
0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low, ci_high in zip(res.ci_low(), res.ci_high())
37+
)
3638

3739
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
3840

@@ -51,17 +53,19 @@ def __init__(self, atol: float = 1e-10, ctol: float = 0.05):
5153

5254
def apply(self, res: CausalTestResult) -> bool:
5355
if res.test_value.type == "risk_ratio":
54-
return any(ci_low < 1 < ci_high or np.isclose(value, 1.0, atol=self.atol) for ci_low, ci_high, value in
55-
zip(res.ci_low(), res.ci_high(), res.test_value.value))
56-
if res.test_value.type in ('coefficient', 'ate'):
56+
return any(
57+
ci_low < 1 < ci_high or np.isclose(value, 1.0, atol=self.atol)
58+
for ci_low, ci_high, value in zip(res.ci_low(), res.ci_high(), res.test_value.value)
59+
)
60+
if res.test_value.type in ("coefficient", "ate"):
5761
value = res.test_value.value if isinstance(res.ci_high(), Iterable) else [res.test_value.value]
5862
return (
59-
sum(
60-
not ((ci_low < 0 < ci_high) or abs(v) < self.atol)
61-
for ci_low, ci_high, v in zip(res.ci_low(), res.ci_high(), value)
62-
)
63-
/ len(value)
64-
< self.ctol
63+
sum(
64+
not ((ci_low < 0 < ci_high) or abs(v) < self.atol)
65+
for ci_low, ci_high, v in zip(res.ci_low(), res.ci_high(), value)
66+
)
67+
/ len(value)
68+
< self.ctol
6569
)
6670

6771
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")

0 commit comments

Comments
 (0)