Skip to content

Commit 0d9b1d7

Browse files
Merge pull request #264 from CITCOM-project/interaction-terms
Estimator return types
2 parents 66c35ac + b8ad419 commit 0d9b1d7

File tree

12 files changed

+150
-132
lines changed

12 files changed

+150
-132
lines changed

causal_testing/surrogate/surrogate_search_algorithms.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def search(
3535

3636
# The GA fitness function after including required variables into the function's scope
3737
# Unused arguments are required for pygad's fitness function signature
38-
#pylint: disable=cell-var-from-loop
39-
def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
38+
# pylint: disable=cell-var-from-loop
39+
def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
4040
surrogate.control_value = solution[0] - self.delta
4141
surrogate.treatment_value = solution[0] + self.delta
4242

@@ -45,8 +45,10 @@ def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
4545
adjustment_dict[adjustment] = solution[i + 1]
4646

4747
ate = surrogate.estimate_ate_calculated(adjustment_dict)
48-
49-
return contradiction_function(ate)
48+
if len(ate) > 1:
49+
raise ValueError(
50+
"Multiple ate values provided but currently only single values supported in this method")
51+
return contradiction_function(ate[0])
5052

5153
gene_types, gene_space = self.create_gene_types(surrogate, specification)
5254

@@ -82,7 +84,7 @@ def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
8284

8385
@staticmethod
8486
def create_gene_types(
85-
surrogate_model: CubicSplineRegressionEstimator, specification: CausalSpecification
87+
surrogate_model: CubicSplineRegressionEstimator, specification: CausalSpecification
8688
) -> tuple[list, list]:
8789
"""Generate the gene_types and gene_space for a given fitness function and specification
8890
:param surrogate_model: Instance of a CubicSplineRegressionEstimator

causal_testing/testing/causal_test_outcome.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,13 @@ class SomeEffect(CausalTestOutcome):
2727
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
2828

2929
def apply(self, res: CausalTestResult) -> bool:
30-
if res.test_value.type == "ate":
31-
return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
32-
if res.test_value.type == "coefficient":
33-
ci_low = res.ci_low() if isinstance(res.ci_low(), Iterable) else [res.ci_low()]
34-
ci_high = res.ci_high() if isinstance(res.ci_high(), Iterable) else [res.ci_high()]
35-
return any(0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low, ci_high in zip(ci_low, ci_high))
3630
if res.test_value.type == "risk_ratio":
37-
return (1 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 1)
31+
return any(
32+
1 < ci_low < ci_high or ci_low < ci_high < 1 for ci_low, ci_high in zip(res.ci_low(), res.ci_high()))
33+
if res.test_value.type in ('coefficient', 'ate'):
34+
return any(
35+
0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low, ci_high in zip(res.ci_low(), res.ci_high()))
36+
3837
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
3938

4039

@@ -51,23 +50,20 @@ def __init__(self, atol: float = 1e-10, ctol: float = 0.05):
5150
self.ctol = ctol
5251

5352
def apply(self, res: CausalTestResult) -> bool:
54-
if res.test_value.type == "ate":
55-
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < self.atol)
56-
if res.test_value.type == "coefficient":
57-
ci_low = res.ci_low() if isinstance(res.ci_low(), Iterable) else [res.ci_low()]
58-
ci_high = res.ci_high() if isinstance(res.ci_high(), Iterable) else [res.ci_high()]
53+
if res.test_value.type == "risk_ratio":
54+
return any(ci_low < 1 < ci_high or np.isclose(value, 1.0, atol=self.atol) for ci_low, ci_high, value in
55+
zip(res.ci_low(), res.ci_high(), res.test_value.value))
56+
if res.test_value.type in ('coefficient', 'ate'):
5957
value = res.test_value.value if isinstance(res.ci_high(), Iterable) else [res.test_value.value]
60-
6158
return (
62-
sum(
63-
not ((ci_low < 0 < ci_high) or abs(v) < self.atol)
64-
for ci_low, ci_high, v in zip(ci_low, ci_high, value)
65-
)
66-
/ len(value)
67-
< self.ctol
59+
sum(
60+
not ((ci_low < 0 < ci_high) or abs(v) < self.atol)
61+
for ci_low, ci_high, v in zip(res.ci_low(), res.ci_high(), value)
62+
)
63+
/ len(value)
64+
< self.ctol
6865
)
69-
if res.test_value.type == "risk_ratio":
70-
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=self.atol)
66+
7167
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
7268

7369

@@ -93,28 +89,33 @@ def __str__(self):
9389

9490

9591
class Positive(SomeEffect):
96-
"""An extension of TestOutcome representing that the expected causal effect should be positive."""
92+
"""An extension of TestOutcome representing that the expected causal effect should be positive.
93+
Currently only single values are supported for the test value"""
9794

9895
def apply(self, res: CausalTestResult) -> bool:
9996
if res.ci_valid() and not super().apply(res):
10097
return False
98+
if len(res.test_value.value) > 1:
99+
raise ValueError("Positive Effects are currently only supported on single float datatypes")
101100
if res.test_value.type in {"ate", "coefficient"}:
102-
return bool(res.test_value.value > 0)
101+
return bool(res.test_value.value[0] > 0)
103102
if res.test_value.type == "risk_ratio":
104-
return bool(res.test_value.value > 1)
105-
# Dead code but necessary for pylint
103+
return bool(res.test_value.value[0] > 1)
106104
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
107105

108106

109107
class Negative(SomeEffect):
110-
"""An extension of TestOutcome representing that the expected causal effect should be negative."""
108+
"""An extension of TestOutcome representing that the expected causal effect should be negative.
109+
Currently only single values are supported for the test value"""
111110

112111
def apply(self, res: CausalTestResult) -> bool:
113112
if res.ci_valid() and not super().apply(res):
114113
return False
114+
if len(res.test_value.value) > 1:
115+
raise ValueError("Negative Effects are currently only supported on single float datatypes")
115116
if res.test_value.type in {"ate", "coefficient"}:
116-
return bool(res.test_value.value < 0)
117+
return bool(res.test_value.value[0] < 0)
117118
if res.test_value.type == "risk_ratio":
118-
return bool(res.test_value.value < 1)
119+
return bool(res.test_value.value[0] < 1)
119120
# Dead code but necessary for pylint
120121
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")

causal_testing/testing/causal_test_result.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __init__(
2727
self,
2828
estimator: Estimator,
2929
test_value: TestValue,
30-
confidence_intervals: [float, float] = None,
30+
confidence_intervals: [pd.Series, pd.Series] = None,
3131
effect_modifier_configuration: {Variable: Any} = None,
3232
adequacy=None,
3333
):
@@ -99,12 +99,16 @@ def to_dict(self, json=False):
9999
def ci_low(self):
100100
"""Return the lower bracket of the confidence intervals."""
101101
if self.confidence_intervals:
102+
if isinstance(self.confidence_intervals[0], pd.Series):
103+
return self.confidence_intervals[0].to_list()
102104
return self.confidence_intervals[0]
103105
return None
104106

105107
def ci_high(self):
106108
"""Return the higher bracket of the confidence intervals."""
107109
if self.confidence_intervals:
110+
if isinstance(self.confidence_intervals[1], pd.Series):
111+
return self.confidence_intervals[1].to_list()
108112
return self.confidence_intervals[1]
109113
return None
110114

causal_testing/testing/estimators.py

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import statsmodels.formula.api as smf
1212
from econml.dml import CausalForestDML
1313
from patsy import dmatrix # pylint: disable = no-name-in-module
14-
14+
from patsy import ModelDesc
1515
from sklearn.ensemble import GradientBoostingRegressor
1616
from statsmodels.regression.linear_model import RegressionResultsWrapper
1717
from statsmodels.tools.sm_exceptions import PerfectSeparationError
@@ -343,30 +343,28 @@ def add_modelling_assumptions(self):
343343
"do not need to be linear."
344344
)
345345

346-
def estimate_coefficient(self) -> float:
346+
def estimate_coefficient(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
347347
"""Estimate the unit average treatment effect of the treatment on the outcome. That is, the change in outcome
348348
caused by a unit change in treatment.
349349
350350
:return: The unit average treatment effect and the 95% Wald confidence intervals.
351351
"""
352352
model = self._run_linear_regression()
353353
newline = "\n"
354-
treatment = [self.treatment]
355-
if str(self.df.dtypes[self.treatment]) == "object":
354+
patsy_md = ModelDesc.from_formula(self.treatment)
355+
if any((self.df.dtypes[factor.name()] == 'object' for factor in patsy_md.rhs_termlist[1].factors)):
356356
design_info = dmatrix(self.formula.split("~")[1], self.df).design_info
357357
treatment = design_info.column_names[design_info.term_name_slices[self.treatment]]
358+
else:
359+
treatment = [self.treatment]
358360
assert set(treatment).issubset(
359361
model.params.index.tolist()
360362
), f"{treatment} not in\n{' ' + str(model.params.index).replace(newline, newline + ' ')}"
361363
unit_effect = model.params[treatment] # Unit effect is the coefficient of the treatment
362364
[ci_low, ci_high] = self._get_confidence_intervals(model, treatment)
363-
if str(self.df.dtypes[self.treatment]) != "object":
364-
unit_effect = unit_effect[0]
365-
ci_low = ci_low[0]
366-
ci_high = ci_high[0]
367365
return unit_effect, [ci_low, ci_high]
368366

369-
def estimate_ate(self) -> tuple[float, list[float, float], float]:
367+
def estimate_ate(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
370368
"""Estimate the average treatment effect of the treatment on the outcome. That is, the change in outcome caused
371369
by changing the treatment variable from the control value to the treatment value.
372370
@@ -384,8 +382,9 @@ def estimate_ate(self) -> tuple[float, list[float, float], float]:
384382

385383
# Perform a t-test to compare the predicted outcome of the control and treated individual (ATE)
386384
t_test_results = model.t_test(individuals.loc["treated"] - individuals.loc["control"])
387-
ate = t_test_results.effect[0]
385+
ate = pd.Series(t_test_results.effect[0])
388386
confidence_intervals = list(t_test_results.conf_int(alpha=self.alpha).flatten())
387+
confidence_intervals = [pd.Series(interval) for interval in confidence_intervals]
389388
return ate, confidence_intervals
390389

391390
def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd.Series, pd.Series]:
@@ -414,7 +413,7 @@ def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd
414413

