Skip to content

Commit 2ebf675

Browse files
authored
Merge pull request #60 from CITCOM-project/poisson-process-example
Poisson process example and a couple of minor changes to support that
2 parents 06fe327 + da0b546 commit 2ebf675

20 files changed

+6560
-16
lines changed

causal_testing/testing/causal_test_engine.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,20 @@ def execute_test(self, estimator: Estimator, estimate_type: str = 'ate') -> Caus
160160
confidence_intervals=confidence_intervals)
161161
# causal_test_result = CausalTestResult(minimal_adjustment_set, ate, confidence_intervals)
162162
# causal_test_result.apply_test_oracle_procedure(self.causal_test_case.expected_causal_effect)
163+
elif estimate_type == "ate_calculated":
164+
logger.debug("calculating ate")
165+
ate, confidence_intervals = estimator.estimate_ate_calculated()
166+
causal_test_result = CausalTestResult(
167+
treatment=estimator.treatment,
168+
outcome=estimator.outcome,
169+
treatment_value=estimator.treatment_values,
170+
control_value=estimator.control_values,
171+
adjustment_set=estimator.adjustment_set,
172+
ate=ate,
173+
effect_modifier_configuration=self.causal_test_case.effect_modifier_configuration,
174+
confidence_intervals=confidence_intervals)
175+
# causal_test_result = CausalTestResult(minimal_adjustment_set, ate, confidence_intervals)
176+
# causal_test_result.apply_test_oracle_procedure(self.causal_test_case.expected_causal_effect)
163177
else:
164178
raise ValueError(f"Invalid estimate type {estimate_type}, expected 'ate', 'cate', or 'risk_ratio'")
165179
return causal_test_result

causal_testing/testing/causal_test_outcome.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,21 @@ def __str__(self):
4141
confidence_str += f"Confidence intervals: {self.confidence_intervals}\n"
4242
return base_str + confidence_str
4343

44+
def to_dict(self):
45+
base_dict = {
46+
"treatment": self.treatment[0],
47+
"control_value": self.control_value,
48+
"treatment_value": self.treatment_value,
49+
"outcome": self.outcome[0],
50+
"adjustment_set": self.adjustment_set,
51+
"ate": self.ate
52+
}
53+
if self.confidence_intervals:
54+
base_dict["ci_low"] = min(self.confidence_intervals)
55+
base_dict["ci_high"] = max(self.confidence_intervals)
56+
return base_dict
57+
58+
4459
def ci_low(self):
4560
"""Return the lower bracket of the confidence intervals."""
4661
if not self.confidence_intervals:

causal_testing/testing/estimators.py

Lines changed: 65 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,20 @@ class LinearRegressionEstimator(Estimator):
7979
combination of parameters and functions of the variables (note these functions need not be linear).
8080
"""
8181
def __init__(self, treatment: tuple, treatment_values: float, control_values: float, adjustment_set: set,
82-
outcome: tuple, df: pd.DataFrame = None, effect_modifiers: dict[Variable: Any] = None, product_terms: list[tuple[Variable, Variable]] = None):
82+
outcome: tuple, df: pd.DataFrame = None, effect_modifiers: dict[Variable: Any] = None, product_terms: list[tuple[Variable, Variable]] = None, intercept: int = 1):
8383
super().__init__(treatment, treatment_values, control_values, adjustment_set, outcome, df, effect_modifiers)
84+
8485
if product_terms is None:
8586
product_terms = []
8687
for (term_a, term_b) in product_terms:
8788
self.add_product_term_to_df(term_a, term_b)
89+
for term in self.effect_modifiers:
90+
self.adjustment_set.add(term)
91+
92+
self.product_terms = product_terms
8893
self.square_terms = []
89-
self.product_terms = []
94+
self.inverse_terms = []
95+
self.intercept = intercept
9096

9197
def add_modelling_assumptions(self):
9298
"""
@@ -112,6 +118,21 @@ def add_squared_term_to_df(self, term_to_square: str):
112118
f'with {term_to_square}.'
113119
self.square_terms.append(term_to_square)
114120

121+
def add_inverse_term_to_df(self, term_to_invert: str):
122+
""" Add an inverse term to the linear regression model and df.
123+
124+
This enables the user to capture curvilinear relationships with a linear regression model, not just straight
125+
lines, while automatically adding the modelling assumption imposed by the addition of this term.
126+
127+
:param term_to_square: The term (column in data and variable in DAG) which is to be squared.
128+
"""
129+
new_term = "1/"+str(term_to_invert)
130+
self.df[new_term] = 1/self.df[term_to_invert]
131+
self.adjustment_set.add(new_term)
132+
self.modelling_assumptions += f'Relationship between {self.treatment} and {self.outcome} varies inversely'\
133+
f'with {term_to_invert}.'
134+
self.inverse_terms.append(term_to_invert)
135+
115136
def add_product_term_to_df(self, term_a: str, term_b: str):
116137
""" Add a product term to the linear regression model and df.
117138
@@ -146,6 +167,7 @@ def estimate_ate(self) -> tuple[float, list[float, float], float]:
146167
:return: The average treatment effect and the 95% Wald confidence intervals.
147168
"""
148169
model = self._run_linear_regression()
170+
print(model.summary())
149171
# Create an empty individual for the control and treated
150172
individuals = pd.DataFrame(1, index=['control', 'treated'], columns=model.params.index)
151173
individuals.loc['control', list(self.treatment)] = self.control_values
@@ -162,30 +184,57 @@ def estimate_ate(self) -> tuple[float, list[float, float], float]:
162184
confidence_intervals = list(t_test_results.conf_int().flatten())
163185
return ate, confidence_intervals
164186

