Skip to content

Commit 8f39903

Browse files
authored
Merge pull request #77 from pymc-labs/misc-grab-bag-of-improvements
Misc grab bag of improvements
2 parents e011c9d + 8c83b92 commit 8f39903

File tree

9 files changed

+1028
-94
lines changed

9 files changed

+1028
-94
lines changed

causalpy/pymc_experiments.py

Lines changed: 72 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,40 @@
99
from causalpy.plot_utils import plot_xY
1010

1111
LEGEND_FONT_SIZE = 12
12+
az.style.use("arviz-darkgrid")
1213

1314

1415
class ExperimentalDesign:
1516
"""Base class"""
1617

1718
prediction_model = None
19+
expt_type = None
1820

1921
def __init__(self, prediction_model=None, **kwargs):
2022
if prediction_model is not None:
2123
self.prediction_model = prediction_model
2224
if self.prediction_model is None:
2325
raise ValueError("fitting_model not set or passed.")
2426

27+
def print_coefficients(self):
28+
"""Prints the model coefficients"""
29+
print("Model coefficients:")
30+
coeffs = az.extract(self.prediction_model.idata.posterior, var_names="beta")
31+
# Note: f"{name: <30}" pads the name with spaces so that we have alignment of the stats despite variable names of different lengths
32+
for name in self.labels:
33+
coeff_samples = coeffs.sel(coeffs=name)
34+
print(
35+
f" {name: <30}{coeff_samples.mean().data:.2f}, 94% HDI [{coeff_samples.quantile(0.03).data:.2f}, {coeff_samples.quantile(1-0.03).data:.2f}]"
36+
)
37+
# add coeff for measurement std
38+
coeff_samples = az.extract(
39+
self.prediction_model.idata.posterior, var_names="sigma"
40+
)
41+
name = "sigma"
42+
print(
43+
f" {name: <30}{coeff_samples.mean().data:.2f}, 94% HDI [{coeff_samples.quantile(0.03).data:.2f}, {coeff_samples.quantile(1-0.03).data:.2f}]"
44+
)
45+
2546