415414
return y.iloc[1], y.iloc[0]
416415

417-
def estimate_risk_ratio(self, adjustment_config: dict = None) -> tuple[float, list[float, float]]:
416+
def estimate_risk_ratio(self, adjustment_config: dict = None) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
418417
"""Estimate the risk_ratio effect of the treatment on the outcome. That is, the change in outcome caused
419418
by changing the treatment variable from the control value to the treatment value.
420419
@@ -423,12 +422,11 @@ def estimate_risk_ratio(self, adjustment_config: dict = None) -> tuple[float, li
423422
if adjustment_config is None:
424423
adjustment_config = {}
425424
control_outcome, treatment_outcome = self.estimate_control_treatment(adjustment_config=adjustment_config)
426-
ci_low = treatment_outcome["mean_ci_lower"] / control_outcome["mean_ci_upper"]
427-
ci_high = treatment_outcome["mean_ci_upper"] / control_outcome["mean_ci_lower"]
428-
429-
return (treatment_outcome["mean"] / control_outcome["mean"]), [ci_low, ci_high]
425+
ci_low = pd.Series(treatment_outcome["mean_ci_lower"] / control_outcome["mean_ci_upper"])
426+
ci_high = pd.Series(treatment_outcome["mean_ci_upper"] / control_outcome["mean_ci_lower"])
427+
return pd.Series(treatment_outcome["mean"] / control_outcome["mean"]), [ci_low, ci_high]
430428

431-
def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[float, list[float, float]]:
429+
def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
432430
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
433431
by changing the treatment variable from the control value to the treatment value. Here, we actually
434432
calculate the expected outcomes under control and treatment and divide one by the other. This
@@ -439,10 +437,9 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[float
439437
if adjustment_config is None:
440438
adjustment_config = {}
441439
control_outcome, treatment_outcome = self.estimate_control_treatment(adjustment_config=adjustment_config)
442-
ci_low = treatment_outcome["mean_ci_lower"] - control_outcome["mean_ci_upper"]
443-
ci_high = treatment_outcome["mean_ci_upper"] - control_outcome["mean_ci_lower"]
444-
445-
return (treatment_outcome["mean"] - control_outcome["mean"]), [ci_low, ci_high]
440+
ci_low = pd.Series(treatment_outcome["mean_ci_lower"] - control_outcome["mean_ci_upper"])
441+
ci_high = pd.Series(treatment_outcome["mean_ci_upper"] - control_outcome["mean_ci_lower"])
442+
return pd.Series(treatment_outcome["mean"] - control_outcome["mean"]), [ci_low, ci_high]
446443

447444
def _run_linear_regression(self) -> RegressionResultsWrapper:
448445
"""Run linear regression of the treatment and adjustment set against the outcome and return the model.
@@ -456,8 +453,8 @@ def _run_linear_regression(self) -> RegressionResultsWrapper:
456453
def _get_confidence_intervals(self, model, treatment):
457454
confidence_intervals = model.conf_int(alpha=self.alpha, cols=None)
458455
ci_low, ci_high = (
459-
confidence_intervals[0].loc[treatment],
460-
confidence_intervals[1].loc[treatment],
456+
pd.Series(confidence_intervals[0].loc[treatment]),
457+
pd.Series(confidence_intervals[1].loc[treatment]),
461458
)
462459
return [ci_low, ci_high]
463460

