Skip to content

EffectEstimate dataclass #353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions causal_testing/estimation/cubic_spline_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from causal_testing.specification.variable import Variable
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
from causal_testing.estimation.effect_estimate import EffectEstimate
from causal_testing.testing.base_test_case import BaseTestCase

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -47,7 +48,7 @@ def __init__(
)
self.formula = f"{base_test_case.outcome_variable.name} ~ cr({'+'.join(terms)}, df={basis})"

def estimate_ate_calculated(self, adjustment_config: dict = None) -> pd.Series:
def estimate_ate_calculated(self, adjustment_config: dict = None) -> EffectEstimate:
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
by changing the treatment variable from the control value to the treatment value. Here, we actually
calculate the expected outcomes under control and treatment and divide one by the other. This
Expand All @@ -74,4 +75,4 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> pd.Series:
x[self.base_test_case.treatment_variable.name] = self.control_value
control = model.predict(x).iloc[0]

return pd.Series(treatment - control)
return EffectEstimate("ate", pd.Series(treatment - control))
43 changes: 43 additions & 0 deletions causal_testing/estimation/effect_estimate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
This module contains the EffectEstimate dataclass.
"""

from dataclasses import dataclass
import pandas as pd


@dataclass
class EffectEstimate:
"""
A dataclass to hold the value and confidence intervals of a causal effect estimate

:ivar type: The type of estimate, e.g. ate, or risk_ratio
(used to determine whether the estimate matches the expected effect)
:ivar value: The estimated causal effect
:ivar ci_low: The lower confidence interval
:ivar ci_high: The upper confidence interval
"""

type: str
value: pd.Series
ci_low: pd.Series = None
ci_high: pd.Series = None

def ci_valid(self) -> bool:
"""Return whether or not the result has valid confidence invervals"""
return (
self.ci_low is not None
and self.ci_high is not None
and not (pd.isnull(self.ci_low).any() or pd.isnull(self.ci_high).any())
)

def to_dict(self) -> dict:
"""Return representation as a dict."""
d = {"effect_measure": self.type, "effect_estimate": self.value.to_dict()}
if self.ci_valid():
return d | {"ci_low": self.ci_low.to_dict(), "ci_high": self.ci_high.to_dict()}
return d

def to_df(self) -> pd.DataFrame:
"""Return representation as a pandas dataframe."""
return pd.DataFrame({"effect_estimate": self.value, "ci_low": self.ci_low, "ci_high": self.ci_high})
25 changes: 16 additions & 9 deletions causal_testing/estimation/experimental_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pandas as pd

from causal_testing.estimation.abstract_estimator import Estimator
from causal_testing.estimation.effect_estimate import EffectEstimate
from causal_testing.testing.base_test_case import BaseTestCase


Expand Down Expand Up @@ -55,7 +56,7 @@ def run_system(self, configuration: dict) -> dict:
:returns: The resulting output as a dict.
"""

