Skip to content

Commit 86211d5

Browse files
committed
Changed to having estomators return an EffectEstimate instance rather than a tuple
1 parent 0ad05e2 commit 86211d5

27 files changed

+335
-404
lines changed

causal_testing/estimation/cubic_spline_estimator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from causal_testing.specification.variable import Variable
1010
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
11+
from causal_testing.estimation.effect_estimate import EffectEstimate
1112
from causal_testing.testing.base_test_case import BaseTestCase
1213

1314
logger = logging.getLogger(__name__)
@@ -47,7 +48,7 @@ def __init__(
4748
)
4849
self.formula = f"{base_test_case.outcome_variable.name} ~ cr({'+'.join(terms)}, df={basis})"
4950

50-
def estimate_ate_calculated(self, adjustment_config: dict = None) -> pd.Series:
51+
def estimate_ate_calculated(self, adjustment_config: dict = None) -> EffectEstimate:
5152
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
5253
by changing the treatment variable from the control value to the treatment value. Here, we actually
5354
calculate the expected outcomes under control and treatment and divide one by the other. This
@@ -74,4 +75,4 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> pd.Series:
7475
x[self.base_test_case.treatment_variable.name] = self.control_value
7576
control = model.predict(x).iloc[0]
7677

77-
return pd.Series(treatment - control)
78+
return EffectEstimate("ate", pd.Series(treatment - control))
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""
2+
This module contains the EffectEstimate dataclass.
3+
"""
4+
5+
import pandas as pd
6+
from dataclasses import dataclass
7+
8+
9+
@dataclass
10+
class EffectEstimate:
11+
"""
12+
A dataclass to hold the value and confidence intervals of a causal effect estimate
13+
14+
:ivar type: The type of estimate, e.g. ate, or risk_ratio
15+
(used to determine whether the estimate matches the expected effect)
16+
:ivar value: The estimated causal effect
17+
:ivar ci_low: The lower confidence interval
18+
:ivar ci_high: The upper confidence interval
19+
"""
20+
21+
type: str
22+
value: pd.Series
23+
ci_low: pd.Series = None
24+
ci_high: pd.Series = None
25+
26+
def ci_valid(self) -> bool:
27+
"""Return whether or not the result has valid confidence invervals"""
28+
return (
29+
self.ci_low is not None
30+
and self.ci_high is not None
31+
and not (pd.isnull(self.ci_low).any() or pd.isnull(self.ci_high).any())
32+
)
33+
34+
def to_dict(self) -> dict:
35+
"""Return representation as a dict."""
36+
d = {"effect_measure": self.type, "effect_estimate": self.value.to_dict()}
37+
if self.ci_valid():
38+
return d | {"ci_low": self.ci_low.to_dict(), "ci_high": self.ci_high.to_dict()}
39+
return d
40+
41+
def to_df(self) -> pd.DataFrame:
42+
return pd.DataFrame({"effect_estimate": self.value, "ci_low": self.ci_low, "ci_high": self.ci_high})

causal_testing/estimation/experimental_estimator.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pandas as pd
66

77
from causal_testing.estimation.abstract_estimator import Estimator
8+
from causal_testing.estimation.effect_estimate import EffectEstimate
89
from causal_testing.testing.base_test_case import BaseTestCase
910

1011

@@ -55,7 +56,7 @@ def run_system(self, configuration: dict) -> dict:
5556
:returns: The resulting output as a dict.
5657
"""
5758