@@ -495,7 +492,7 @@ def __init__(
495492
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
496493
self.formula = f"{outcome} ~ cr({'+'.join(terms)}, df={basis})"
497494

498-
def estimate_ate_calculated(self, adjustment_config: dict = None) -> float:
495+
def estimate_ate_calculated(self, adjustment_config: dict = None) -> pd.Series:
499496
model = self._run_linear_regression()
500497

501498
x = {"Intercept": 1, self.treatment: self.treatment_value}
@@ -511,7 +508,7 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> float:
511508
x[self.treatment] = self.control_value
512509
control = model.predict(x).iloc[0]
513510

514-
return treatment - control
511+
return pd.Series(treatment - control)
515512

516513

517514
class InstrumentalVariableEstimator(Estimator):
@@ -567,7 +564,7 @@ def add_modelling_assumptions(self):
567564
"""
568565
)
569566

570-
def estimate_iv_coefficient(self, df):
567+
def estimate_iv_coefficient(self, df) -> float:
571568
"""
572569
Estimate the linear regression coefficient of the treatment on the
573570
outcome.
@@ -581,7 +578,7 @@ def estimate_iv_coefficient(self, df):
581578
# Estimate the coefficient of I on X by cancelling
582579
return ab / a
583580

584-
def estimate_coefficient(self, bootstrap_size=100):
581+
def estimate_coefficient(self, bootstrap_size=100) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
585582
"""
586583
Estimate the unit ate (i.e. coefficient) of the treatment on the
587584
outcome.
@@ -590,10 +587,10 @@ def estimate_coefficient(self, bootstrap_size=100):
590587
[self.estimate_iv_coefficient(self.df.sample(len(self.df), replace=True)) for _ in range(bootstrap_size)]
591588
)
592589
bound = ceil((bootstrap_size * self.alpha) / 2)
593-
ci_low = bootstraps[bound]
594-
ci_high = bootstraps[bootstrap_size - bound]
590+
ci_low = pd.Series(bootstraps[bound])
591+
ci_high = pd.Series(bootstraps[bootstrap_size - bound])
595592

596-
return self.estimate_iv_coefficient(self.df), (ci_low, ci_high)
593+
return pd.Series(self.estimate_iv_coefficient(self.df)), [ci_low, ci_high]
597594

598595

599596
class CausalForestEstimator(Estimator):
@@ -610,7 +607,7 @@ def add_modelling_assumptions(self):
610607
"""
611608
self.modelling_assumptions.append("Non-parametric estimator: no restrictions imposed on the data.")
612609

613-
def estimate_ate(self) -> float:
610+
def estimate_ate(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
614611
"""Estimate the average treatment effect.
615612
616613
:return ate, confidence_intervals: The average treatment effect and 95% confidence intervals.
@@ -638,9 +635,9 @@ def estimate_ate(self) -> float:
638635
model.fit(outcome_df, treatment_df, X=effect_modifier_df, W=confounders_df)
639636

640637
# Obtain the ATE and 95% confidence intervals
641-
ate = model.ate(effect_modifier_df, T0=self.control_value, T1=self.treatment_value)
638+
ate = pd.Series(model.ate(effect_modifier_df, T0=self.control_value, T1=self.treatment_value))
642639
ate_interval = model.ate_interval(effect_modifier_df, T0=self.control_value, T1=self.treatment_value)
643-
ci_low, ci_high = ate_interval[0], ate_interval[1]
640+
ci_low, ci_high = pd.Series(ate_interval[0]), pd.Series(ate_interval[1])
644641
return ate, [ci_low, ci_high]
645642

646643
def estimate_cates(self) -> pd.DataFrame:

0 commit comments

Comments
 (0)