Skip to content

Commit 127a2f4

Browse files
committed
Added experimental estimator to keep functionality of experimental data collector
1 parent 73eb5f1 commit 127a2f4

File tree

2 files changed

+147
-0
lines changed

2 files changed

+147
-0
lines changed
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""This module contains the ExperimentalEstimator class for directly interacting with the system under test."""
2+
3+
import pandas as pd
4+
from typing import Any
5+
from abc import abstractmethod
6+
7+
from causal_testing.estimation.abstract_estimator import Estimator
8+
9+
10+
class ExperimentalEstimator(Estimator):
11+
"""A Logistic Regression Estimator is a parametric estimator which restricts the variables in the data to a linear
12+
combination of parameters and functions of the variables (note these functions need not be linear). It is designed
13+
for estimating categorical outcomes.
14+
"""
15+
16+
def __init__(
17+
# pylint: disable=too-many-arguments
18+
self,
19+
treatment: str,
20+
treatment_value: float,
21+
control_value: float,
22+
adjustment_set: dict[str:Any],
23+
outcome: str,
24+
effect_modifiers: dict[str:Any] = None,
25+
alpha: float = 0.05,
26+
repeats: int = 200,
27+
):
28+
super().__init__(
29+
treatment=treatment,
30+
treatment_value=treatment_value,
31+
control_value=control_value,
32+
adjustment_set=adjustment_set,
33+
outcome=outcome,
34+
effect_modifiers=effect_modifiers,
35+
alpha=alpha,
36+
)
37+
if effect_modifiers is None:
38+
self.effect_modifiers = {}
39+
self.repeats = repeats
40+
41+
def add_modelling_assumptions(self):
42+
"""
43+
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
44+
must hold if the resulting causal inference is to be considered valid.
45+
"""
46+
self.modelling_assumptions.append(
47+
"The supplied number of repeats must be sufficient for statistical significance"
48+
)
49+
50+
@abstractmethod
51+
def run_system(self, configuration: dict) -> dict:
52+
"""
53+
Runs the system under test with the supplied configuration and supplies the outputs as a dict.
54+
:param configuration: The run configuration arguments.
55+
:returns: The resulting output as a dict.
56+
"""
57+
58+
def estimate_ate(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
59+
"""Estimate the average treatment effect of the treatment on the outcome. That is, the change in outcome caused
60+
by changing the treatment variable from the control value to the treatment value.
61+
62+
:return: The average treatment effect and the bootstrapped confidence intervals.
63+
"""
64+
control_configuration = self.adjustment_set | self.effect_modifiers | {self.treatment: self.control_value}
65+
treatment_configuration = self.adjustment_set | self.effect_modifiers | {self.treatment: self.treatment_value}
66+
67+
control_outcomes = pd.DataFrame([self.run_system(control_configuration) for _ in range(self.repeats)])
68+
treatment_outcomes = pd.DataFrame([self.run_system(treatment_configuration) for _ in range(self.repeats)])
69+
70+
difference = (treatment_outcomes[self.outcome] - control_outcomes[self.outcome]).sort_values().reset_index()
71+
72+
ci_low_index = round(self.repeats * (self.alpha / 2))
73+
ci_low = difference.iloc[ci_low_index]
74+
ci_high = difference.iloc[self.repeats - ci_low_index]
75+
76+
return pd.Series({self.treatment: difference.mean()[self.outcome]}), [
77+
pd.Series({self.treatment: ci_low[self.outcome]}),
78+
pd.Series({self.treatment: ci_high[self.outcome]}),
79+
]
80+
81+
def estimate_risk_ratio(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
82+
"""Estimate the risk ratio of the treatment on the outcome. That is, the change in outcome caused
83+
by changing the treatment variable from the control value to the treatment value.
84+
85+
:return: The average treatment effect and the bootstrapped confidence intervals.
86+
"""
87+
control_configuration = self.adjustment_set | self.effect_modifiers | {self.treatment: self.control_value}
88+
treatment_configuration = self.adjustment_set | self.effect_modifiers | {self.treatment: self.treatment_value}
89+
90+
control_outcomes = pd.DataFrame([self.run_system(control_configuration) for _ in range(self.repeats)])
91+
treatment_outcomes = pd.DataFrame([self.run_system(treatment_configuration) for _ in range(self.repeats)])
92+
93+
difference = (treatment_outcomes[self.outcome] / control_outcomes[self.outcome]).sort_values().reset_index()
94+
95+
ci_low_index = round(self.repeats * (self.alpha / 2))
96+
ci_low = difference.iloc[ci_low_index]
97+
ci_high = difference.iloc[self.repeats - ci_low_index]
98+
99+
return pd.Series({self.treatment: difference.mean()[self.outcome]}), [
100+
pd.Series({self.treatment: ci_low[self.outcome]}),
101+
pd.Series({self.treatment: ci_high[self.outcome]}),
102+
]
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import unittest
2+
from causal_testing.estimation.experimental_estimator import ExperimentalEstimator
3+
4+
5+
class ConcreteExperimentalEstimator(ExperimentalEstimator):
6+
def run_system(self, configuration):
7+
return {"Y": 2 * configuration["X"]}
8+
9+
10+
class TestExperimentalEstimator(unittest.TestCase):
11+
"""
12+
Test the experimental estimator.
13+
"""
14+
15+
def test_estimate_ate(self):
16+
estimator = ConcreteExperimentalEstimator(
17+
treatment="X",
18+
treatment_value=2,
19+
control_value=1,
20+
adjustment_set={},
21+
outcome="Y",
22+
effect_modifiers={},
23+
alpha=0.05,
24+
repeats=200,
25+
)
26+
ate, [ci_low, ci_high] = estimator.estimate_ate()
27+
self.assertEqual(ate["X"], 2)
28+
self.assertEqual(ci_low["X"], 2)
29+
self.assertEqual(ci_high["X"], 2)
30+
31+
def test_estimate_risk_ratio(self):
32+
estimator = ConcreteExperimentalEstimator(
33+
treatment="X",
34+
treatment_value=2,
35+
control_value=1,
36+
adjustment_set={},
37+
outcome="Y",
38+
effect_modifiers={},
39+
alpha=0.05,
40+
repeats=200,
41+
)
42+
rr, [ci_low, ci_high] = estimator.estimate_risk_ratio()
43+
self.assertEqual(rr["X"], 2)
44+
self.assertEqual(ci_low["X"], 2)
45+
self.assertEqual(ci_high["X"], 2)

0 commit comments

Comments
 (0)