Skip to content

Commit 36b4511

Browse files
committed
#76 improve DiD plotting + rerun notebooks
1 parent f0aefd4 commit 36b4511

File tree

3 files changed

+66
-83
lines changed

3 files changed

+66
-83
lines changed

causalpy/pymc_experiments.py

Lines changed: 41 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -370,53 +370,52 @@ def plot(self):
370370
alpha=0.5,
371371
ax=ax,
372372
)
373+
373374
# Plot model fit to control group
374-
parts = ax.violinplot(
375-
az.extract(
376-
self.y_pred_control, group="posterior_predictive", var_names="mu"
377-
).values.T,
378-
positions=self.x_pred_control[self.time_variable_name].values,
379-
showmeans=False,
380-
showmedians=False,
381-
widths=0.2,
382-
)
383-
for pc in parts["bodies"]:
384-
pc.set_facecolor("C0")
385-
pc.set_edgecolor("None")
386-
pc.set_alpha(0.5)
375+
time_points = self.x_pred_control[self.time_variable_name].values
376+
plot_xY(
377+
time_points,
378+
self.y_pred_control.posterior_predictive.y_hat,
379+
ax=ax,
380+
plot_hdi_kwargs={"color": "C0"},
381+
)
387382

388383
# Plot model fit to treatment group
389-
parts = ax.violinplot(
390-
az.extract(
391-
self.y_pred_treatment, group="posterior_predictive", var_names="mu"
392-
).values.T,
393-
positions=self.x_pred_treatment[self.time_variable_name].values,
394-
showmeans=False,
395-
showmedians=False,
396-
widths=0.2,
397-
)
398-
399-
for pc in parts["bodies"]:
400-
pc.set_facecolor("C1")
401-
pc.set_edgecolor("None")
402-
pc.set_alpha(0.5)
384+
time_points = self.x_pred_control[self.time_variable_name].values
385+
plot_xY(
386+
time_points,
387+
self.y_pred_treatment.posterior_predictive.y_hat,
388+
ax=ax,
389+
plot_hdi_kwargs={"color": "C1"},
390+
)
391+
403392
# Plot counterfactual - post-test for treatment group IF no treatment
404393
# had occurred.
405-
parts = ax.violinplot(
406-
az.extract(
407-
self.y_pred_counterfactual,
408-
group="posterior_predictive",
409-
var_names="mu",
410-
).values.T,
411-
positions=self.x_pred_counterfactual[self.time_variable_name].values,
412-
showmeans=False,
413-
showmedians=False,
414-
widths=0.2,
415-
)
416-
for pc in parts["bodies"]:
417-
pc.set_facecolor("C2")
418-
pc.set_edgecolor("None")
419-
pc.set_alpha(0.5)
394+
time_points = self.x_pred_counterfactual[self.time_variable_name].values
395+
if len(time_points) == 1:
396+
parts = ax.violinplot(
397+
az.extract(
398+
self.y_pred_counterfactual,
399+
group="posterior_predictive",
400+
var_names="mu",
401+
).values.T,
402+
positions=self.x_pred_counterfactual[self.time_variable_name].values,
403+
showmeans=False,
404+
showmedians=False,
405+
widths=0.2,
406+
)
407+
for pc in parts["bodies"]:
408+
pc.set_facecolor("C2")
409+
pc.set_edgecolor("None")
410+
pc.set_alpha(0.5)
411+
else:
412+
plot_xY(
413+
time_points,
414+
self.y_pred_counterfactual.posterior_predictive.y_hat,
415+
ax=ax,
416+
plot_hdi_kwargs={"color": "C2"},
417+
)
418+
420419
# arrow to label the causal impact
421420
self._plot_causal_impact_arrow(ax)
422421
# formatting

docs/notebooks/did_pymc.ipynb

Lines changed: 5 additions & 5 deletions
Large diffs are not rendered by default.

docs/notebooks/did_pymc_banks.ipynb

Lines changed: 20 additions & 36 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)