@@ -629,12 +629,13 @@ def _compute_statistics(
629629 cumulative = True ,
630630 relative = True ,
631631 min_effect = None ,
632+ time_dim = "obs_ind" ,
632633):
633634 """Compute all summary statistics from posterior draws."""
634635 stats = {}
635636
636637 # Average effect over window
637- avg_effect = impact .mean (dim = "obs_ind" )
638+ avg_effect = impact .mean (dim = time_dim )
638639 stats ["avg" ] = {
639640 "mean" : float (avg_effect .mean (dim = ["chain" , "draw" ]).values ),
640641 "median" : float (avg_effect .median (dim = ["chain" , "draw" ]).values ),
@@ -677,9 +678,9 @@ def _compute_statistics(
677678 # Cumulative effect
678679 if cumulative :
679680 # Use cumulative sum over window
680- cum_effect = impact .cumsum (dim = "obs_ind" )
681+ cum_effect = impact .cumsum (dim = time_dim )
681682 # Take final value (cumulative over entire window)
682- cum_final = cum_effect .isel (obs_ind = - 1 )
683+ cum_final = cum_effect .isel ({ time_dim : - 1 } )
683684
684685 stats ["cum" ] = {
685686 "mean" : float (cum_final .mean (dim = ["chain" , "draw" ]).values ),
@@ -720,7 +721,7 @@ def _compute_statistics(
720721 # Relative effects
721722 if relative :
722723 epsilon = 1e-8 # Guard against division by zero
723- counterfactual_mean = counterfactual .mean (dim = "obs_ind" )
724+ counterfactual_mean = counterfactual .mean (dim = time_dim )
724725 rel_avg = (avg_effect / (counterfactual_mean + epsilon )) * 100
725726
726727 stats ["avg" ]["relative_mean" ] = float (
@@ -746,7 +747,9 @@ def _compute_statistics(
746747
747748 if cumulative :
748749 # Relative cumulative: (cumulative effect / cumulative counterfactual) * 100
749- counterfactual_cum = counterfactual .cumsum (dim = "obs_ind" ).isel (obs_ind = - 1 )
750+ counterfactual_cum = counterfactual .cumsum (dim = time_dim ).isel (
751+ {time_dim : - 1 }
752+ )
750753 rel_cum = (cum_final / (counterfactual_cum + epsilon )) * 100
751754
752755 stats ["cum" ]["relative_mean" ] = float (
@@ -850,6 +853,7 @@ def _generate_prose(
850853 direction = "increase" ,
851854 cumulative = True ,
852855 relative = True ,
856+ prefix = "Post-period" ,
853857):
854858 """Generate prose summary from statistics."""
855859 hdi_pct = int ((1 - alpha ) * 100 )
@@ -883,7 +887,7 @@ def fmt_num(x, decimals=2):
883887 direction_text = "effect"
884888
885889 prose_parts = [
886- f"Post-period ({ window_str } ), the average effect was { fmt_num (avg_mean )} "
890+ f"{ prefix } ({ window_str } ), the average effect was { fmt_num (avg_mean )} "
887891 f"({ hdi_pct } % HDI [{ fmt_num (avg_lower )} , { fmt_num (avg_upper )} ]), "
888892 f"with a posterior probability of an { direction_text } of { fmt_num (p_val , 3 )} ."
889893 ]
@@ -1138,6 +1142,7 @@ def _generate_prose_ols(
11381142 alpha = 0.05 ,
11391143 cumulative = True ,
11401144 relative = True ,
1145+ prefix = "Post-period" ,
11411146):
11421147 """Generate prose summary for OLS models."""
11431148 ci_pct = int ((1 - alpha ) * 100 )
@@ -1161,7 +1166,7 @@ def fmt_num(x, decimals=2):
11611166 p_val = stats ["avg" ]["p_value" ]
11621167
11631168 prose_parts = [
1164- f"Post-period ({ window_str } ), the average effect was { fmt_num (avg_mean )} "
1169+ f"{ prefix } ({ window_str } ), the average effect was { fmt_num (avg_mean )} "
11651170 f"({ ci_pct } % CI [{ fmt_num (avg_lower )} , { fmt_num (avg_upper )} ]), "
11661171 f"with a p-value of { fmt_num (p_val , 3 )} ."
11671172 ]
0 commit comments