Skip to content

Commit a13d83b

Browse files
committed
Moved estimation to a separate package
1 parent c6f9d31 commit a13d83b

34 files changed

+1348
-1200
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""This module contains the CubicSplineRegressionEstimator class, for estimating continuous outcomes with changes in behaviour"""
2+
3+
import logging
4+
from abc import ABC, abstractmethod
5+
from typing import Any
6+
from math import ceil
7+
8+
import numpy as np
9+
import pandas as pd
10+
import statsmodels.api as sm
11+
import statsmodels.formula.api as smf
12+
from patsy import dmatrix # pylint: disable = no-name-in-module
13+
from patsy import ModelDesc
14+
from statsmodels.regression.linear_model import RegressionResultsWrapper
15+
from statsmodels.tools.sm_exceptions import PerfectSeparationError
16+
from lifelines import CoxPHFitter
17+
18+
from causal_testing.specification.variable import Variable
19+
from causal_testing.specification.capabilities import TreatmentSequence, Capability
20+
from causal_testing.estimation.estimator import Estimator
21+
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
class CubicSplineRegressionEstimator(LinearRegressionEstimator):
27+
"""A Cubic Spline Regression Estimator is a parametric estimator which restricts the variables in the data to a
28+
combination of parameters and basis functions of the variables.
29+
"""
30+
31+
def __init__(
32+
# pylint: disable=too-many-arguments
33+
self,
34+
treatment: str,
35+
treatment_value: float,
36+
control_value: float,
37+
adjustment_set: set,
38+
outcome: str,
39+
basis: int,
40+
df: pd.DataFrame = None,
41+
effect_modifiers: dict[Variable:Any] = None,
42+
formula: str = None,
43+
alpha: float = 0.05,
44+
expected_relationship=None,
45+
):
46+
super().__init__(
47+
treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers, formula, alpha
48+
)
49+
50+
self.expected_relationship = expected_relationship
51+
52+
if effect_modifiers is None:
53+
effect_modifiers = []
54+
55+
if formula is None:
56+
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
57+
self.formula = f"{outcome} ~ cr({'+'.join(terms)}, df={basis})"
58+
59+
def estimate_ate_calculated(self, adjustment_config: dict = None) -> pd.Series:
60+
model = self._run_linear_regression()
61+
62+
x = {"Intercept": 1, self.treatment: self.treatment_value}
63+
if adjustment_config is not None:
64+
for k, v in adjustment_config.items():
65+
x[k] = v
66+
if self.effect_modifiers is not None:
67+
for k, v in self.effect_modifiers.items():
68+
x[k] = v
69+
70+
treatment = model.predict(x).iloc[0]
71+
72+
x[self.treatment] = self.control_value
73+
control = model.predict(x).iloc[0]
74+
75+
return pd.Series(treatment - control)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""This module contains the Estimator abstract class"""
2+
3+
import logging
4+
from abc import ABC, abstractmethod
5+
from typing import Any
6+
from math import ceil
7+
8+
import numpy as np
9+
import pandas as pd
10+
import statsmodels.api as sm
11+
import statsmodels.formula.api as smf
12+
from patsy import dmatrix # pylint: disable = no-name-in-module
13+
from patsy import ModelDesc
14+
from statsmodels.regression.linear_model import RegressionResultsWrapper
15+
from statsmodels.tools.sm_exceptions import PerfectSeparationError
16+
from lifelines import CoxPHFitter
17+
18+
from causal_testing.specification.variable import Variable
19+
from causal_testing.specification.capabilities import TreatmentSequence, Capability
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
class Estimator(ABC):
25+
# pylint: disable=too-many-instance-attributes
26+
"""An estimator contains all of the information necessary to compute a causal estimate for the effect of changing
27+
a set of treatment variables to a set of values.
28+
29+
All estimators must implement the following two methods:
30+
31+
1) add_modelling_assumptions: The validity of a model-assisted causal inference result depends on whether
32+
the modelling assumptions imposed by a model actually hold. Therefore, for each model, is important to state
33+
the modelling assumption upon which the validity of the results depend. To achieve this, the estimator object
34+
maintains a list of modelling assumptions (as strings). If a user wishes to implement their own estimator, they
35+
must implement this method and add all assumptions to the list of modelling assumptions.
36+
37+
2) estimate_ate: All estimators must be capable of returning the average treatment effect as a minimum. That is, the
38+
average effect of the intervention (changing treatment from control to treated value) on the outcome of interest
39+
adjusted for all confounders.
40+
"""
41+
42+
def __init__(
43+
# pylint: disable=too-many-arguments
44+
self,
45+
treatment: str,
46+
treatment_value: float,
47+
control_value: float,
48+
adjustment_set: set,
49+
outcome: str,
50+
df: pd.DataFrame = None,
51+
effect_modifiers: dict[str:Any] = None,
52+
alpha: float = 0.05,
53+
query: str = "",
54+
):
55+
self.treatment = treatment
56+
self.treatment_value = treatment_value
57+
self.control_value = control_value
58+
self.adjustment_set = adjustment_set
59+
self.outcome = outcome
60+
self.alpha = alpha
61+
self.df = df.query(query) if query else df
62+
63+
if effect_modifiers is None:
64+
self.effect_modifiers = {}
65+
elif isinstance(effect_modifiers, dict):
66+
self.effect_modifiers = effect_modifiers
67+
else:
68+
raise ValueError(f"Unsupported type for effect_modifiers {effect_modifiers}. Expected iterable")
69+
self.modelling_assumptions = []
70+
if query:
71+
self.modelling_assumptions.append(query)
72+
self.add_modelling_assumptions()
73+
logger.debug("Effect Modifiers: %s", self.effect_modifiers)
74+
75+
@abstractmethod
76+
def add_modelling_assumptions(self):
77+
"""
78+
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
79+
must hold if the resulting causal inference is to be considered valid.
80+
"""
81+
82+
def compute_confidence_intervals(self) -> list[float, float]:
83+
"""
84+
Estimate the 95% Wald confidence intervals for the effect of changing the treatment from control values to
85+
treatment values on the outcome.
86+
:return: 95% Wald confidence intervals.
87+
"""
File renamed without changes.

0 commit comments

Comments
 (0)