Skip to content

Commit 90cc898

Browse files
committed
#76 #44 start to fix did plot
1 parent 7840611 commit 90cc898

File tree

1 file changed

+29
-25
lines changed

1 file changed

+29
-25
lines changed

causalpy/pymc_experiments.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -302,17 +302,19 @@ def plot(self):
302302
fig, ax = plt.subplots()
303303

304304
# Plot raw data
305-
sns.lineplot(
306-
self.data,
307-
x=self.time_variable_name,
308-
y=self.outcome_variable_name,
309-
hue=self.group_variable_name,
310-
units="unit",
311-
estimator=None,
312-
alpha=0.25,
313-
ax=ax,
314-
)
305+
# NOTE: This will not work when there is just ONE unit in each group
306+
# sns.lineplot(
307+
# self.data,
308+
# x=self.time_variable_name,
309+
# y=self.outcome_variable_name,
310+
# hue=self.group_variable_name,
311+
# # units="unit",
312+
# estimator=None,
313+
# alpha=0.25,
314+
# ax=ax,
315+
# )
315316
# Plot model fit to control group
317+
# NOTE: This will not work when there is just ONE unit in each group
316318
parts = ax.violinplot(
317319
az.extract(
318320
self.y_pred_control, group="posterior_predictive", var_names="mu"
@@ -328,6 +330,7 @@ def plot(self):
328330
pc.set_alpha(0.5)
329331

330332
# Plot model fit to treatment group
333+
# NOTE: This will not work when there is just ONE unit in each group
331334
parts = ax.violinplot(
332335
az.extract(
333336
self.y_pred_treatment, group="posterior_predictive", var_names="mu"
@@ -337,18 +340,19 @@ def plot(self):
337340
showmedians=False,
338341
widths=0.2,
339342
)
340-
# Plot counterfactual - post-test for treatment group IF no treatment had occurred.
341-
parts = ax.violinplot(
342-
az.extract(
343-
self.y_pred_counterfactual,
344-
group="posterior_predictive",
345-
var_names="mu",
346-
).values.T,
347-
positions=self.x_pred_counterfactual[self.time_variable_name].values,
348-
showmeans=False,
349-
showmedians=False,
350-
widths=0.2,
351-
)
343+
# # Plot counterfactual - post-test for treatment group IF no treatment had occurred.
344+
# # NOTE: This will not work when there is just ONE unit in each group
345+
# parts = ax.violinplot(
346+
# az.extract(
347+
# self.y_pred_counterfactual,
348+
# group="posterior_predictive",
349+
# var_names="mu",
350+
# ).values.T,
351+
# positions=self.x_pred_counterfactual[self.time_variable_name].values,
352+
# showmeans=False,
353+
# showmedians=False,
354+
# widths=0.2,
355+
# )
352356
# arrow to label the causal impact
353357
y_pred_treatment = (
354358
self.y_pred_treatment["posterior_predictive"]
@@ -378,9 +382,9 @@ def plot(self):
378382
)
379383
# formatting
380384
ax.set(
381-
xlim=[-0.15, 1.25],
382-
xticks=[0, 1],
383-
xticklabels=["pre", "post"],
385+
# xlim=[-0.15, 1.25],
386+
xticks=self.x_pred_treatment[self.time_variable_name].values,
387+
# xticklabels=["pre", "post"],
384388
title=self._causal_impact_summary_stat(),
385389
)
386390
ax.legend(fontsize=LEGEND_FONT_SIZE)

0 commit comments

Comments
 (0)