Skip to content

Commit fbc4c94

Browse files
committed
move plotting into experiment classes, removing PlotComponent entirely
1 parent 1cdf7c2 commit fbc4c94

File tree

12 files changed

+649
-716
lines changed

12 files changed

+649
-716
lines changed

causalpy/experiments/base.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
Base class for quasi experimental designs.
1616
"""
1717

18+
from abc import abstractmethod
19+
20+
from causalpy.pymc_models import PyMCModel
21+
from causalpy.skl_models import ScikitLearnModel
22+
1823

1924
class BaseExperiment:
2025
"""Base class for quasi experimental designs."""
@@ -33,3 +38,26 @@ def idata(self):
3338
def print_coefficients(self, round_to=None):
3439
"""Ask the model to print its coefficients."""
3540
self.model.print_coefficients(self.labels, round_to)
41+
42+
def plot(self, *args, **kwargs) -> tuple:
43+
"""Plot the model.
44+
45+
Internally, this function dispatches to either `bayesian_plot` or `ols_plot`
46+
depending on the model type.
47+
"""
48+
if isinstance(self.model, PyMCModel):
49+
return self.bayesian_plot(*args, **kwargs)
50+
elif isinstance(self.model, ScikitLearnModel):
51+
return self.ols_plot(*args, **kwargs)
52+
else:
53+
raise ValueError("Unsupported model type")
54+
55+
@abstractmethod
56+
def bayesian_plot(self, *args, **kwargs):
57+
"""Abstract method for plotting the model."""
58+
raise NotImplementedError
59+
60+
@abstractmethod
61+
def ols_plot(self, *args, **kwargs):
62+
"""Abstract method for plotting the model."""
63+
raise NotImplementedError

causalpy/experiments/diff_in_diff.py

Lines changed: 219 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,26 @@
1515
Difference in differences
1616
"""
1717

18+
import arviz as az
1819
import numpy as np
1920
import pandas as pd
21+
import seaborn as sns
2022
from matplotlib import pyplot as plt
2123
from patsy import build_design_matrices, dmatrices
2224

2325
from causalpy.custom_exceptions import (
2426
DataException,
2527
FormulaException,
2628
)
29+
from causalpy.plot_utils import plot_xY
2730
from causalpy.pymc_models import PyMCModel
2831
from causalpy.skl_models import ScikitLearnModel
29-
from causalpy.utils import _is_variable_dummy_coded, convert_to_string
32+
from causalpy.utils import _is_variable_dummy_coded, convert_to_string, round_num
3033

3134
from .base import BaseExperiment
3235

36+
LEGEND_FONT_SIZE = 12
37+
3338

3439
class DifferenceInDifferences(BaseExperiment):
3540
"""A class to analyse data from Difference in Difference settings.
@@ -205,18 +210,6 @@ def input_validation(self):
205210
coded. Consisting of 0's and 1's only."""
206211
)
207212