def estimate_ate(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
def estimate_ate(self) -> EffectEstimate:
"""Estimate the average treatment effect of the treatment on the outcome. That is, the change in outcome caused
by changing the treatment variable from the control value to the treatment value.

Expand Down Expand Up @@ -88,14 +89,20 @@ def estimate_ate(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
ci_low = difference.iloc[ci_low_index]
ci_high = difference.iloc[self.repeats - ci_low_index]

return pd.Series(
{self.base_test_case.treatment_variable.name: difference.mean()[self.base_test_case.outcome_variable.name]}
), [
return EffectEstimate(
"ate",
pd.Series(
{
self.base_test_case.treatment_variable.name: difference.mean()[
self.base_test_case.outcome_variable.name
]
}
),
pd.Series({self.base_test_case.treatment_variable.name: ci_low[self.base_test_case.outcome_variable.name]}),
pd.Series(
{self.base_test_case.treatment_variable.name: ci_high[self.base_test_case.outcome_variable.name]}
),
]
)

def estimate_risk_ratio(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
"""Estimate the risk ratio of the treatment on the outcome. That is, the change in outcome caused
Expand Down Expand Up @@ -130,11 +137,11 @@ def estimate_risk_ratio(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
ci_low = difference.iloc[ci_low_index]
ci_high = difference.iloc[self.repeats - ci_low_index]

return pd.Series(
{self.base_test_case.treatment_variable.name: difference.mean()[self.base_test_case.outcome_variable.name]}
), [
return EffectEstimate(
"ate",
{self.base_test_case.treatment_variable.name: difference.mean()[self.base_test_case.outcome_variable.name]},
pd.Series({self.base_test_case.treatment_variable.name: ci_low[self.base_test_case.outcome_variable.name]}),
pd.Series(
{self.base_test_case.treatment_variable.name: ci_high[self.base_test_case.outcome_variable.name]}
),
]
)
9 changes: 5 additions & 4 deletions causal_testing/estimation/instrumental_variable_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import statsmodels.api as sm

from causal_testing.estimation.abstract_estimator import Estimator
from causal_testing.estimation.effect_estimate import EffectEstimate
from causal_testing.testing.base_test_case import BaseTestCase

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -61,7 +62,7 @@ def add_modelling_assumptions(self):
"""
)

def estimate_iv_coefficient(self, df) -> float:
def iv_coefficient(self, df) -> float:
"""
Estimate the linear regression coefficient of the treatment on the
outcome.
Expand All @@ -75,16 +76,16 @@ def estimate_iv_coefficient(self, df) -> float:
# Estimate the coefficient of I on X by cancelling
return ab / a

def estimate_coefficient(self, bootstrap_size=100) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
def estimate_coefficient(self, bootstrap_size=100) -> EffectEstimate:
"""
Estimate the unit ate (i.e. coefficient) of the treatment on the
outcome.
"""
bootstraps = sorted(
[self.estimate_iv_coefficient(self.df.sample(len(self.df), replace=True)) for _ in range(bootstrap_size)]
[self.iv_coefficient(self.df.sample(len(self.df), replace=True)) for _ in range(bootstrap_size)]
)
bound = ceil((bootstrap_size * self.alpha) / 2)
ci_low = pd.Series(bootstraps[bound])
ci_high = pd.Series(bootstraps[bootstrap_size - bound])

return pd.Series(self.estimate_iv_coefficient(self.df)), [ci_low, ci_high]
return EffectEstimate("coefficient", pd.Series(self.iv_coefficient(self.df)), ci_low, ci_high)
5 changes: 3 additions & 2 deletions causal_testing/estimation/ipcw_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from lifelines import CoxPHFitter

from causal_testing.estimation.abstract_estimator import Estimator
from causal_testing.estimation.effect_estimate import EffectEstimate
from causal_testing.testing.base_test_case import BaseTestCase
from causal_testing.specification.variable import Variable

Expand Down Expand Up @@ -285,7 +286,7 @@ def preprocess_data(self):
if len(self.df.loc[self.df["trtrand"] == 1]) == 0:
raise ValueError(f"No individuals began the treatment strategy {self.treatment_strategy}")

def estimate_hazard_ratio(self):
def estimate_hazard_ratio(self) -> EffectEstimate:
"""
Estimate the hazard ratio.
"""
Expand Down Expand Up @@ -380,4 +381,4 @@ def estimate_hazard_ratio(self):

ci_low, ci_high = [np.exp(cox_ph.confidence_intervals_)[col] for col in cox_ph.confidence_intervals_.columns]

return (cox_ph.hazard_ratios_, (ci_low, ci_high))
return EffectEstimate("hazard_ratio", cox_ph.hazard_ratios_, ci_low, ci_high)
21 changes: 12 additions & 9 deletions causal_testing/estimation/linear_regression_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from causal_testing.specification.variable import Variable
from causal_testing.estimation.genetic_programming_regression_fitter import GP
from causal_testing.estimation.abstract_regression_estimator import RegressionEstimator
from causal_testing.estimation.effect_estimate import EffectEstimate
from causal_testing.testing.base_test_case import BaseTestCase

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -92,7 +93,7 @@ def gp_formula(
formula = gp.simplify(formula)
self.formula = f"{self.base_test_case.outcome_variable.name} ~ I({formula}) - 1"

def estimate_coefficient(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
def estimate_coefficient(self) -> EffectEstimate:
"""Estimate the unit average treatment effect of the treatment on the outcome. That is, the change in outcome
caused by a unit change in treatment.

Expand Down Expand Up @@ -121,9 +122,9 @@ def estimate_coefficient(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
), f"{treatment} not in\n{' ' + str(model.params.index).replace(newline, newline + ' ')}"
unit_effect = model.params[treatment] # Unit effect is the coefficient of the treatment
[ci_low, ci_high] = self._get_confidence_intervals(model, treatment)
return unit_effect, [ci_low, ci_high]
return EffectEstimate("coefficient", unit_effect, ci_low, ci_high)

def estimate_ate(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
def estimate_ate(self) -> EffectEstimate:
"""Estimate the average treatment effect of the treatment on the outcome. That is, the change in outcome caused
by changing the treatment variable from the control value to the treatment value.

Expand All @@ -146,10 +147,10 @@ def estimate_ate(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
t_test_results = model.t_test(individuals.loc["treated"] - individuals.loc["control"])
ate = pd.Series(t_test_results.effect[0])
confidence_intervals = list(t_test_results.conf_int(alpha=self.alpha).flatten())
confidence_intervals = [pd.Series(interval) for interval in confidence_intervals]
return ate, confidence_intervals
ci_low, ci_high = [pd.Series(interval) for interval in confidence_intervals]
return EffectEstimate("ate", ate, ci_low, ci_high)

def estimate_risk_ratio(self, adjustment_config: dict = None) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
def estimate_risk_ratio(self, adjustment_config: dict = None) -> EffectEstimate:
"""Estimate the risk_ratio effect of the treatment on the outcome. That is, the change in outcome caused
by changing the treatment variable from the control value to the treatment value.

Expand All @@ -159,9 +160,11 @@ def estimate_risk_ratio(self, adjustment_config: dict = None) -> tuple[pd.Series
control_outcome, treatment_outcome = prediction.iloc[1], prediction.iloc[0]
ci_low = pd.Series(treatment_outcome["mean_ci_lower"] / control_outcome["mean_ci_upper"])
ci_high = pd.Series(treatment_outcome["mean_ci_upper"] / control_outcome["mean_ci_lower"])
return pd.Series(treatment_outcome["mean"] / control_outcome["mean"]), [ci_low, ci_high]
return EffectEstimate(
"risk_ratio", pd.Series(treatment_outcome["mean"] / control_outcome["mean"]), ci_low, ci_high
)

def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
def estimate_ate_calculated(self, adjustment_config: dict = None) -> EffectEstimate:
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
by changing the treatment variable from the control value to the treatment value. Here, we actually
calculate the expected outcomes under control and treatment and divide one by the other. This
Expand All @@ -177,7 +180,7 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[pd.Se
control_outcome, treatment_outcome = prediction.iloc[1], prediction.iloc[0]
ci_low = pd.Series(treatment_outcome["mean_ci_lower"] - control_outcome["mean_ci_upper"])
ci_high = pd.Series(treatment_outcome["mean_ci_upper"] - control_outcome["mean_ci_lower"])
return pd.Series(treatment_outcome["mean"] - control_outcome["mean"]), [ci_low, ci_high]
return EffectEstimate("ate", pd.Series(treatment_outcome["mean"] - control_outcome["mean"]), ci_low, ci_high)

def _get_confidence_intervals(self, model, treatment):
confidence_intervals = model.conf_int(alpha=self.alpha, cols=None)
Expand Down
9 changes: 6 additions & 3 deletions causal_testing/estimation/logistic_regression_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd
import statsmodels.formula.api as smf

from causal_testing.estimation.effect_estimate import EffectEstimate
from causal_testing.estimation.abstract_regression_estimator import RegressionEstimator

logger = logging.getLogger(__name__)
Expand All @@ -32,15 +33,17 @@ def add_modelling_assumptions(self):
self.modelling_assumptions.append("The outcome must be binary.")
self.modelling_assumptions.append("Independently and identically distributed errors.")

def estimate_unit_odds_ratio(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
def estimate_unit_odds_ratio(self) -> EffectEstimate:
"""Estimate the odds ratio of increasing the treatment by one. In logistic regression, this corresponds to the
coefficient of the treatment of interest.

:return: The odds ratio. Confidence intervals are not yet supported.
"""
model = self.fit_model(self.df)
ci_low, ci_high = np.exp(model.conf_int(self.alpha).loc[self.base_test_case.treatment_variable.name])
return pd.Series(np.exp(model.params[self.base_test_case.treatment_variable.name])), [
return EffectEstimate(
"odds_ratio",
pd.Series(np.exp(model.params[self.base_test_case.treatment_variable.name])),
pd.Series(ci_low),
pd.Series(ci_high),
]
)
45 changes: 16 additions & 29 deletions causal_testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from causal_testing.testing.causal_test_case import CausalTestCase
from causal_testing.testing.base_test_case import BaseTestCase
from causal_testing.testing.causal_effect import NoEffect, SomeEffect, Positive, Negative
from causal_testing.testing.causal_test_result import CausalTestResult, TestValue
from causal_testing.testing.causal_test_result import CausalTestResult
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
from causal_testing.estimation.logistic_regression_estimator import LogisticRegressionEstimator

Expand Down Expand Up @@ -332,7 +332,6 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC
expected_causal_effect=expected_effect,
estimate_type=test.get("estimate_type", "ate"),
estimate_params=test.get("estimate_params"),
effect_modifier_configuration=test.get("effect_modifier_configuration"),
estimator=estimator,
)

Expand Down Expand Up @@ -376,10 +375,7 @@ def run_tests_in_batches(self, batch_size: int = 100, silent: bool = False) -> L
logger.error(f"Type or attribute error in test: {str(e)}")
raise
batch_results.append(
CausalTestResult(
estimator=test_case.estimator,
test_value=TestValue("Error", str(e)),
)
CausalTestResult(effect_estimate=None, estimator=test_case.estimator, error_message=str(e))
)

progress.update(1)
Expand Down Expand Up @@ -410,10 +406,7 @@ def run_tests(self, silent=False) -> List[CausalTestResult]:
if not silent:
logger.error(f"Error running test {test_case}: {str(e)}")
raise
result = CausalTestResult(
estimator=test_case.estimator,
test_value=TestValue("Error", str(e)),
)
result = CausalTestResult(estimator=test_case.estimator, effect_estimate=None, error_message=str(e))
results.append(result)
logger.info(f"Test errored: {test_case}")

Expand All @@ -432,17 +425,10 @@ def save_results(self, results: List[CausalTestResult], output_path: str = None)
# Combine test configs with their results
json_results = []
for test_config, test_case, result in zip(test_configs["tests"], self.test_cases, results):
# Handle effect estimate - could be a Series or other format
effect_estimate = result.test_value.value
if isinstance(effect_estimate, pd.Series):
effect_estimate = effect_estimate.to_dict()

# Handle confidence intervals - convert to list if needed
ci_low = result.ci_low()
ci_high = result.ci_high()

# Determine if test failed based on expected vs actual effect
test_passed = test_case.expected_causal_effect.apply(result) if result.test_value.type != "Error" else False
test_passed = (
test_case.expected_causal_effect.apply(result) if result.effect_estimate is not None else False
)

output = {
"name": test_config["name"],
Expand All @@ -454,15 +440,16 @@ def save_results(self, results: List[CausalTestResult], output_path: str = None)
"alpha": test_config.get("alpha", 0.05),
"skip": test_config.get("skip", False),
"passed": test_passed,
"result": {
"treatment": result.estimator.base_test_case.treatment_variable.name,
"outcome": result.estimator.base_test_case.outcome_variable.name,
"adjustment_set": list(result.adjustment_set) if result.adjustment_set else [],
"effect_measure": result.test_value.type,
"effect_estimate": effect_estimate,
"ci_low": ci_low,
"ci_high": ci_high,
},
"result": (
{
"treatment": result.estimator.base_test_case.treatment_variable.name,
"outcome": result.estimator.base_test_case.outcome_variable.name,
"adjustment_set": list(result.adjustment_set) if result.adjustment_set else [],
}
| result.effect_estimate.to_dict()
if result.effect_estimate
else {"error": result.error_message}
),
}
json_results.append(output)

Expand Down
4 changes: 2 additions & 2 deletions causal_testing/surrogate/surrogate_search_algorithms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Module containing implementation of search algorithm for surrogate search """
"""Module containing implementation of search algorithm for surrogate search"""

# Fitness functions are required to be iteratively defined, including all variables within.

Expand Down Expand Up @@ -45,7 +45,7 @@ def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
for i, adjustment in enumerate(surrogate_model.adjustment_set):
adjustment_dict[adjustment] = solution[i + 1]

ate = surrogate_model.estimate_ate_calculated(adjustment_dict)
ate = surrogate_model.estimate_ate_calculated(adjustment_dict).value
if len(ate) > 1:
raise ValueError(
"Multiple ate values provided but currently only single values supported in this method"
Expand Down
Loading
Loading