Skip to content

Commit 0cb5492

Browse files
authored
Merge pull request #269 from CITCOM-project/interaction-terms
Temporary workaround for "I(...) not in df" bug
2 parents 5b9f113 + 4d78ae9 commit 0cb5492

File tree

5 files changed

+31
-19
lines changed

5 files changed

+31
-19
lines changed

causal_testing/specification/causal_dag.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ class CausalDAG(nx.DiGraph):
133133
def __init__(self, dot_path: str = None, **attr):
134134
super().__init__(**attr)
135135
if dot_path:
136-
with open(dot_path, 'r', encoding='utf-8') as file:
137-
dot_content = file.read().replace('\n', '')
136+
with open(dot_path, "r", encoding="utf-8") as file:
137+
dot_content = file.read().replace("\n", "")
138138
# Previously, we used pydot_graph_from_file() to read in the dot_path directly, however,
139139
# this method does not currently have a way of removing spurious nodes.
140140
# Workaround: Read in the file using open(), remove new lines, and then create the pydot_graph.

causal_testing/surrogate/causal_surrogate_assisted.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class SimulationResult:
1919
relationship: str
2020

2121

22-
class SearchAlgorithm(ABC): # pylint: disable=too-few-public-methods
22+
class SearchAlgorithm(ABC): # pylint: disable=too-few-public-methods
2323
"""Class to be inherited with the search algorithm consisting of a search function and the fitness function of the
2424
space to be searched"""
2525

causal_testing/surrogate/surrogate_search_algorithms.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self, delta=0.05, config: dict = None) -> None:
2626

2727
# pylint: disable=too-many-locals
2828
def search(
29-
self, surrogate_models: list[CubicSplineRegressionEstimator], specification: CausalSpecification
29+
self, surrogate_models: list[CubicSplineRegressionEstimator], specification: CausalSpecification
3030
) -> list:
3131
solutions = []
3232

@@ -47,7 +47,8 @@ def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
4747
ate = surrogate.estimate_ate_calculated(adjustment_dict)
4848
if len(ate) > 1:
4949
raise ValueError(
50-
"Multiple ate values provided but currently only single values supported in this method")
50+
"Multiple ate values provided but currently only single values supported in this method"
51+
)
5152
return contradiction_function(ate[0])
5253

5354
gene_types, gene_space = self.create_gene_types(surrogate, specification)
@@ -84,7 +85,7 @@ def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
8485

8586
@staticmethod
8687
def create_gene_types(
87-
surrogate_model: CubicSplineRegressionEstimator, specification: CausalSpecification
88+
surrogate_model: CubicSplineRegressionEstimator, specification: CausalSpecification
8889
) -> tuple[list, list]:
8990
"""Generate the gene_types and gene_space for a given fitness function and specification
9091
:param surrogate_model: Instance of a CubicSplineRegressionEstimator

causal_testing/testing/causal_test_outcome.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@ class SomeEffect(CausalTestOutcome):
2929
def apply(self, res: CausalTestResult) -> bool:
3030
if res.test_value.type == "risk_ratio":
3131
return any(
32-
1 < ci_low < ci_high or ci_low < ci_high < 1 for ci_low, ci_high in zip(res.ci_low(), res.ci_high()))
33-
if res.test_value.type in ('coefficient', 'ate'):
32+
1 < ci_low < ci_high or ci_low < ci_high < 1 for ci_low, ci_high in zip(res.ci_low(), res.ci_high())
33+
)
34+
if res.test_value.type in ("coefficient", "ate"):
3435
return any(
35-
0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low, ci_high in zip(res.ci_low(), res.ci_high()))
36+
0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low, ci_high in zip(res.ci_low(), res.ci_high())
37+
)
3638

3739
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
3840

@@ -51,17 +53,19 @@ def __init__(self, atol: float = 1e-10, ctol: float = 0.05):
5153

5254
def apply(self, res: CausalTestResult) -> bool:
5355
if res.test_value.type == "risk_ratio":
54-
return any(ci_low < 1 < ci_high or np.isclose(value, 1.0, atol=self.atol) for ci_low, ci_high, value in
55-
zip(res.ci_low(), res.ci_high(), res.test_value.value))
56-
if res.test_value.type in ('coefficient', 'ate'):
56+
return any(
57+
ci_low < 1 < ci_high or np.isclose(value, 1.0, atol=self.atol)
58+
for ci_low, ci_high, value in zip(res.ci_low(), res.ci_high(), res.test_value.value)
59+
)
60+
if res.test_value.type in ("coefficient", "ate"):
5761
value = res.test_value.value if isinstance(res.ci_high(), Iterable) else [res.test_value.value]
5862
return (
59-
sum(
60-
not ((ci_low < 0 < ci_high) or abs(v) < self.atol)
61-
for ci_low, ci_high, v in zip(res.ci_low(), res.ci_high(), value)
62-
)
63-
/ len(value)
64-
< self.ctol
63+
sum(
64+
not ((ci_low < 0 < ci_high) or abs(v) < self.atol)
65+
for ci_low, ci_high, v in zip(res.ci_low(), res.ci_high(), value)
66+
)
67+
/ len(value)
68+
< self.ctol
6569
)
6670

6771
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")

causal_testing/testing/estimators.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,14 @@ def estimate_coefficient(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
352352
model = self._run_linear_regression()
353353
newline = "\n"
354354
patsy_md = ModelDesc.from_formula(self.treatment)
355-
if any((self.df.dtypes[factor.name()] == 'object' for factor in patsy_md.rhs_termlist[1].factors)):
355+
if any(
356+
(
357+
self.df.dtypes[factor.name()] == "object"
358+
for factor in patsy_md.rhs_termlist[1].factors
359+
# We want to remove this long term as it prevents us from discovering categoricals within I(...) blocks
360+
if factor.name() in self.df.dtypes
361+
)
362+
):
356363
design_info = dmatrix(self.formula.split("~")[1], self.df).design_info
357364
treatment = design_info.column_names[design_info.term_name_slices[self.treatment]]
358365
else:

0 commit comments

Comments
 (0)