@@ -257,12 +257,17 @@ def __init__(
257
257
258
258
# calculate causal impact
259
259
self .causal_impact = (
260
- self .y_pred_treatment ["posterior_predictive" ]
261
- .mu .isel ({"obs_ind" : 1 })
262
- .mean ()
263
- .data
264
- - self .y_pred_counterfactual ["posterior_predictive" ].mu .mean ().data
260
+ self .y_pred_treatment ["posterior_predictive" ].mu .isel ({"obs_ind" : 1 })
261
+ - self .y_pred_counterfactual ["posterior_predictive" ].mu .squeeze ()
265
262
)
263
+ # self.causal_impact = (
264
+ # self.y_pred_treatment["posterior_predictive"]
265
+ # .mu.isel({"obs_ind": 1})
266
+ # .stack(samples=["chain", "draw"])
267
+ # - self.y_pred_counterfactual["posterior_predictive"]
268
+ # .mu.stack(samples=["chain", "draw"])
269
+ # .squeeze()
270
+ # )
266
271
267
272
def plot (self ):
268
273
"""Plot the results"""
@@ -348,17 +353,25 @@ def plot(self):
348
353
xlim = [- 0.15 , 1.25 ],
349
354
xticks = [0 , 1 ],
350
355
xticklabels = ["pre" , "post" ],
351
- title = f"Causal impact = { self .causal_impact :.2f } " ,
356
+ title = self ._causal_impact_summary_stat () ,
352
357
)
353
358
ax .legend (fontsize = LEGEND_FONT_SIZE )
354
359
return (fig , ax )
355
360
361
+ def _causal_impact_summary_stat (self ):
362
+ percentiles = self .causal_impact .quantile ([0.03 , 1 - 0.03 ]).values
363
+ ci = r"$CI_{94\%}$" + f"[{ percentiles [0 ]:.2f} , { percentiles [1 ]:.2f} ]"
364
+ causal_impact = f"{ self .causal_impact .mean ():.2f} , "
365
+ return f"Causal impact = { causal_impact + ci } "
366
+
356
367
def summary (self ):
357
368
"""Print text output summarising the results"""
358
369
359
370
print (f"{ self .expt_type :=^80} " )
360
371
print (f"Formula: { self .formula } " )
372
+ print ("\n Results:" )
361
373
# TODO: extra experiment specific outputs here
374
+ print (self ._causal_impact_summary_stat ())
362
375
self .print_coefficients ()
363
376
364
377
0 commit comments