208-
def plot(self, round_to=None) -> tuple[plt.Figure, plt.Axes]:
209-
"""
210-
Plot the results
211-
212-
:param round_to:
213-
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
214-
"""
215-
# Get a BayesianPlotComponent or OLSPlotComponent depending on the model
216-
plot_component = self.model.get_plot_component()
217-
fig, ax = plot_component.plot_difference_in_differences(self, round_to=round_to)
218-
return fig, ax
219-
220213
def summary(self, round_to=None) -> None:
221214
"""Print summary of main results and model coefficients.
222215
@@ -232,3 +225,216 @@ def summary(self, round_to=None) -> None:
232225
def _causal_impact_summary_stat(self, round_to=None) -> str:
233226
"""Computes the mean and 94% credible interval bounds for the causal impact."""
234227
return f"Causal impact = {convert_to_string(self.causal_impact, round_to=round_to)}"
228+
229+
def bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
230+
"""
231+
Plot the results
232+
233+
:param round_to:
234+
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
235+
"""
236+
round_to = kwargs.get("round_to")
237+
238+
def _plot_causal_impact_arrow(results, ax):
239+
"""
240+
draw a vertical arrow between `y_pred_counterfactual` and
241+
`y_pred_counterfactual`
242+
"""
243+
# Calculate y values to plot the arrow between
244+
y_pred_treatment = (
245+
results.y_pred_treatment["posterior_predictive"]
246+
.mu.isel({"obs_ind": 1})
247+
.mean()
248+
.data
249+
)
250+
y_pred_counterfactual = (
251+
results.y_pred_counterfactual["posterior_predictive"].mu.mean().data
252+
)
253+
# Calculate the x position to plot at
254+
# Note that we force to be float to avoid a type error using np.ptp with boolean
255+
# values
256+
diff = np.ptp(
257+
np.array(
258+
results.x_pred_treatment[results.time_variable_name].values
259+
).astype(float)
260+
)
261+
x = (
262+
np.max(results.x_pred_treatment[results.time_variable_name].values)
263+
+ 0.1 * diff
264+
)
265+
# Plot the arrow
266+
ax.annotate(
267+
"",
268+
xy=(x, y_pred_counterfactual),
269+
xycoords="data",
270+
xytext=(x, y_pred_treatment),
271+
textcoords="data",
272+
arrowprops={"arrowstyle": "<-", "color": "green", "lw": 3},
273+
)
274+
# Plot text annotation next to arrow
275+
ax.annotate(
276+
"causal\nimpact",
277+
xy=(x, np.mean([y_pred_counterfactual, y_pred_treatment])),
278+
xycoords="data",
279+
xytext=(5, 0),
280+
textcoords="offset points",
281+
color="green",
282+
va="center",
283+
)
284+
285+
fig, ax = plt.subplots()
286+
287+
# Plot raw data
288+
sns.scatterplot(
289+
self.data,
290+
x=self.time_variable_name,
291+
y=self.outcome_variable_name,
292+
hue=self.group_variable_name,
293+
alpha=1,
294+
legend=False,
295+
markers=True,
296+
ax=ax,
297+
)
298+
299+
# Plot model fit to control group
300+
time_points = self.x_pred_control[self.time_variable_name].values
301+
h_line, h_patch = plot_xY(
302+
time_points,
303+
self.y_pred_control.posterior_predictive.mu,
304+
ax=ax,
305+
plot_hdi_kwargs={"color": "C0"},
306+
label="Control group",
307+
)
308+
handles = [(h_line, h_patch)]
309+
labels = ["Control group"]
310+
311+
# Plot model fit to treatment group
312+
time_points = self.x_pred_control[self.time_variable_name].values
313+
h_line, h_patch = plot_xY(
314+
time_points,
315+
self.y_pred_treatment.posterior_predictive.mu,
316+
ax=ax,
317+
plot_hdi_kwargs={"color": "C1"},
318+
label="Treatment group",
319+
)
320+
handles.append((h_line, h_patch))
321+
labels.append("Treatment group")
322+
323+
# Plot counterfactual - post-test for treatment group IF no treatment
324+
# had occurred.
325+
time_points = self.x_pred_counterfactual[self.time_variable_name].values
326+
if len(time_points) == 1:
327+
parts = ax.violinplot(
328+
az.extract(
329+
self.y_pred_counterfactual,
330+
group="posterior_predictive",
331+
var_names="mu",
332+
).values.T,
333+
positions=self.x_pred_counterfactual[self.time_variable_name].values,
334+
showmeans=False,
335+
showmedians=False,
336+
widths=0.2,
337+
)
338+
for pc in parts["bodies"]:
339+
pc.set_facecolor("C0")
340+
pc.set_edgecolor("None")
341+
pc.set_alpha(0.5)
342+
else:
343+
h_line, h_patch = plot_xY(
344+
time_points,
345+
self.y_pred_counterfactual.posterior_predictive.mu,
346+
ax=ax,
347+
plot_hdi_kwargs={"color": "C2"},
348+
label="Counterfactual",
349+
)
350+
handles.append((h_line, h_patch))
351+
labels.append("Counterfactual")
352+
353+
# arrow to label the causal impact
354+
_plot_causal_impact_arrow(self, ax)
355+
356+
# formatting
357+
ax.set(
358+
xticks=self.x_pred_treatment[self.time_variable_name].values,
359+
title=self._causal_impact_summary_stat(round_to),
360+
)
361+
ax.legend(
362+
handles=(h_tuple for h_tuple in handles),
363+
labels=labels,
364+
fontsize=LEGEND_FONT_SIZE,
365+
)
366+
return fig, ax
367+
368+
def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
369+
"""Generate plot for difference-in-differences"""
370+
round_to = kwargs.get("round_to")
371+
fig, ax = plt.subplots()
372+
373+
# Plot raw data
374+
sns.lineplot(
375+
self.data,
376+
x=self.time_variable_name,
377+
y=self.outcome_variable_name,
378+
hue="group",
379+
units="unit",
380+
estimator=None,
381+
alpha=0.25,
382+
ax=ax,
383+
)
384+
# Plot model fit to control group
385+
ax.plot(
386+
self.x_pred_control[self.time_variable_name],
387+
self.y_pred_control,
388+
"o",
389+
c="C0",
390+
markersize=10,
391+
label="model fit (control group)",
392+
)
393+
# Plot model fit to treatment group
394+
ax.plot(
395+
self.x_pred_treatment[self.time_variable_name],
396+
self.y_pred_treatment,
397+
"o",
398+
c="C1",
399+
markersize=10,
400+
label="model fit (treament group)",
401+
)
402+
# Plot counterfactual - post-test for treatment group IF no treatment
403+
# had occurred.
404+
ax.plot(
405+
self.x_pred_counterfactual[self.time_variable_name],
406+
self.y_pred_counterfactual,
407+
"go",
408+
markersize=10,
409+
label="counterfactual",
410+
)
411+
# arrow to label the causal impact
412+
ax.annotate(
413+
"",
414+
xy=(1.05, self.y_pred_counterfactual),
415+
xycoords="data",
416+
xytext=(1.05, self.y_pred_treatment[1]),
417+
textcoords="data",
418+
arrowprops={"arrowstyle": "<->", "color": "green", "lw": 3},
419+
)
420+
ax.annotate(
421+
"causal\nimpact",
422+
xy=(
423+
1.05,
424+
np.mean([self.y_pred_counterfactual[0], self.y_pred_treatment[1]]),
425+
),
426+
xycoords="data",
427+
xytext=(5, 0),
428+
textcoords="offset points",
429+
color="green",
430+
va="center",
431+
)
432+
# formatting
433+
ax.set(
434+
xlim=[-0.05, 1.1],
435+
xticks=[0, 1],
436+
xticklabels=["pre", "post"],
437+
title=f"Causal impact = {round_num(self.causal_impact, round_to)}",
438+
)
439+
ax.legend(fontsize=LEGEND_FONT_SIZE)
440+
return fig, ax

0 commit comments

Comments
 (0)