Skip to content

Commit b1310a6

Browse files
committed
#76 improve DID plotting + improve data pre-processing
1 parent 90cc898 commit b1310a6

File tree

3 files changed

+261
-234
lines changed

3 files changed

+261
-234
lines changed

causalpy/pymc_experiments.py

Lines changed: 55 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,8 @@ def __init__(
238238

239239
# TODO: check that data in column self.group_variable_name has TWO levels
240240

241+
# TODO: check we have `unit` as a predictor column which is an vector of labels of unique units
242+
241243
# TODO: `treated` is a deterministic function of group and time, so this should be a function rather than supplied data
242244

243245
# DEVIATION FROM SKL EXPERIMENT CODE =============================
@@ -303,18 +305,17 @@ def plot(self):
303305

304306
# Plot raw data
305307
# NOTE: This will not work when there is just ONE unit in each group
306-
# sns.lineplot(
307-
# self.data,
308-
# x=self.time_variable_name,
309-
# y=self.outcome_variable_name,
310-
# hue=self.group_variable_name,
311-
# # units="unit",
312-
# estimator=None,
313-
# alpha=0.25,
314-
# ax=ax,
315-
# )
308+
sns.lineplot(
309+
self.data,
310+
x=self.time_variable_name,
311+
y=self.outcome_variable_name,
312+
hue=self.group_variable_name,
313+
units="unit", # NOTE: assumes we have a `unit` predictor variable
314+
estimator=None,
315+
alpha=0.5,
316+
ax=ax,
317+
)
316318
# Plot model fit to control group
317-
# NOTE: This will not work when there is just ONE unit in each group
318319
parts = ax.violinplot(
319320
az.extract(
320321
self.y_pred_control, group="posterior_predictive", var_names="mu"
@@ -330,7 +331,6 @@ def plot(self):
330331
pc.set_alpha(0.5)
331332

332333
# Plot model fit to treatment group
333-
# NOTE: This will not work when there is just ONE unit in each group
334334
parts = ax.violinplot(
335335
az.extract(
336336
self.y_pred_treatment, group="posterior_predictive", var_names="mu"
@@ -340,20 +340,41 @@ def plot(self):
340340
showmedians=False,
341341
widths=0.2,
342342
)
343-
# # Plot counterfactual - post-test for treatment group IF no treatment had occurred.
344-
# # NOTE: This will not work when there is just ONE unit in each group
345-
# parts = ax.violinplot(
346-
# az.extract(
347-
# self.y_pred_counterfactual,
348-
# group="posterior_predictive",
349-
# var_names="mu",
350-
# ).values.T,
351-
# positions=self.x_pred_counterfactual[self.time_variable_name].values,
352-
# showmeans=False,
353-
# showmedians=False,
354-
# widths=0.2,
355-
# )
343+
for pc in parts["bodies"]:
344+
pc.set_facecolor("C1")
345+
pc.set_edgecolor("None")
346+
pc.set_alpha(0.5)
347+
# Plot counterfactual - post-test for treatment group IF no treatment had occurred.
348+
parts = ax.violinplot(
349+
az.extract(
350+
self.y_pred_counterfactual,
351+
group="posterior_predictive",
352+
var_names="mu",
353+
).values.T,
354+
positions=self.x_pred_counterfactual[self.time_variable_name].values,
355+
showmeans=False,
356+
showmedians=False,
357+
widths=0.2,
358+
)
359+
for pc in parts["bodies"]:
360+
pc.set_facecolor("C2")
361+
pc.set_edgecolor("None")
362+
pc.set_alpha(0.5)
356363
# arrow to label the causal impact
364+
self._plot_causal_impact_arrow(ax)
365+
# formatting
366+
ax.set(
367+
xticks=self.x_pred_treatment[self.time_variable_name].values,
368+
title=self._causal_impact_summary_stat(),
369+
)
370+
ax.legend(fontsize=LEGEND_FONT_SIZE)
371+
return (fig, ax)
372+
373+
def _plot_causal_impact_arrow(self, ax):
374+
"""
375+
draw a vertical arrow between `y_pred_counterfactual` and `y_pred_counterfactual`
376+
"""
377+
# Calculate y values to plot the arrow between
357378
y_pred_treatment = (
358379
self.y_pred_treatment["posterior_predictive"]
359380
.mu.isel({"obs_ind": 1})
@@ -363,32 +384,28 @@ def plot(self):
363384
y_pred_counterfactual = (
364385
self.y_pred_counterfactual["posterior_predictive"].mu.mean().data
365386
)
387+
# Calculate the x position to plot at
388+
diff = np.ptp(self.x_pred_treatment[self.time_variable_name].values)
389+
x = np.max(self.x_pred_treatment[self.time_variable_name].values) + 0.1 * diff
390+
# Plot the arrow
366391
ax.annotate(
367392
"",
368-
xy=(1.15, y_pred_counterfactual),
393+
xy=(x, y_pred_counterfactual),
369394
xycoords="data",
370-
xytext=(1.15, y_pred_treatment),
395+
xytext=(x, y_pred_treatment),
371396
textcoords="data",
372-
arrowprops={"arrowstyle": "<->", "color": "green", "lw": 3},
397+
arrowprops={"arrowstyle": "<-", "color": "green", "lw": 3},
373398
)
399+
# Plot text annotation next to arrow
374400
ax.annotate(
375401
"causal\nimpact",
376-
xy=(1.15, np.mean([y_pred_counterfactual, y_pred_treatment])),
402+
xy=(x, np.mean([y_pred_counterfactual, y_pred_treatment])),
377403
xycoords="data",
378404
xytext=(5, 0),
379405
textcoords="offset points",
380406
color="green",
381407
va="center",
382408
)
383-
# formatting
384-
ax.set(
385-
# xlim=[-0.15, 1.25],
386-
xticks=self.x_pred_treatment[self.time_variable_name].values,
387-
# xticklabels=["pre", "post"],
388-
title=self._causal_impact_summary_stat(),
389-
)
390-
ax.legend(fontsize=LEGEND_FONT_SIZE)
391-
return (fig, ax)
392409

393410
def _causal_impact_summary_stat(self):
394411
percentiles = self.causal_impact.quantile([0.03, 1 - 0.03]).values

docs/notebooks/did_pymc.ipynb

Lines changed: 112 additions & 26 deletions
Large diffs are not rendered by default.

docs/notebooks/did_pymc_banks.ipynb

Lines changed: 94 additions & 170 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)