Skip to content
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/lint-format.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ jobs:

- name: Archive production artifacts
if: ${{ success() }} || ${{ failure() }}
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v3
with:
name: MegaLinter reports
path: |
megalinter-reports
mega-linter.log
mega-linter.log
4 changes: 2 additions & 2 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,8 @@ min-public-methods=2
[EXCEPTIONS]

# Exceptions that will emit a warning when caught.
overgeneral-exceptions=BaseException,
Exception
overgeneral-exceptions=builtins.BaseException,
builtins.Exception


[FORMAT]
Expand Down
75 changes: 75 additions & 0 deletions causal_testing/estimation/cubic_spline_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""This module contains the CubicSplineRegressionEstimator class, for estimating
continuous outcomes with changes in behaviour"""

import logging
from typing import Any

import pandas as pd

from causal_testing.specification.variable import Variable
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator

logger = logging.getLogger(__name__)


class CubicSplineRegressionEstimator(LinearRegressionEstimator):
"""A Cubic Spline Regression Estimator is a parametric estimator which restricts the variables in the data to a
combination of parameters and basis functions of the variables.
"""

def __init__(
# pylint: disable=too-many-arguments
self,
treatment: str,
treatment_value: float,
control_value: float,
adjustment_set: set,
outcome: str,
basis: int,
df: pd.DataFrame = None,
effect_modifiers: dict[Variable:Any] = None,
formula: str = None,
alpha: float = 0.05,
expected_relationship=None,
):
super().__init__(
treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers, formula, alpha
)

self.expected_relationship = expected_relationship

if effect_modifiers is None:
effect_modifiers = []

if formula is None:
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
self.formula = f"{outcome} ~ cr({'+'.join(terms)}, df={basis})"

def estimate_ate_calculated(self, adjustment_config: dict = None) -> pd.Series:
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
by changing the treatment variable from the control value to the treatment value. Here, we actually
calculate the expected outcomes under control and treatment and divide one by the other. This
allows for custom terms to be put in such as squares, inverses, products, etc.

:param: adjustment_config: The configuration of the adjustment set as a dict mapping variable names to
their values. N.B. Every variable in the adjustment set MUST have a value in
order to estimate the outcome under control and treatment.

:return: The average treatment effect.
"""
model = self._run_regression()

x = {"Intercept": 1, self.treatment: self.treatment_value}
if adjustment_config is not None:
for k, v in adjustment_config.items():
x[k] = v
if self.effect_modifiers is not None:
for k, v in self.effect_modifiers.items():
x[k] = v

Check warning on line 68 in causal_testing/estimation/cubic_spline_estimator.py

View check run for this annotation

Codecov / codecov/patch

causal_testing/estimation/cubic_spline_estimator.py#L68

Added line #L68 was not covered by tests

treatment = model.predict(x).iloc[0]

x[self.treatment] = self.control_value
control = model.predict(x).iloc[0]

return pd.Series(treatment - control)
73 changes: 73 additions & 0 deletions causal_testing/estimation/estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""This module contains the Estimator abstract class"""

import logging
from abc import ABC, abstractmethod
from typing import Any

import pandas as pd

logger = logging.getLogger(__name__)


class Estimator(ABC):
# pylint: disable=too-many-instance-attributes
"""An estimator contains all of the information necessary to compute a causal estimate for the effect of changing
a set of treatment variables to a set of values.

All estimators must implement the following two methods:

1) add_modelling_assumptions: The validity of a model-assisted causal inference result depends on whether
the modelling assumptions imposed by a model actually hold. Therefore, for each model, is important to state
the modelling assumption upon which the validity of the results depend. To achieve this, the estimator object
maintains a list of modelling assumptions (as strings). If a user wishes to implement their own estimator, they
must implement this method and add all assumptions to the list of modelling assumptions.

2) estimate_ate: All estimators must be capable of returning the average treatment effect as a minimum. That is, the
average effect of the intervention (changing treatment from control to treated value) on the outcome of interest
adjusted for all confounders.
"""

def __init__(
# pylint: disable=too-many-arguments
self,
treatment: str,
treatment_value: float,
control_value: float,
adjustment_set: set,
outcome: str,
df: pd.DataFrame = None,
effect_modifiers: dict[str:Any] = None,
alpha: float = 0.05,
query: str = "",
):
self.treatment = treatment
self.treatment_value = treatment_value
self.control_value = control_value
self.adjustment_set = adjustment_set
self.outcome = outcome
self.alpha = alpha
self.df = df.query(query) if query else df

if effect_modifiers is None:
self.effect_modifiers = {}
else:
self.effect_modifiers = effect_modifiers
self.modelling_assumptions = []
if query:
self.modelling_assumptions.append(query)
self.add_modelling_assumptions()
logger.debug("Effect Modifiers: %s", self.effect_modifiers)

@abstractmethod
def add_modelling_assumptions(self):
"""
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
must hold if the resulting causal inference is to be considered valid.
"""

def compute_confidence_intervals(self) -> list[float, float]:
"""
Estimate the 95% Wald confidence intervals for the effect of changing the treatment from control values to
treatment values on the outcome.
:return: 95% Wald confidence intervals.
"""
Loading
Loading