2647
class TimeSeriesExperiment(ExperimentalDesign):
2748
"""A class to analyse time series quasi-experiments"""
@@ -44,6 +65,7 @@ def __init__(
4465

4566
# set things up with pre-intervention data
4667
y, X = dmatrices(formula, self.datapre)
68+
self.outcome_variable_name = y.design_info.column_names[0]
4769
self._y_design_info = y.design_info
4870
self._x_design_info = X.design_info
4971
self.labels = X.design_info.column_names
@@ -144,10 +166,20 @@ def plot(self):
144166

145167
return (fig, ax)
146168

169+
def summary(self):
170+
"""Print text output summarising the results"""
171+
172+
print(f"{self.expt_type:=^80}")
173+
print(f"Formula: {self.formula}")
174+
# TODO: extra experiment specific outputs here
175+
self.print_coefficients()
176+
147177

148178
class SyntheticControl(TimeSeriesExperiment):
149179
"""A wrapper around the TimeSeriesExperiment class"""
150180

181+
expt_type = "Synthetic Control"
182+
151183
def plot(self):
152184
"""Plot the results"""
153185
fig, ax = super().plot()
@@ -160,7 +192,7 @@ def plot(self):
160192
class InterruptedTimeSeries(TimeSeriesExperiment):
161193
"""A wrapper around the TimeSeriesExperiment class"""
162194

163-
pass
195+
expt_type = "Interrupted Time Series"
164196

165197

166198
class DifferenceInDifferences(ExperimentalDesign):
@@ -177,20 +209,20 @@ def __init__(
177209
data,
178210
formula,
179211
time_variable_name="t",
180-
outcome_variable_name="y",
181212
prediction_model=None,
182213
**kwargs,
183214
):
184215
super().__init__(prediction_model=prediction_model, **kwargs)
185216
self.data = data
217+
self.expt_type = "Difference in Differences"
186218
self.formula = formula
187219
self.time_variable_name = time_variable_name
188-
self.outcome_variable_name = outcome_variable_name
189220
y, X = dmatrices(formula, self.data)
190221
self._y_design_info = y.design_info
191222
self._x_design_info = X.design_info
192223
self.labels = X.design_info.column_names
193224
self.y, self.X = np.asarray(y), np.asarray(X)
225+
self.outcome_variable_name = y.design_info.column_names[0]
194226

195227
# TODO: `treated` is a deterministic function of group and time, so this should be a function rather than supplied data
196228

@@ -224,14 +256,18 @@ def __init__(
224256
self.y_pred_counterfactual = self.prediction_model.predict(np.asarray(new_x))
225257

226258
# calculate causal impact
227-
# TODO: This should most likely be posterior estimate, not posterior predictive
228259
self.causal_impact = (
229-
self.y_pred_treatment["posterior_predictive"]
230-
.y_hat.isel({"obs_ind": 1})
231-
.mean()
232-
.data
233-
- self.y_pred_counterfactual["posterior_predictive"].y_hat.mean().data
260+
self.y_pred_treatment["posterior_predictive"].mu.isel({"obs_ind": 1})
261+
- self.y_pred_counterfactual["posterior_predictive"].mu.squeeze()
234262
)
263+
# self.causal_impact = (
264+
# self.y_pred_treatment["posterior_predictive"]
265+
# .mu.isel({"obs_ind": 1})
266+
# .stack(samples=["chain", "draw"])
267+
# - self.y_pred_counterfactual["posterior_predictive"]
268+
# .mu.stack(samples=["chain", "draw"])
269+
# .squeeze()
270+
# )
235271

236272
def plot(self):
237273
"""Plot the results"""
@@ -251,7 +287,7 @@ def plot(self):
251287
# Plot model fit to control group
252288
parts = ax.violinplot(
253289
az.extract(
254-
self.y_pred_control, group="posterior_predictive", var_names="y_hat"
290+
self.y_pred_control, group="posterior_predictive", var_names="mu"
255291
).values.T,
256292
positions=self.x_pred_control[self.time_variable_name].values,
257293
showmeans=False,
@@ -266,7 +302,7 @@ def plot(self):
266302
# Plot model fit to treatment group
267303
parts = ax.violinplot(
268304
az.extract(
269-
self.y_pred_treatment, group="posterior_predictive", var_names="y_hat"
305+
self.y_pred_treatment, group="posterior_predictive", var_names="mu"
270306
).values.T,
271307
positions=self.x_pred_treatment[self.time_variable_name].values,
272308
showmeans=False,
@@ -278,7 +314,7 @@ def plot(self):
278314
az.extract(
279315
self.y_pred_counterfactual,
280316
group="posterior_predictive",
281-
var_names="y_hat",
317+
var_names="mu",
282318
).values.T,
283319
positions=self.x_pred_counterfactual[self.time_variable_name].values,
284320
showmeans=False,
@@ -288,12 +324,12 @@ def plot(self):
288324
# arrow to label the causal impact
289325
y_pred_treatment = (
290326
self.y_pred_treatment["posterior_predictive"]
291-
.y_hat.isel({"obs_ind": 1})
327+
.mu.isel({"obs_ind": 1})
292328
.mean()
293329
.data
294330
)
295331
y_pred_counterfactual = (
296-
self.y_pred_counterfactual["posterior_predictive"].y_hat.mean().data
332+
self.y_pred_counterfactual["posterior_predictive"].mu.mean().data
297333
)
298334
ax.annotate(
299335
"",
@@ -317,11 +353,27 @@ def plot(self):
317353
xlim=[-0.15, 1.25],
318354
xticks=[0, 1],
319355
xticklabels=["pre", "post"],
320-
title=f"Causal impact = {self.causal_impact:.2f}",
356+
title=self._causal_impact_summary_stat(),
321357
)
322358
ax.legend(fontsize=LEGEND_FONT_SIZE)
323359
return (fig, ax)
324360

361+
def _causal_impact_summary_stat(self):
362+
percentiles = self.causal_impact.quantile([0.03, 1 - 0.03]).values
363+
ci = r"$CI_{94\%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]"
364+
causal_impact = f"{self.causal_impact.mean():.2f}, "
365+
return f"Causal impact = {causal_impact + ci}"
366+
367+
def summary(self):
368+
"""Print text output summarising the results"""
369+
370+
print(f"{self.expt_type:=^80}")
371+
print(f"Formula: {self.formula}")
372+
print("\nResults:")
373+
# TODO: extra experiment specific outputs here
374+
print(self._causal_impact_summary_stat())
375+
self.print_coefficients()
376+
325377

326378
class RegressionDiscontinuity(ExperimentalDesign):
327379
"""
@@ -345,20 +397,20 @@ def __init__(
345397
treatment_threshold: float,
346398
prediction_model=None,
347399
running_variable_name: str = "x",
348-
outcome_variable_name="y",
349400
**kwargs,
350401
):
351402
super().__init__(prediction_model=prediction_model, **kwargs)
403+
self.expt_type = "Regression Discontinuity"
352404
self.data = data
353405
self.formula = formula
354406
self.running_variable_name = running_variable_name
355-
self.outcome_variable_name = outcome_variable_name
356407
self.treatment_threshold = treatment_threshold
357408
y, X = dmatrices(formula, self.data)
358409
self._y_design_info = y.design_info
359410
self._x_design_info = X.design_info
360411
self.labels = X.design_info.column_names
361412
self.y, self.X = np.asarray(y), np.asarray(X)
413+
self.outcome_variable_name = y.design_info.column_names[0]
362414

363415
# TODO: `treated` is a deterministic function of x and treatment_threshold, so this could be a function rather than supplied data
364416

@@ -445,25 +497,13 @@ def plot(self):
445497

446498
def summary(self):
447499
"""Print text output summarising the results"""
448-
print("Difference in Differences experiment")
500+
501+
print(f"{self.expt_type:=^80}")
449502
print(f"Formula: {self.formula}")
450503
print(f"Running variable: {self.running_variable_name}")
451504
print(f"Threshold on running variable: {self.treatment_threshold}")
452505
print(f"\nResults:")
453506
print(
454507
f"Discontinuity at threshold = {self.discontinuity_at_threshold.mean():.2f}"
455508
)
456-
print("Model coefficients:")
457-
coeffs = az.extract(self.prediction_model.idata.posterior, var_names="beta")
458-
for name in self.labels:
459-
coeff_samples = coeffs.sel(coeffs=name)
460-
print(
461-
f"\t{name}\t\t{coeff_samples.mean().data:.2f}, 94% HDI [{coeff_samples.quantile(0.03).data:.2f}, {coeff_samples.quantile(1-0.03).data:.2f}]"
462-
)
463-
# add coeff for measurement std
464-
coeff_samples = az.extract(
465-
self.prediction_model.idata.posterior, var_names="sigma"
466-
)
467-
print(
468-
f"\tsigma\t\t{coeff_samples.mean().data:.2f}, 94% HDI [{coeff_samples.quantile(0.03).data:.2f}, {coeff_samples.quantile(1-0.03).data:.2f}]"
469-
)
509+
self.print_coefficients()

causalpy/pymc_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def build_model(self, X, y, coords):
6666
n_predictors = X.shape[1]
6767
X = pm.MutableData("X", X, dims=["obs_ind", "coeffs"])
6868
y = pm.MutableData("y", y[:, 0], dims="obs_ind")
69-
beta = pm.Dirichlet("beta", a=np.ones(n_predictors))
69+
beta = pm.Dirichlet("beta", a=np.ones(n_predictors), dims="coeffs")
7070
sigma = pm.HalfNormal("sigma", 1)
7171
mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind")
7272
pm.Normal("y_hat", mu, sigma, observed=y, dims="obs_ind")

causalpy/skl_experiments.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ class ExperimentalDesign:
1111
"""Base class for experiment designs"""
1212

1313
prediction_model = None
14+
outcome_variable_name = None
1415

1516
def __init__(self, prediction_model=None, **kwargs):
1617
if prediction_model is not None:
@@ -34,6 +35,7 @@ def __init__(self, data, treatment_time, formula, prediction_model=None, **kwarg
3435
self._y_design_info = y.design_info
3536
self._x_design_info = X.design_info
3637
self.labels = X.design_info.column_names
38+
self.outcome_variable_name = y.design_info.column_names[0]
3739
self.pre_y, self.pre_X = np.asarray(y), np.asarray(X)
3840
# process post-intervention data
3941
(new_y, new_x) = build_design_matrices(
@@ -174,20 +176,19 @@ def __init__(
174176
data,
175177
formula,
176178
time_variable_name="t",
177-
outcome_variable_name="y",
178179
prediction_model=None,
179180
**kwargs,
180181
):
181182
super().__init__(prediction_model=prediction_model, **kwargs)
182183
self.data = data
183184
self.formula = formula
184185
self.time_variable_name = time_variable_name
185-
self.outcome_variable_name = outcome_variable_name
186186
y, X = dmatrices(formula, self.data)
187187
self._y_design_info = y.design_info
188188
self._x_design_info = X.design_info
189189
self.labels = X.design_info.column_names
190190
self.y, self.X = np.asarray(y), np.asarray(X)
191+
self.outcome_variable_name = y.design_info.column_names[0]
191192

192193
# TODO: `treated` is a deterministic function of group and time, so this should be a function rather than supplied data
193194

@@ -307,20 +308,19 @@ def __init__(
307308
treatment_threshold,
308309
prediction_model=None,
309310
running_variable_name="x",
310-
outcome_variable_name="y",
311311
**kwargs,
312312
):
313313
super().__init__(prediction_model=prediction_model, **kwargs)
314314
self.data = data
315315
self.formula = formula
316316
self.running_variable_name = running_variable_name
317-
self.outcome_variable_name = outcome_variable_name
318317
self.treatment_threshold = treatment_threshold
319318
y, X = dmatrices(formula, self.data)
320319
self._y_design_info = y.design_info
321320
self._x_design_info = X.design_info
322321
self.labels = X.design_info.column_names
323322
self.y, self.X = np.asarray(y), np.asarray(X)
323+
self.outcome_variable_name = y.design_info.column_names[0]
324324

325325
# TODO: `treated` is a deterministic function of x and treatment_threshold, so this could be a function rather than supplied data
326326

docs/notebooks/did_pymc.ipynb

Lines changed: 45 additions & 4 deletions
Large diffs are not rendered by default.

docs/notebooks/its_pymc.ipynb

Lines changed: 235 additions & 13 deletions
Large diffs are not rendered by default.

docs/notebooks/rd_pymc.ipynb

Lines changed: 36 additions & 3 deletions
Large diffs are not rendered by default.

docs/notebooks/rd_pymc_drinking.ipynb

Lines changed: 399 additions & 14 deletions
Large diffs are not rendered by default.

docs/notebooks/rd_skl_drinking.ipynb

Lines changed: 5 additions & 7 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)