Skip to content

Commit e2ae2e8

Browse files
Three period ITS (#575)
* Reporting enhancements * Adding comparison in summary * Updating Plots * Updating doctest * Delete planning.md * Delete fix_seasonality.py * Delete issue.md * Delete simplify_data.py * Fixing pre commit * Moving note up and updating in its_pymc * Updating Intro an Title of its_post_intervention_analysis * Updating visualisation and summary * Adding it into index * Update lift test * Update lift test * Updating docs * tweaks to new fixed intervention ITS notebook * fix cross reference * clarification on fixed point intervention definition * minor notebook tweaks * minor tweaks to the its lift test notebook * wrap text output + minor simplification + re-run --------- Co-authored-by: Benjamin T. Vincent <[email protected]>
1 parent 5309a1f commit e2ae2e8

File tree

10 files changed

+2437
-61
lines changed

10 files changed

+2437
-61
lines changed

causalpy/experiments/interrupted_time_series.py

Lines changed: 770 additions & 38 deletions
Large diffs are not rendered by default.

causalpy/reporting.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -629,12 +629,13 @@ def _compute_statistics(
629629
cumulative=True,
630630
relative=True,
631631
min_effect=None,
632+
time_dim="obs_ind",
632633
):
633634
"""Compute all summary statistics from posterior draws."""
634635
stats = {}
635636

636637
# Average effect over window
637-
avg_effect = impact.mean(dim="obs_ind")
638+
avg_effect = impact.mean(dim=time_dim)
638639
stats["avg"] = {
639640
"mean": float(avg_effect.mean(dim=["chain", "draw"]).values),
640641
"median": float(avg_effect.median(dim=["chain", "draw"]).values),
@@ -677,9 +678,9 @@ def _compute_statistics(
677678
# Cumulative effect
678679
if cumulative:
679680
# Use cumulative sum over window
680-
cum_effect = impact.cumsum(dim="obs_ind")
681+
cum_effect = impact.cumsum(dim=time_dim)
681682
# Take final value (cumulative over entire window)
682-
cum_final = cum_effect.isel(obs_ind=-1)
683+
cum_final = cum_effect.isel({time_dim: -1})
683684

684685
stats["cum"] = {
685686
"mean": float(cum_final.mean(dim=["chain", "draw"]).values),
@@ -720,7 +721,7 @@ def _compute_statistics(
720721
# Relative effects
721722
if relative:
722723
epsilon = 1e-8 # Guard against division by zero
723-
counterfactual_mean = counterfactual.mean(dim="obs_ind")
724+
counterfactual_mean = counterfactual.mean(dim=time_dim)
724725
rel_avg = (avg_effect / (counterfactual_mean + epsilon)) * 100
725726

726727
stats["avg"]["relative_mean"] = float(
@@ -746,7 +747,9 @@ def _compute_statistics(
746747

747748
if cumulative:
748749
# Relative cumulative: (cumulative effect / cumulative counterfactual) * 100
749-
counterfactual_cum = counterfactual.cumsum(dim="obs_ind").isel(obs_ind=-1)
750+
counterfactual_cum = counterfactual.cumsum(dim=time_dim).isel(
751+
{time_dim: -1}
752+
)
750753
rel_cum = (cum_final / (counterfactual_cum + epsilon)) * 100
751754

752755
stats["cum"]["relative_mean"] = float(
@@ -850,6 +853,7 @@ def _generate_prose(
850853
direction="increase",
851854
cumulative=True,
852855
relative=True,
856+
prefix="Post-period",
853857
):
854858
"""Generate prose summary from statistics."""
855859
hdi_pct = int((1 - alpha) * 100)
@@ -883,7 +887,7 @@ def fmt_num(x, decimals=2):
883887
direction_text = "effect"
884888

885889
prose_parts = [
886-
f"Post-period ({window_str}), the average effect was {fmt_num(avg_mean)} "
890+
f"{prefix} ({window_str}), the average effect was {fmt_num(avg_mean)} "
887891
f"({hdi_pct}% HDI [{fmt_num(avg_lower)}, {fmt_num(avg_upper)}]), "
888892
f"with a posterior probability of an {direction_text} of {fmt_num(p_val, 3)}."
889893
]
@@ -1138,6 +1142,7 @@ def _generate_prose_ols(
11381142
alpha=0.05,
11391143
cumulative=True,
11401144
relative=True,
1145+
prefix="Post-period",
11411146
):
11421147
"""Generate prose summary for OLS models."""
11431148
ci_pct = int((1 - alpha) * 100)
@@ -1161,7 +1166,7 @@ def fmt_num(x, decimals=2):
11611166
p_val = stats["avg"]["p_value"]
11621167

11631168
prose_parts = [
1164-
f"Post-period ({window_str}), the average effect was {fmt_num(avg_mean)} "
1169+
f"{prefix} ({window_str}), the average effect was {fmt_num(avg_mean)} "
11651170
f"({ci_pct}% CI [{fmt_num(avg_lower)}, {fmt_num(avg_upper)}]), "
11661171
f"with a p-value of {fmt_num(p_val, 3)}."
11671172
]

0 commit comments

Comments
 (0)