@@ -135,7 +135,8 @@ def __init__(
135
135
if isinstance (self .model , PyMCModel ):
136
136
COORDS = {
137
137
# key must stay as "coeffs" unless we can find a way to auto identify
138
- # the predictor dimension name
138
+ # the predictor dimension name. "coeffs" is assumed by
139
+ # PyMCModel.print_coefficients for example.
139
140
"coeffs" : self .control_units ,
140
141
"treated_units" : self .treated_units ,
141
142
"obs_ind" : np .arange (self .datapre .shape [0 ]),
@@ -423,45 +424,49 @@ def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
423
424
:param hdi_prob:
424
425
Prob for which the highest density interval will be computed. The default value is defined as the default from the :func:`arviz.hdi` function.
425
426
"""
426
- if isinstance (self .model , PyMCModel ):
427
- hdi_pct = int ( round ( hdi_prob * 100 ) )
427
+ if not isinstance (self .model , PyMCModel ):
428
+ raise ValueError ( "Unsupported model type" )
428
429
429
- pred_lower_col = f"pred_hdi_lower_{ hdi_pct } "
430
- pred_upper_col = f"pred_hdi_upper_{ hdi_pct } "
431
- impact_lower_col = f"impact_hdi_lower_{ hdi_pct } "
432
- impact_upper_col = f"impact_hdi_upper_{ hdi_pct } "
430
+ hdi_pct = int (round (hdi_prob * 100 ))
433
431
434
- pre_data = self .datapre .copy ()
435
- post_data = self .datapost .copy ()
432
+ pred_lower_col = f"pred_hdi_lower_{ hdi_pct } "
433
+ pred_upper_col = f"pred_hdi_upper_{ hdi_pct } "
434
+ impact_lower_col = f"impact_hdi_lower_{ hdi_pct } "
435
+ impact_upper_col = f"impact_hdi_upper_{ hdi_pct } "
436
436
437
- pre_data ["prediction" ] = (
438
- az .extract (self .pre_pred , group = "posterior_predictive" , var_names = "mu" )
439
- .mean ("sample" )
440
- .values
441
- )
442
- post_data ["prediction" ] = (
443
- az .extract (self .post_pred , group = "posterior_predictive" , var_names = "mu" )
444
- .mean ("sample" )
445
- .values
446
- )
447
- pre_data [[pred_lower_col , pred_upper_col ]] = get_hdi_to_df (
448
- self .pre_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob
449
- ).set_index (pre_data .index )
450
- post_data [[pred_lower_col , pred_upper_col ]] = get_hdi_to_df (
451
- self .post_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob
452
- ).set_index (post_data .index )
453
-
454
- pre_data ["impact" ] = self .pre_impact .mean (dim = ["chain" , "draw" ]).values
455
- post_data ["impact" ] = self .post_impact .mean (dim = ["chain" , "draw" ]).values
456
- pre_data [[impact_lower_col , impact_upper_col ]] = get_hdi_to_df (
457
- self .pre_impact , hdi_prob = hdi_prob
458
- ).set_index (pre_data .index )
459
- post_data [[impact_lower_col , impact_upper_col ]] = get_hdi_to_df (
460
- self .post_impact , hdi_prob = hdi_prob
461
- ).set_index (post_data .index )
462
-
463
- self .plot_data = pd .concat ([pre_data , post_data ])
464
-
465
- return self .plot_data
466
- else :
467
- raise ValueError ("Unsupported model type" )
437
+ pre_data = self .datapre .copy ()
438
+ post_data = self .datapost .copy ()
439
+
440
+ pre_data ["prediction" ] = (
441
+ az .extract (self .pre_pred , group = "posterior_predictive" , var_names = "mu" )
442
+ .mean ("sample" )
443
+ .values
444
+ )
445
+ post_data ["prediction" ] = (
446
+ az .extract (self .post_pred , group = "posterior_predictive" , var_names = "mu" )
447
+ .mean ("sample" )
448
+ .values
449
+ )
450
+ pre_data [[pred_lower_col , pred_upper_col ]] = get_hdi_to_df (
451
+ self .pre_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob
452
+ ).set_index (pre_data .index )
453
+ post_data [[pred_lower_col , pred_upper_col ]] = get_hdi_to_df (
454
+ self .post_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob
455
+ ).set_index (post_data .index )
456
+
457
+ pre_data ["impact" ] = (
458
+ self .pre_impact .mean (dim = ["chain" , "draw" ]).isel (treated_units = 0 ).values
459
+ )
460
+ post_data ["impact" ] = (
461
+ self .post_impact .mean (dim = ["chain" , "draw" ]).isel (treated_units = 0 ).values
462
+ )
463
+ pre_data [[impact_lower_col , impact_upper_col ]] = get_hdi_to_df (
464
+ self .pre_impact , hdi_prob = hdi_prob
465
+ ).set_index (pre_data .index )
466
+ post_data [[impact_lower_col , impact_upper_col ]] = get_hdi_to_df (
467
+ self .post_impact , hdi_prob = hdi_prob
468
+ ).set_index (post_data .index )
469
+
470
+ self .plot_data = pd .concat ([pre_data , post_data ])
471
+
472
+ return self .plot_data
0 commit comments