Skip to content

Commit 8eb35fc

Browse files
committed
Adding comparison in summary
1 parent 0c7ae25 commit 8eb35fc

File tree

4 files changed

+787
-438
lines changed

4 files changed

+787
-438
lines changed

causalpy/experiments/interrupted_time_series.py

Lines changed: 349 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,8 @@ def effect_summary(
512512
For three-period designs, specify which period to summarize:
513513
- "intervention": Summary for intervention period only
514514
- "post": Summary for post-intervention period only
515-
- "comparison": Comparative summary with persistence metrics (NotImplementedError)
515+
- "comparison": Comparative summary with persistence metrics (persistence ratio,
516+
probability that effect persisted, HDI/CI interval comparison)
516517
- None: Default behavior (summarizes all post-treatment data, backward compatible)
517518
direction : {"increase", "decrease", "two-sided"}, default="increase"
518519
Direction for tail probability calculation (PyMC only)
@@ -552,9 +553,13 @@ def effect_summary(
552553
)
553554

554555
if period == "comparison":
555-
raise NotImplementedError(
556-
"period='comparison' is not yet implemented. "
557-
"This will provide a comparative summary with persistence metrics."
556+
# Comparison period: compare intervention and post-intervention periods
557+
return self._comparison_period_summary(
558+
direction=direction,
559+
alpha=alpha,
560+
cumulative=cumulative,
561+
relative=relative,
562+
min_effect=min_effect,
558563
)
559564

560565
# Select appropriate impact and prediction data based on period
@@ -658,6 +663,170 @@ def effect_summary(
658663
treated_unit=treated_unit,
659664
)
660665

666+
def _comparison_period_summary(
667+
self,
668+
direction: Literal["increase", "decrease", "two-sided"] = "increase",
669+
alpha: float = 0.05,
670+
cumulative: bool = True,
671+
relative: bool = True,
672+
min_effect: float | None = None,
673+
):
674+
"""Generate comparative summary between intervention and post-intervention periods.
675+
676+
Parameters
677+
----------
678+
direction : {"increase", "decrease", "two-sided"}, default="increase"
679+
Direction for tail probability calculation (PyMC only)
680+
alpha : float, default=0.05
681+
Significance level for HDI/CI intervals
682+
cumulative : bool, default=True
683+
Whether to include cumulative effect statistics
684+
relative : bool, default=True
685+
Whether to include relative effect statistics
686+
min_effect : float, optional
687+
Region of Practical Equivalence (ROPE) threshold (PyMC only)
688+
689+
Returns
690+
-------
691+
EffectSummary
692+
Object with .table (DataFrame) and .text (str) attributes
693+
"""
694+
from causalpy.reporting import EffectSummary, _extract_hdi_bounds
695+
696+
is_pymc = isinstance(self.model, PyMCModel)
697+
time_dim = "obs_ind"
698+
hdi_prob = 1 - alpha
699+
prob_persisted: float | None
700+
701+
if is_pymc:
702+
# PyMC: Compute statistics for both periods
703+
intervention_avg = self.intervention_impact.mean(dim=time_dim)
704+
intervention_mean = float(
705+
intervention_avg.mean(dim=["chain", "draw"]).values
706+
)
707+
intervention_hdi = az.hdi(intervention_avg, hdi_prob=hdi_prob)
708+
intervention_lower, intervention_upper = _extract_hdi_bounds(
709+
intervention_hdi, hdi_prob
710+
)
711+
712+
post_avg = self.post_intervention_impact.mean(dim=time_dim)
713+
post_mean = float(post_avg.mean(dim=["chain", "draw"]).values)
714+
post_hdi = az.hdi(post_avg, hdi_prob=hdi_prob)
715+
post_lower, post_upper = _extract_hdi_bounds(post_hdi, hdi_prob)
716+
717+
# Persistence ratio: post_mean / intervention_mean (as percentage)
718+
epsilon = 1e-8
719+
persistence_ratio_pct = (post_mean / (intervention_mean + epsilon)) * 100
720+
721+
# Probability that some effect persisted (P(post_mean > 0))
722+
prob_persisted = float((post_avg > 0).mean().values)
723+
724+
# Build simple table
725+
table = pd.DataFrame(
726+
{
727+
"mean": [intervention_mean, post_mean],
728+
"hdi_lower": [intervention_lower, post_lower],
729+
"hdi_upper": [intervention_upper, post_upper],
730+
"persistence_ratio_pct": [None, persistence_ratio_pct],
731+
"prob_persisted": [None, prob_persisted],
732+
},
733+
index=["intervention", "post_intervention"],
734+
)
735+
736+
# Generate simple prose
737+
hdi_pct = int(hdi_prob * 100)
738+
text = (
739+
f"Effect persistence: The post-intervention effect "
740+
f"({post_mean:.1f}, {hdi_pct}% HDI [{post_lower:.1f}, {post_upper:.1f}]) "
741+
f"was {persistence_ratio_pct:.1f}% of the intervention effect "
742+
f"({intervention_mean:.1f}, {hdi_pct}% HDI [{intervention_lower:.1f}, {intervention_upper:.1f}]), "
743+
f"with a posterior probability of {prob_persisted:.2f} that some effect persisted "
744+
f"beyond the intervention period."
745+
)
746+
747+
else:
748+
# OLS: Compute statistics for both periods
749+
from causalpy.reporting import _compute_statistics_ols
750+
751+
intervention_stats = _compute_statistics_ols(
752+
self.intervention_impact.values
753+
if hasattr(self.intervention_impact, "values")
754+
else np.asarray(self.intervention_impact),
755+
self.intervention_pred,
756+
alpha=alpha,
757+
cumulative=False,
758+
relative=False,
759+
)
760+
761+
post_stats = _compute_statistics_ols(
762+
self.post_intervention_impact.values
763+
if hasattr(self.post_intervention_impact, "values")
764+
else np.asarray(self.post_intervention_impact),
765+
self.post_intervention_pred,
766+
alpha=alpha,
767+
cumulative=False,
768+
relative=False,
769+
)
770+
771+
# Persistence ratio (as percentage)
772+
epsilon = 1e-8
773+
persistence_ratio_pct = (
774+
post_stats["avg"]["mean"]
775+
/ (intervention_stats["avg"]["mean"] + epsilon)
776+
) * 100
777+
778+
# For OLS, use 1 - p-value as proxy for probability
779+
prob_persisted = (
780+
1 - post_stats["avg"]["p_value"]
781+
if "p_value" in post_stats["avg"]
782+
else None
783+
)
784+
785+
# Build simple table
786+
table_data = {
787+
"mean": [
788+
intervention_stats["avg"]["mean"],
789+
post_stats["avg"]["mean"],
790+
],
791+
"ci_lower": [
792+
intervention_stats["avg"]["ci_lower"],
793+
post_stats["avg"]["ci_lower"],
794+
],
795+
"ci_upper": [
796+
intervention_stats["avg"]["ci_upper"],
797+
post_stats["avg"]["ci_upper"],
798+
],
799+
"persistence_ratio_pct": [None, persistence_ratio_pct],
800+
}
801+
if prob_persisted is not None:
802+
table_data["prob_persisted"] = [None, prob_persisted]
803+
804+
table = pd.DataFrame(
805+
table_data,
806+
index=["intervention", "post_intervention"],
807+
)
808+
809+
# Generate simple prose
810+
ci_pct = int((1 - alpha) * 100)
811+
if prob_persisted is not None:
812+
text = (
813+
f"Effect persistence: The post-intervention effect "
814+
f"({post_stats['avg']['mean']:.1f}, {ci_pct}% CI [{post_stats['avg']['ci_lower']:.1f}, {post_stats['avg']['ci_upper']:.1f}]) "
815+
f"was {persistence_ratio_pct:.1f}% of the intervention effect "
816+
f"({intervention_stats['avg']['mean']:.1f}, {ci_pct}% CI [{intervention_stats['avg']['ci_lower']:.1f}, {intervention_stats['avg']['ci_upper']:.1f}]), "
817+
f"with a probability of {prob_persisted:.2f} that some effect persisted "
818+
f"beyond the intervention period."
819+
)
820+
else:
821+
text = (
822+
f"Effect persistence: The post-intervention effect "
823+
f"({post_stats['avg']['mean']:.1f}, {ci_pct}% CI [{post_stats['avg']['ci_lower']:.1f}, {post_stats['avg']['ci_upper']:.1f}]) "
824+
f"was {persistence_ratio_pct:.1f}% of the intervention effect "
825+
f"({intervention_stats['avg']['mean']:.1f}, {ci_pct}% CI [{intervention_stats['avg']['ci_lower']:.1f}, {intervention_stats['avg']['ci_upper']:.1f}])."
826+
)
827+
828+
return EffectSummary(table=table, text=text)
829+
661830
def summary(self, round_to: int | None = None) -> None:
662831
"""Print summary of main results and model coefficients.
663832
@@ -1049,3 +1218,179 @@ def get_plot_data_ols(self) -> pd.DataFrame:
10491218
self.plot_data = pd.concat([pre_data, post_data])
10501219

10511220
return self.plot_data
1221+
1222+
def analyze_persistence(
1223+
self,
1224+
hdi_prob: float = 0.95,
1225+
direction: Literal["increase", "decrease", "two-sided"] = "increase",
1226+
) -> dict[str, Any]:
1227+
"""Analyze effect persistence between intervention and post-intervention periods.
1228+
1229+
Computes mean effects, persistence ratio, and total (cumulative) impacts for both periods.
1230+
The persistence ratio is the post-intervention mean effect divided by the intervention
1231+
mean effect (as a decimal, e.g., 0.30 means 30% persistence, 1.5 means 150%).
1232+
Note: The ratio can exceed 1.0 if the post-intervention effect is larger than the
1233+
intervention effect.
1234+
1235+
Automatically prints a summary of the results.
1236+
1237+
Parameters
1238+
----------
1239+
hdi_prob : float, default=0.95
1240+
Probability for HDI interval (Bayesian models only)
1241+
direction : {"increase", "decrease", "two-sided"}, default="increase"
1242+
Direction for tail probability calculation (Bayesian models only)
1243+
1244+
Returns
1245+
-------
1246+
dict[str, Any]
1247+
Dictionary containing:
1248+
- "mean_effect_during": Mean effect during intervention period
1249+
- "mean_effect_post": Mean effect during post-intervention period
1250+
- "persistence_ratio": Post-intervention mean effect divided by intervention mean (decimal, can exceed 1.0)
1251+
- "total_effect_during": Total (cumulative) effect during intervention period
1252+
- "total_effect_post": Total (cumulative) effect during post-intervention period
1253+
1254+
Raises
1255+
------
1256+
ValueError
1257+
If treatment_end_time is not provided (two-period design)
1258+
1259+
Examples
1260+
--------
1261+
>>> result = cp.InterruptedTimeSeries(
1262+
... df,
1263+
... treatment_time=pd.Timestamp("2024-01-01"),
1264+
... treatment_end_time=pd.Timestamp("2024-04-01"),
1265+
... formula="y ~ 1 + t",
1266+
... )
1267+
>>> persistence = result.analyze_persistence()
1268+
>>> # Results are automatically printed
1269+
>>> print(f"Persistence ratio: {persistence['persistence_ratio']:.2f}")
1270+
"""
1271+
if self.treatment_end_time is None:
1272+
raise ValueError(
1273+
"analyze_persistence() requires treatment_end_time to be provided. "
1274+
"This method is only available for three-period designs."
1275+
)
1276+
1277+
is_pymc = isinstance(self.model, PyMCModel)
1278+
time_dim = "obs_ind"
1279+
1280+
if is_pymc:
1281+
# PyMC: Compute statistics using xarray operations
1282+
from causalpy.reporting import _extract_hdi_bounds
1283+
1284+
# Intervention period
1285+
intervention_avg = self.intervention_impact.mean(dim=time_dim)
1286+
intervention_mean = float(
1287+
intervention_avg.mean(dim=["chain", "draw"]).values
1288+
)
1289+
intervention_hdi = az.hdi(intervention_avg, hdi_prob=hdi_prob)
1290+
intervention_lower, intervention_upper = _extract_hdi_bounds(
1291+
intervention_hdi, hdi_prob
1292+
)
1293+
1294+
# Post-intervention period
1295+
post_avg = self.post_intervention_impact.mean(dim=time_dim)
1296+
post_mean = float(post_avg.mean(dim=["chain", "draw"]).values)
1297+
post_hdi = az.hdi(post_avg, hdi_prob=hdi_prob)
1298+
post_lower, post_upper = _extract_hdi_bounds(post_hdi, hdi_prob)
1299+
1300+
# Cumulative (total) impacts
1301+
intervention_cum = self.intervention_impact_cumulative.isel({time_dim: -1})
1302+
intervention_cum_mean = float(
1303+
intervention_cum.mean(dim=["chain", "draw"]).values
1304+
)
1305+
1306+
post_cum = self.post_intervention_impact_cumulative.isel({time_dim: -1})
1307+
post_cum_mean = float(post_cum.mean(dim=["chain", "draw"]).values)
1308+
1309+
# Persistence ratio: post_mean / intervention_mean (as decimal, not percentage)
1310+
epsilon = 1e-8
1311+
persistence_ratio = post_mean / (intervention_mean + epsilon)
1312+
1313+
result = {
1314+
"mean_effect_during": intervention_mean,
1315+
"mean_effect_post": post_mean,
1316+
"persistence_ratio": float(persistence_ratio),
1317+
"total_effect_during": intervention_cum_mean,
1318+
"total_effect_post": post_cum_mean,
1319+
}
1320+
# Store HDI bounds for printing
1321+
intervention_ci_lower = intervention_lower
1322+
intervention_ci_upper = intervention_upper
1323+
post_ci_lower = post_lower
1324+
post_ci_upper = post_upper
1325+
else:
1326+
# OLS: Compute statistics using numpy operations
1327+
from causalpy.reporting import _compute_statistics_ols
1328+
1329+
# Get counterfactual predictions for each period
1330+
intervention_counterfactual = self.intervention_pred
1331+
post_counterfactual = self.post_intervention_pred
1332+
1333+
# Compute statistics for intervention period
1334+
intervention_stats = _compute_statistics_ols(
1335+
self.intervention_impact.values
1336+
if hasattr(self.intervention_impact, "values")
1337+
else np.asarray(self.intervention_impact),
1338+
intervention_counterfactual,
1339+
alpha=1 - hdi_prob,
1340+
cumulative=True,
1341+
relative=False,
1342+
)
1343+
1344+
# Compute statistics for post-intervention period
1345+
post_stats = _compute_statistics_ols(
1346+
self.post_intervention_impact.values
1347+
if hasattr(self.post_intervention_impact, "values")
1348+
else np.asarray(self.post_intervention_impact),
1349+
post_counterfactual,
1350+
alpha=1 - hdi_prob,
1351+
cumulative=True,
1352+
relative=False,
1353+
)
1354+
1355+
# Persistence ratio (as decimal)
1356+
epsilon = 1e-8
1357+
persistence_ratio = post_stats["avg"]["mean"] / (
1358+
intervention_stats["avg"]["mean"] + epsilon
1359+
)
1360+
1361+
result = {
1362+
"mean_effect_during": intervention_stats["avg"]["mean"],
1363+
"mean_effect_post": post_stats["avg"]["mean"],
1364+
"persistence_ratio": float(persistence_ratio),
1365+
"total_effect_during": intervention_stats["cum"]["mean"],
1366+
"total_effect_post": post_stats["cum"]["mean"],
1367+
}
1368+
# Store CI bounds for printing
1369+
intervention_ci_lower = intervention_stats["avg"]["ci_lower"]
1370+
intervention_ci_upper = intervention_stats["avg"]["ci_upper"]
1371+
post_ci_lower = post_stats["avg"]["ci_lower"]
1372+
post_ci_upper = post_stats["avg"]["ci_upper"]
1373+
1374+
# Print results
1375+
hdi_pct = int(hdi_prob * 100)
1376+
ci_label = "HDI" if is_pymc else "CI"
1377+
print("=" * 60)
1378+
print("Effect Persistence Analysis")
1379+
print("=" * 60)
1380+
print("\nDuring intervention period:")
1381+
print(f" Mean effect: {result['mean_effect_during']:.2f}")
1382+
print(
1383+
f" {hdi_pct}% {ci_label}: [{intervention_ci_lower:.2f}, {intervention_ci_upper:.2f}]"
1384+
)
1385+
print(f" Total effect: {result['total_effect_during']:.2f}")
1386+
print("\nPost-intervention period:")
1387+
print(f" Mean effect: {result['mean_effect_post']:.2f}")
1388+
print(f" {hdi_pct}% {ci_label}: [{post_ci_lower:.2f}, {post_ci_upper:.2f}]")
1389+
print(f" Total effect: {result['total_effect_post']:.2f}")
1390+
print(f"\nPersistence ratio: {result['persistence_ratio']:.3f}")
1391+
print(
1392+
f" ({result['persistence_ratio'] * 100:.1f}% of intervention effect persisted)"
1393+
)
1394+
print("=" * 60)
1395+
1396+
return result

0 commit comments

Comments
 (0)