@@ -256,13 +256,12 @@ def __init__(
256
256
self .y_pred_counterfactual = self .prediction_model .predict (np .asarray (new_x ))
257
257
258
258
# calculate causal impact
259
- # TODO: This should most likely be posterior estimate, not posterior predictive
260
259
self .causal_impact = (
261
260
self .y_pred_treatment ["posterior_predictive" ]
262
- .y_hat .isel ({"obs_ind" : 1 })
261
+ .mu .isel ({"obs_ind" : 1 })
263
262
.mean ()
264
263
.data
265
- - self .y_pred_counterfactual ["posterior_predictive" ].y_hat .mean ().data
264
+ - self .y_pred_counterfactual ["posterior_predictive" ].mu .mean ().data
266
265
)
267
266
268
267
def plot (self ):
@@ -283,7 +282,7 @@ def plot(self):
283
282
# Plot model fit to control group
284
283
parts = ax .violinplot (
285
284
az .extract (
286
- self .y_pred_control , group = "posterior_predictive" , var_names = "y_hat "
285
+ self .y_pred_control , group = "posterior_predictive" , var_names = "mu "
287
286
).values .T ,
288
287
positions = self .x_pred_control [self .time_variable_name ].values ,
289
288
showmeans = False ,
@@ -298,7 +297,7 @@ def plot(self):
298
297
# Plot model fit to treatment group
299
298
parts = ax .violinplot (
300
299
az .extract (
301
- self .y_pred_treatment , group = "posterior_predictive" , var_names = "y_hat "
300
+ self .y_pred_treatment , group = "posterior_predictive" , var_names = "mu "
302
301
).values .T ,
303
302
positions = self .x_pred_treatment [self .time_variable_name ].values ,
304
303
showmeans = False ,
@@ -310,7 +309,7 @@ def plot(self):
310
309
az .extract (
311
310
self .y_pred_counterfactual ,
312
311
group = "posterior_predictive" ,
313
- var_names = "y_hat " ,
312
+ var_names = "mu " ,
314
313
).values .T ,
315
314
positions = self .x_pred_counterfactual [self .time_variable_name ].values ,
316
315
showmeans = False ,
@@ -320,12 +319,12 @@ def plot(self):
320
319
# arrow to label the causal impact
321
320
y_pred_treatment = (
322
321
self .y_pred_treatment ["posterior_predictive" ]
323
- .y_hat .isel ({"obs_ind" : 1 })
322
+ .mu .isel ({"obs_ind" : 1 })
324
323
.mean ()
325
324
.data
326
325
)
327
326
y_pred_counterfactual = (
328
- self .y_pred_counterfactual ["posterior_predictive" ].y_hat .mean ().data
327
+ self .y_pred_counterfactual ["posterior_predictive" ].mu .mean ().data
329
328
)
330
329
ax .annotate (
331
330
"" ,
0 commit comments