58-
def estimate_ate(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
59+
def estimate_ate(self) -> EffectEstimate:
5960
"""Estimate the average treatment effect of the treatment on the outcome. That is, the change in outcome caused
6061
by changing the treatment variable from the control value to the treatment value.
6162
@@ -88,14 +89,20 @@ def estimate_ate(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
8889
ci_low = difference.iloc[ci_low_index]
8990
ci_high = difference.iloc[self.repeats - ci_low_index]
9091

91-
return pd.Series(
92-
{self.base_test_case.treatment_variable.name: difference.mean()[self.base_test_case.outcome_variable.name]}
93-
), [
92+
return EffectEstimate(
93+
"ate",
94+
pd.Series(
95+
{
96+
self.base_test_case.treatment_variable.name: difference.mean()[
97+
self.base_test_case.outcome_variable.name
98+
]
99+
}
100+
),
94101
pd.Series({self.base_test_case.treatment_variable.name: ci_low[self.base_test_case.outcome_variable.name]}),
95102
pd.Series(
96103
{self.base_test_case.treatment_variable.name: ci_high[self.base_test_case.outcome_variable.name]}
97104
),
98-
]
105+
)
99106

100107
def estimate_risk_ratio(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
101108
"""Estimate the risk ratio of the treatment on the outcome. That is, the change in outcome caused
@@ -130,11 +137,11 @@ def estimate_risk_ratio(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
130137
ci_low = difference.iloc[ci_low_index]
131138
ci_high = difference.iloc[self.repeats - ci_low_index]
132139

133-
return pd.Series(
134-
{self.base_test_case.treatment_variable.name: difference.mean()[self.base_test_case.outcome_variable.name]}
135-
), [
140+
return EffectEstimate(
141+
"ate",
142+
{self.base_test_case.treatment_variable.name: difference.mean()[self.base_test_case.outcome_variable.name]},
136143
pd.Series({self.base_test_case.treatment_variable.name: ci_low[self.base_test_case.outcome_variable.name]}),
137144
pd.Series(
138145
{self.base_test_case.treatment_variable.name: ci_high[self.base_test_case.outcome_variable.name]}
139146
),
140-
]
147+
)

causal_testing/estimation/instrumental_variable_estimator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import statsmodels.api as sm
88

99
from causal_testing.estimation.abstract_estimator import Estimator
10+
from causal_testing.estimation.effect_estimate import EffectEstimate
1011
from causal_testing.testing.base_test_case import BaseTestCase
1112

1213
logger = logging.getLogger(__name__)
@@ -61,7 +62,7 @@ def add_modelling_assumptions(self):
6162
"""
6263
)
6364

64-
def estimate_iv_coefficient(self, df) -> float:
65+
def iv_coefficient(self, df) -> float:
6566
"""
6667
Estimate the linear regression coefficient of the treatment on the
6768
outcome.
@@ -75,16 +76,16 @@ def estimate_iv_coefficient(self, df) -> float:
7576
# Estimate the coefficient of I on X by cancelling
7677
return ab / a
7778

78-
def estimate_coefficient(self, bootstrap_size=100) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
79+
def estimate_coefficient(self, bootstrap_size=100) -> EffectEstimate:
7980
"""
8081
Estimate the unit ate (i.e. coefficient) of the treatment on the
8182
outcome.
8283
"""
8384
bootstraps = sorted(
84-
[self.estimate_iv_coefficient(self.df.sample(len(self.df), replace=True)) for _ in range(bootstrap_size)]
85+
[self.iv_coefficient(self.df.sample(len(self.df), replace=True)) for _ in range(bootstrap_size)]
8586
)
8687
bound = ceil((bootstrap_size * self.alpha) / 2)
8788
ci_low = pd.Series(bootstraps[bound])
8889
ci_high = pd.Series(bootstraps[bootstrap_size - bound])
8990

90-
return pd.Series(self.estimate_iv_coefficient(self.df)), [ci_low, ci_high]
91+
return EffectEstimate("coefficient", pd.Series(self.iv_coefficient(self.df)), ci_low, ci_high)

causal_testing/estimation/ipcw_estimator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from lifelines import CoxPHFitter
1212

1313
from causal_testing.estimation.abstract_estimator import Estimator
14+
from causal_testing.estimation.effect_estimate import EffectEstimate
1415
from causal_testing.testing.base_test_case import BaseTestCase
1516
from causal_testing.specification.variable import Variable
1617

@@ -285,7 +286,7 @@ def preprocess_data(self):
285286
if len(self.df.loc[self.df["trtrand"] == 1]) == 0:
286287
raise ValueError(f"No individuals began the treatment strategy {self.treatment_strategy}")
287288

288-
def estimate_hazard_ratio(self):
289+
def estimate_hazard_ratio(self) -> EffectEstimate:
289290
"""
290291
Estimate the hazard ratio.
291292
"""
@@ -380,4 +381,4 @@ def estimate_hazard_ratio(self):
380381

381382
ci_low, ci_high = [np.exp(cox_ph.confidence_intervals_)[col] for col in cox_ph.confidence_intervals_.columns]
382383

383-
return (cox_ph.hazard_ratios_, (ci_low, ci_high))
384+
return EffectEstimate("hazard_ratio", cox_ph.hazard_ratios_, ci_low, ci_high)

causal_testing/estimation/linear_regression_estimator.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from causal_testing.specification.variable import Variable
1111
from causal_testing.estimation.genetic_programming_regression_fitter import GP
1212
from causal_testing.estimation.abstract_regression_estimator import RegressionEstimator
13+
from causal_testing.estimation.effect_estimate import EffectEstimate
1314
from causal_testing.testing.base_test_case import BaseTestCase
1415

1516
logger = logging.getLogger(__name__)
@@ -92,7 +93,7 @@ def gp_formula(
9293
formula = gp.simplify(formula)
9394
self.formula = f"{self.base_test_case.outcome_variable.name} ~ I({formula}) - 1"
9495

95-
def estimate_coefficient(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
96+
def estimate_coefficient(self) -> EffectEstimate:
9697
"""Estimate the unit average treatment effect of the treatment on the outcome. That is, the change in outcome
9798
caused by a unit change in treatment.
9899
@@ -121,9 +122,9 @@ def estimate_coefficient(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
121122
), f"{treatment} not in\n{' ' + str(model.params.index).replace(newline, newline + ' ')}"
122123
unit_effect = model.params[treatment] # Unit effect is the coefficient of the treatment
123124
[ci_low, ci_high] = self._get_confidence_intervals(model, treatment)
124-
return unit_effect, [ci_low, ci_high]
125+
return EffectEstimate("coefficient", unit_effect, ci_low, ci_high)
125126

126-
def estimate_ate(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
127+
def estimate_ate(self) -> EffectEstimate:
127128
"""Estimate the average treatment effect of the treatment on the outcome. That is, the change in outcome caused
128129
by changing the treatment variable from the control value to the treatment value.
129130
@@ -146,10 +147,10 @@ def estimate_ate(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
146147
t_test_results = model.t_test(individuals.loc["treated"] - individuals.loc["control"])
147148
ate = pd.Series(t_test_results.effect[0])
148149
confidence_intervals = list(t_test_results.conf_int(alpha=self.alpha).flatten())
149-
confidence_intervals = [pd.Series(interval) for interval in confidence_intervals]
150-
return ate, confidence_intervals
150+
ci_low, ci_high = [pd.Series(interval) for interval in confidence_intervals]
151+
return EffectEstimate("ate", ate, ci_low, ci_high)
151152

152-
def estimate_risk_ratio(self, adjustment_config: dict = None) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
153+
def estimate_risk_ratio(self, adjustment_config: dict = None) -> EffectEstimate:
153154
"""Estimate the risk_ratio effect of the treatment on the outcome. That is, the change in outcome caused
154155
by changing the treatment variable from the control value to the treatment value.
155156
@@ -159,9 +160,11 @@ def estimate_risk_ratio(self, adjustment_config: dict = None) -> tuple[pd.Series
159160
control_outcome, treatment_outcome = prediction.iloc[1], prediction.iloc[0]
160161
ci_low = pd.Series(treatment_outcome["mean_ci_lower"] / control_outcome["mean_ci_upper"])
161162
ci_high = pd.Series(treatment_outcome["mean_ci_upper"] / control_outcome["mean_ci_lower"])
162-
return pd.Series(treatment_outcome["mean"] / control_outcome["mean"]), [ci_low, ci_high]
163+
return EffectEstimate(
164+
"risk_ratio", pd.Series(treatment_outcome["mean"] / control_outcome["mean"]), ci_low, ci_high
165+
)
163166

164-
def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
167+
def estimate_ate_calculated(self, adjustment_config: dict = None) -> EffectEstimate:
165168
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
166169
by changing the treatment variable from the control value to the treatment value. Here, we actually
167170
calculate the expected outcomes under control and treatment and divide one by the other. This
@@ -177,7 +180,7 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[pd.Se
177180
control_outcome, treatment_outcome = prediction.iloc[1], prediction.iloc[0]
178181
ci_low = pd.Series(treatment_outcome["mean_ci_lower"] - control_outcome["mean_ci_upper"])
179182
ci_high = pd.Series(treatment_outcome["mean_ci_upper"] - control_outcome["mean_ci_lower"])
180-
return pd.Series(treatment_outcome["mean"] - control_outcome["mean"]), [ci_low, ci_high]
183+
return EffectEstimate("ate", pd.Series(treatment_outcome["mean"] - control_outcome["mean"]), ci_low, ci_high)
181184

182185
def _get_confidence_intervals(self, model, treatment):
183186
confidence_intervals = model.conf_int(alpha=self.alpha, cols=None)

causal_testing/estimation/logistic_regression_estimator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pandas as pd
77
import statsmodels.formula.api as smf
88

9+
from causal_testing.estimation.effect_estimate import EffectEstimate
910
from causal_testing.estimation.abstract_regression_estimator import RegressionEstimator
1011

1112
logger = logging.getLogger(__name__)
@@ -32,15 +33,17 @@ def add_modelling_assumptions(self):
3233
self.modelling_assumptions.append("The outcome must be binary.")
3334
self.modelling_assumptions.append("Independently and identically distributed errors.")
3435

35-
def estimate_unit_odds_ratio(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
36+
def estimate_unit_odds_ratio(self) -> EffectEstimate:
3637
"""Estimate the odds ratio of increasing the treatment by one. In logistic regression, this corresponds to the
3738
coefficient of the treatment of interest.
3839
3940
:return: The odds ratio. Confidence intervals are not yet supported.
4041
"""
4142
model = self.fit_model(self.df)
4243
ci_low, ci_high = np.exp(model.conf_int(self.alpha).loc[self.base_test_case.treatment_variable.name])
43-
return pd.Series(np.exp(model.params[self.base_test_case.treatment_variable.name])), [
44+
return EffectEstimate(
45+
"odds_ratio",
46+
pd.Series(np.exp(model.params[self.base_test_case.treatment_variable.name])),
4447
pd.Series(ci_low),
4548
pd.Series(ci_high),
46-
]
49+
)

causal_testing/main.py

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from causal_testing.testing.causal_test_case import CausalTestCase
2020
from causal_testing.testing.base_test_case import BaseTestCase
2121
from causal_testing.testing.causal_effect import NoEffect, SomeEffect, Positive, Negative
22-
from causal_testing.testing.causal_test_result import CausalTestResult, TestValue
22+
from causal_testing.testing.causal_test_result import CausalTestResult
2323
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
2424
from causal_testing.estimation.logistic_regression_estimator import LogisticRegressionEstimator
2525

@@ -332,7 +332,6 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC
332332
expected_causal_effect=expected_effect,
333333
estimate_type=test.get("estimate_type", "ate"),
334334
estimate_params=test.get("estimate_params"),
335-
effect_modifier_configuration=test.get("effect_modifier_configuration"),
336335
estimator=estimator,
337336
)
338337

@@ -376,10 +375,7 @@ def run_tests_in_batches(self, batch_size: int = 100, silent: bool = False) -> L
376375
logger.error(f"Type or attribute error in test: {str(e)}")
377376
raise
378377
batch_results.append(
379-
CausalTestResult(
380-
estimator=test_case.estimator,
381-
test_value=TestValue("Error", str(e)),
382-
)
378+
CausalTestResult(effect_estimate=None, estimator=test_case.estimator, error_message=str(e))
383379
)
384380

385381
progress.update(1)
@@ -410,10 +406,7 @@ def run_tests(self, silent=False) -> List[CausalTestResult]:
410406
if not silent:
411407
logger.error(f"Error running test {test_case}: {str(e)}")
412408
raise
413-
result = CausalTestResult(
414-
estimator=test_case.estimator,
415-
test_value=TestValue("Error", str(e)),
416-
)
409+
result = CausalTestResult(estimator=test_case.estimator, effect_estimate=None, error_message=str(e))
417410
results.append(result)
418411
logger.info(f"Test errored: {test_case}")
419412

@@ -432,17 +425,10 @@ def save_results(self, results: List[CausalTestResult], output_path: str = None)
432425
# Combine test configs with their results
433426
json_results = []
434427
for test_config, test_case, result in zip(test_configs["tests"], self.test_cases, results):
435-
# Handle effect estimate - could be a Series or other format
436-
effect_estimate = result.test_value.value
437-
if isinstance(effect_estimate, pd.Series):
438-
effect_estimate = effect_estimate.to_dict()
439-
440-
# Handle confidence intervals - convert to list if needed
441-
ci_low = result.ci_low()
442-
ci_high = result.ci_high()
443-
444428
# Determine if test failed based on expected vs actual effect
445-
test_passed = test_case.expected_causal_effect.apply(result) if result.test_value.type != "Error" else False
429+
test_passed = (
430+
test_case.expected_causal_effect.apply(result) if result.effect_estimate is not None else False
431+
)
446432

447433
output = {
448434
"name": test_config["name"],
@@ -454,15 +440,16 @@ def save_results(self, results: List[CausalTestResult], output_path: str = None)
454440
"alpha": test_config.get("alpha", 0.05),
455441
"skip": test_config.get("skip", False),
456442
"passed": test_passed,
457-
"result": {
458-
"treatment": result.estimator.base_test_case.treatment_variable.name,
459-
"outcome": result.estimator.base_test_case.outcome_variable.name,
460-
"adjustment_set": list(result.adjustment_set) if result.adjustment_set else [],
461-
"effect_measure": result.test_value.type,
462-
"effect_estimate": effect_estimate,
463-
"ci_low": ci_low,
464-
"ci_high": ci_high,
465-
},
443+
"result": (
444+
{
445+
"treatment": result.estimator.base_test_case.treatment_variable.name,
446+
"outcome": result.estimator.base_test_case.outcome_variable.name,
447+
"adjustment_set": list(result.adjustment_set) if result.adjustment_set else [],
448+
}
449+
| result.effect_estimate.to_dict()
450+
if result.effect_estimate
451+
else {"error": result.error_message}
452+
),
466453
}
467454
json_results.append(output)
468455

causal_testing/surrogate/surrogate_search_algorithms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Module containing implementation of search algorithm for surrogate search """
1+
"""Module containing implementation of search algorithm for surrogate search"""
22

33
# Fitness functions are required to be iteratively defined, including all variables within.
44

@@ -45,7 +45,7 @@ def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
4545
for i, adjustment in enumerate(surrogate_model.adjustment_set):
4646
adjustment_dict[adjustment] = solution[i + 1]
4747

48-
ate = surrogate_model.estimate_ate_calculated(adjustment_dict)
48+
ate = surrogate_model.estimate_ate_calculated(adjustment_dict).value
4949
if len(ate) > 1:
5050
raise ValueError(
5151
"Multiple ate values provided but currently only single values supported in this method"

0 commit comments

Comments
 (0)