165-
def estimate_risk_ratio(self) -> tuple[float, list[float, float]]:
166-
""" Estimate the average treatment effect of the treatment on the outcome. That is, the change in outcome caused
167-
by changing the treatment variable from the control value to the treatment value.
187+
def estimate_control_treatment(self) -> tuple[pd.Series, pd.Series]:
188+
""" Estimate the outcomes under control and treatment.
168189
169190
:return: The average treatment effect and the 95% Wald confidence intervals.
170191
"""
171192
model = self._run_linear_regression()
193+
self.model = model
172194

173195
x = pd.DataFrame()
174196
x[self.treatment[0]] = [self.treatment_values, self.control_values]
175-
x['Intercept'] = 1
197+
x['Intercept'] = self.intercept
198+
for k, v in self.effect_modifiers.items():
199+
x[k] = v
176200
for t in self.square_terms:
177201
x[t+'^2'] = x[t] ** 2
202+
for t in self.inverse_terms:
203+
x['1/'+t] = 1 / x[t]
178204
for a, b in self.product_terms:
179205
x[f"{a}*{b}"] = x[a] * x[b]
206+
x = x[model.params.index]
180207

181-
print(x)
182-
print(model.summary())
208+
y = model.get_prediction(x).summary_frame()
209+
return y.iloc[1], y.iloc[0]
183210

184-
y = model.predict(x)
185-
treatment_outcome = y.iloc[0]
186-
control_outcome = y.iloc[1]
187211

188-
return treatment_outcome/control_outcome, None
212+
def estimate_risk_ratio(self) -> tuple[float, list[float, float]]:
213+
""" Estimate the risk_ratio effect of the treatment on the outcome. That is, the change in outcome caused
214+
by changing the treatment variable from the control value to the treatment value.
215+
216+
:return: The average treatment effect and the 95% Wald confidence intervals.
217+
"""
218+
control_outcome, treatment_outcome = self.estimate_control_treatment()
219+
ci_low = treatment_outcome['mean_ci_lower'] / control_outcome['mean_ci_upper']
220+
ci_high = treatment_outcome['mean_ci_upper'] / control_outcome['mean_ci_lower']
221+
222+
return (treatment_outcome['mean']/control_outcome['mean']), [ci_low, ci_high]
223+
224+
225+
def estimate_ate_calculated(self) -> tuple[float, list[float, float]]:
226+
""" Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
227+
by changing the treatment variable from the control value to the treatment value. Here, we actually
228+
calculate the expected outcomes under control and treatment and take one away from the other. This
229+
allows for custom terms to be put in such as squares, inverses, products, etc.
230+
231+
:return: The average treatment effect and the 95% Wald confidence intervals.
232+
"""
233+
control_outcome, treatment_outcome = self.estimate_control_treatment()
234+
ci_low = treatment_outcome['mean_ci_lower'] - control_outcome['mean_ci_upper']
235+
ci_high = treatment_outcome['mean_ci_upper'] - control_outcome['mean_ci_lower']
236+
237+
return (treatment_outcome['mean']-control_outcome['mean']), [ci_low, ci_high]
189238

190239
def estimate_cates(self) -> tuple[float, list[float, float]]:
191240
""" Estimate the conditional average treatment effect of the treatment on the outcome. That is, the change
@@ -196,7 +245,7 @@ def estimate_cates(self) -> tuple[float, list[float, float]]:
196245
assert self.effect_modifiers, f"Must have at least one effect modifier to compute CATE - {self.effect_modifiers}."
197246
x = pd.DataFrame()
198247
x[self.treatment[0]] = [self.treatment_values, self.control_values]
199-
x['Intercept'] = 1
248+
x['Intercept'] = self.intercept
200249
for k, v in self.effect_modifiers.items():
201250
self.adjustment_set.add(k)
202251
x[k] = v
@@ -226,11 +275,11 @@ def _run_linear_regression(self) -> RegressionResultsWrapper:
226275
necessary_cols = list(self.treatment) + list(self.adjustment_set) + list(self.outcome)
227276
missing_rows = reduced_df[necessary_cols].isnull().any(axis=1)
228277
reduced_df = reduced_df[~missing_rows]
278+
reduced_df = reduced_df.sort_values(list(self.treatment))
229279
logger.debug(reduced_df[necessary_cols])
230280

231281
# 2. Add intercept
232-
reduced_df['Intercept'] = 1
233-
282+
reduced_df['Intercept'] = self.intercept
234283

235284
# 3. Estimate the unit difference in outcome caused by unit difference in treatment
236285
cols = list(self.treatment)
@@ -289,6 +338,7 @@ def estimate_ate(self) -> float:
289338
model.fit(outcome_df, treatment_df, X=effect_modifier_df, W=confounders_df)
290339

291340
# Obtain the ATE and 95% confidence intervals
341+
print(dir(model))
292342
ate = model.ate(effect_modifier_df, T0=self.control_values, T1=self.treatment_values)
293343
ate_interval = model.ate_interval(effect_modifier_df, T0=self.control_values, T1=self.treatment_values)
294344
ci_low, ci_high = ate_interval[0], ate_interval[1]

0 commit comments

Comments
 (0)