@@ -340,7 +340,8 @@ def __init__(
340
340
(new_x ,) = build_design_matrices ([self ._x_design_info ], self .x_pred_treatment )
341
341
self .y_pred_treatment = self .model .predict (np .asarray (new_x ))
342
342
343
- # predicted outcome for counterfactual
343
+ # predicted outcome for counterfactual. This is given by removing the influence
344
+ # of the interaction term between the group and the post_treatment variable
344
345
self .x_pred_counterfactual = (
345
346
self .data
346
347
# just the treated group
@@ -349,24 +350,28 @@ def __init__(
349
350
.query ("post_treatment == True" )
350
351
# drop the outcome variable
351
352
.drop (self .outcome_variable_name , axis = 1 )
352
- # DO AN INTERVENTION. Set the post_treatment variable to False
353
- .assign (post_treatment = False )
354
353
# We may have multiple units per time point, we only want one time point
355
354
.groupby (self .time_variable_name )
356
355
.first ()
357
356
.reset_index ()
358
357
)
359
358
assert not self .x_pred_counterfactual .empty
360
359
(new_x ,) = build_design_matrices (
361
- [self ._x_design_info ], self .x_pred_counterfactual
360
+ [self ._x_design_info ], self .x_pred_counterfactual , return_type = "dataframe"
362
361
)
362
+ # INTERVENTION: set the interaction term between the group and the
363
+ # post_treatment variable to zero. This is the counterfactual.
364
+ for i , label in enumerate (self .labels ):
365
+ if "post_treatment" in label and self .group_variable_name in label :
366
+ new_x .iloc [:, i ] = 0
363
367
self .y_pred_counterfactual = self .model .predict (np .asarray (new_x ))
364
368
365
- # calculate causal impact
366
- self .causal_impact = (
367
- self .y_pred_treatment ["posterior_predictive" ].mu .isel ({"obs_ind" : 1 })
368
- - self .y_pred_counterfactual ["posterior_predictive" ].mu .squeeze ()
369
- )
369
+ # calculate causal impact.
370
+ # This is the coefficient on the interaction term
371
+ coeff_names = self .idata .posterior .coords ["coeffs" ].data
372
+ for i , label in enumerate (coeff_names ):
373
+ if "post_treatment" in label and self .group_variable_name in label :
374
+ self .causal_impact = self .idata .posterior ["beta" ].isel ({"coeffs" : i })
370
375
371
376
def plot (self ):
372
377
"""Plot the results.
0 commit comments