Skip to content

Commit 2c09cc3

Browse files
Handle case where there are no covariates
1 parent 6d6ad7c commit 2c09cc3

File tree

1 file changed

+78
-68
lines changed

1 file changed

+78
-68
lines changed

causal_testing/testing/estimators.py

Lines changed: 78 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,16 @@ class Estimator(ABC):
4141
"""
4242

4343
def __init__(
44-
# pylint: disable=too-many-arguments
45-
self,
46-
treatment: str,
47-
treatment_value: float,
48-
control_value: float,
49-
adjustment_set: set,
50-
outcome: str,
51-
df: pd.DataFrame = None,
52-
effect_modifiers: dict[str:Any] = None,
53-
alpha: float = 0.05,
44+
# pylint: disable=too-many-arguments
45+
self,
46+
treatment: str,
47+
treatment_value: float,
48+
control_value: float,
49+
adjustment_set: set,
50+
outcome: str,
51+
df: pd.DataFrame = None,
52+
effect_modifiers: dict[str:Any] = None,
53+
alpha: float = 0.05,
5454
):
5555
self.treatment = treatment
5656
self.treatment_value = treatment_value
@@ -85,25 +85,24 @@ def compute_confidence_intervals(self) -> list[float, float]:
8585

8686

8787
class RegressionEstimator(Estimator):
88-
"""
89-
90-
"""
88+
""" """
9189

9290
def __init__(
93-
# pylint: disable=too-many-arguments
94-
self,
95-
treatment: str,
96-
treatment_value: float,
97-
control_value: float,
98-
adjustment_set: set,
99-
outcome: str,
100-
df: pd.DataFrame = None,
101-
effect_modifiers: dict[str:Any] = None,
102-
formula: str = None,
103-
alpha: float = 0.05,
91+
# pylint: disable=too-many-arguments
92+
self,
93+
treatment: str,
94+
treatment_value: float,
95+
control_value: float,
96+
adjustment_set: set,
97+
outcome: str,
98+
df: pd.DataFrame = None,
99+
effect_modifiers: dict[str:Any] = None,
100+
formula: str = None,
101+
alpha: float = 0.05,
104102
):
105-
super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers,
106-
alpha=alpha)
103+
super().__init__(
104+
treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers, alpha=alpha
105+
)
107106

108107
if effect_modifiers is None:
109108
effect_modifiers = []
@@ -134,14 +133,19 @@ def get_terms_from_formula(self):
134133
if self.treatment not in rhs_terms:
135134
raise ValueError(f"Treatment variable '{self.treatment}' not found in formula")
136135
covariates = rhs_terms.remove(self.treatment)
136+
if covariates is None:
137+
covariates = []
137138
return outcome, self.treatment, covariates
138139

139140
def validate_formula(self, causal_dag: CausalDAG):
140141
outcome, treatment, covariates = self.get_terms_from_formula()
141142
proper_backdoor_graph = causal_dag.get_proper_backdoor_graph(treatments=[treatment], outcomes=[outcome])
142-
return CausalDAG.constructive_backdoor_criterion(proper_backdoor_graph=proper_backdoor_graph,
143-
treatments=[treatment], outcomes=[outcome],
144-
covariates=list(covariates))
143+
return causal_dag.constructive_backdoor_criterion(
144+
proper_backdoor_graph=proper_backdoor_graph,
145+
treatments=[treatment],
146+
outcomes=[outcome],
147+
covariates=list(covariates),
148+
)
145149

146150

147151
class LogisticRegressionEstimator(RegressionEstimator):
@@ -151,19 +155,20 @@ class LogisticRegressionEstimator(RegressionEstimator):
151155
"""
152156

153157
def __init__(
154-
# pylint: disable=too-many-arguments
155-
self,
156-
treatment: str,
157-
treatment_value: float,
158-
control_value: float,
159-
adjustment_set: set,
160-
outcome: str,
161-
df: pd.DataFrame = None,
162-
effect_modifiers: dict[str:Any] = None,
163-
formula: str = None,
158+
# pylint: disable=too-many-arguments
159+
self,
160+
treatment: str,
161+
treatment_value: float,
162+
control_value: float,
163+
adjustment_set: set,
164+
outcome: str,
165+
df: pd.DataFrame = None,
166+
effect_modifiers: dict[str:Any] = None,
167+
formula: str = None,
164168
):
165-
super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers,
166-
formula)
169+
super().__init__(
170+
treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers, formula
171+
)
167172

168173
self.model = None
169174

@@ -218,7 +223,7 @@ def estimate(self, data: pd.DataFrame, adjustment_config: dict = None) -> Regres
218223
return model.predict(x)
219224

220225
def estimate_control_treatment(
221-
self, adjustment_config: dict = None, bootstrap_size: int = 100
226+
self, adjustment_config: dict = None, bootstrap_size: int = 100
222227
) -> tuple[pd.Series, pd.Series]:
223228
"""Estimate the outcomes under control and treatment.
224229
@@ -336,23 +341,28 @@ class LinearRegressionEstimator(RegressionEstimator):
336341
"""
337342

338343
def __init__(
339-
# pylint: disable=too-many-arguments
340-
self,
341-
treatment: str,
342-
treatment_value: float,
343-
control_value: float,
344-
adjustment_set: set,
345-
outcome: str,
346-
df: pd.DataFrame = None,
347-
effect_modifiers: dict[Variable:Any] = None,
348-
formula: str = None,
349-
alpha: float = 0.05,
350-
344+
# pylint: disable=too-many-arguments
345+
self,
346+
treatment: str,
347+
treatment_value: float,
348+
control_value: float,
349+
adjustment_set: set,
350+
outcome: str,
351+
df: pd.DataFrame = None,
352+
effect_modifiers: dict[Variable:Any] = None,
353+
formula: str = None,
354+
alpha: float = 0.05,
351355
):
352-
353356
super().__init__(
354-
treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers, alpha=alpha,
355-
formula=formula
357+
treatment,
358+
treatment_value,
359+
control_value,
360+
adjustment_set,
361+
outcome,
362+
df,
363+
effect_modifiers,
364+
alpha=alpha,
365+
formula=formula,
356366
)
357367

358368
self.model = None
@@ -497,17 +507,17 @@ class InstrumentalVariableEstimator(Estimator):
497507
"""
498508

499509
def __init__(
500-
# pylint: disable=too-many-arguments
501-
self,
502-
treatment: str,
503-
treatment_value: float,
504-
control_value: float,
505-
adjustment_set: set,
506-
outcome: str,
507-
instrument: str,
508-
df: pd.DataFrame = None,
509-
intercept: int = 1,
510-
effect_modifiers: dict = None, # Not used (yet?). Needed for compatibility
510+
# pylint: disable=too-many-arguments
511+
self,
512+
treatment: str,
513+
treatment_value: float,
514+
control_value: float,
515+
adjustment_set: set,
516+
outcome: str,
517+
instrument: str,
518+
df: pd.DataFrame = None,
519+
intercept: int = 1,
520+
effect_modifiers: dict = None, # Not used (yet?). Needed for compatibility
511521
):
512522
super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, None)
513523
self.intercept = intercept

0 commit comments

Comments
 (0)