Skip to content

Commit c4e678e

Browse files
committed
Add plot customization options to EventStudy
Enhanced the _bayesian_plot and _ols_plot methods in EventStudy to support configurable figure size and HDI probability. Updated docstrings to document new parameters and improved plot labeling for clarity.
1 parent 9e5b790 commit c4e678e

File tree

2 files changed

+68
-20
lines changed

2 files changed

+68
-20
lines changed

causalpy/experiments/event_study.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -420,10 +420,31 @@ def get_event_time_summary(self, round_to: int | None = 2) -> pd.DataFrame:
420420
return pd.DataFrame(rows)
421421

422422
def _bayesian_plot(
423-
self, round_to: int | None = 2, **kwargs: dict
423+
self,
424+
round_to: int | None = 2,
425+
figsize: tuple[float, float] = (10, 6),
426+
hdi_prob: float = 0.94,
427+
**kwargs: dict,
424428
) -> tuple[plt.Figure, plt.Axes]:
425-
"""Plot event-study coefficients with credible intervals (Bayesian)."""
426-
fig, ax = plt.subplots(figsize=(10, 6))
429+
"""Plot event-study coefficients with credible intervals (Bayesian).
430+
431+
Parameters
432+
----------
433+
round_to : int, optional
434+
Number of decimals for rounding. Defaults to 2.
435+
figsize : tuple[float, float], optional
436+
Figure size in inches (width, height). Defaults to (10, 6).
437+
hdi_prob : float, optional
438+
Probability mass for the highest density interval. Defaults to 0.94.
439+
**kwargs : dict
440+
Additional keyword arguments (currently unused).
441+
442+
Returns
443+
-------
444+
tuple[plt.Figure, plt.Axes]
445+
The matplotlib Figure and Axes objects.
446+
"""
447+
fig, ax = plt.subplots(figsize=figsize)
427448

428449
sorted_times = sorted(self.event_time_coeffs.keys())
429450
means_list: list[float] = []
@@ -437,7 +458,7 @@ def _bayesian_plot(
437458
lower_list.append(0.0)
438459
upper_list.append(0.0)
439460
else:
440-
hdi = az.hdi(coeff.values.flatten(), hdi_prob=0.94)
461+
hdi = az.hdi(coeff.values.flatten(), hdi_prob=hdi_prob)
441462
means_list.append(float(coeff.mean()))
442463
lower_list.append(float(hdi[0]))
443464
upper_list.append(float(hdi[1]))
@@ -447,6 +468,7 @@ def _bayesian_plot(
447468
upper = np.array(upper_list)
448469

449470
# Plot coefficients with error bars
471+
hdi_pct = int(hdi_prob * 100)
450472
ax.errorbar(
451473
sorted_times,
452474
means,
@@ -456,7 +478,7 @@ def _bayesian_plot(
456478
capthick=2,
457479
markersize=8,
458480
color="C0",
459-
label="Event-time coefficient",
481+
label=f"Event-time coefficient ({hdi_pct}% HDI)",
460482
)
461483

462484
# Add horizontal line at zero
@@ -502,10 +524,28 @@ def _bayesian_plot(
502524
return fig, ax
503525

504526
def _ols_plot(
505-
self, round_to: int | None = 2, **kwargs: dict
527+
self,
528+
round_to: int | None = 2,
529+
figsize: tuple[float, float] = (10, 6),
530+
**kwargs: dict,
506531
) -> tuple[plt.Figure, plt.Axes]:
507-
"""Plot event-study coefficients (OLS)."""
508-
fig, ax = plt.subplots(figsize=(10, 6))
532+
"""Plot event-study coefficients (OLS).
533+
534+
Parameters
535+
----------
536+
round_to : int, optional
537+
Number of decimals for rounding. Defaults to 2.
538+
figsize : tuple[float, float], optional
539+
Figure size in inches (width, height). Defaults to (10, 6).
540+
**kwargs : dict
541+
Additional keyword arguments (currently unused).
542+
543+
Returns
544+
-------
545+
tuple[plt.Figure, plt.Axes]
546+
The matplotlib Figure and Axes objects.
547+
"""
548+
fig, ax = plt.subplots(figsize=figsize)
509549

510550
sorted_times = sorted(self.event_time_coeffs.keys())
511551
coeffs = []

docs/source/notebooks/event_study_pymc.ipynb

Lines changed: 20 additions & 12 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)