@@ -135,7 +135,8 @@ def __init__(
135135 if isinstance (self .model , PyMCModel ):
136136 COORDS = {
137137 # key must stay as "coeffs" unless we can find a way to auto identify
138- # the predictor dimension name
138+ # the predictor dimension name. "coeffs" is assumed by
139+ # PyMCModel.print_coefficients for example.
139140 "coeffs" : self .control_units ,
140141 "treated_units" : self .treated_units ,
141142 "obs_ind" : np .arange (self .datapre .shape [0 ]),
@@ -423,45 +424,49 @@ def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
423424 :param hdi_prob:
424425 Prob for which the highest density interval will be computed. The default value is defined as the default from the :func:`arviz.hdi` function.
425426 """
426- if isinstance (self .model , PyMCModel ):
427- hdi_pct = int ( round ( hdi_prob * 100 ) )
427+ if not isinstance (self .model , PyMCModel ):
428+ raise ValueError ( "Unsupported model type" )
428429
429- pred_lower_col = f"pred_hdi_lower_{ hdi_pct } "
430- pred_upper_col = f"pred_hdi_upper_{ hdi_pct } "
431- impact_lower_col = f"impact_hdi_lower_{ hdi_pct } "
432- impact_upper_col = f"impact_hdi_upper_{ hdi_pct } "
430+ hdi_pct = int (round (hdi_prob * 100 ))
433431
434- pre_data = self .datapre .copy ()
435- post_data = self .datapost .copy ()
432+ pred_lower_col = f"pred_hdi_lower_{ hdi_pct } "
433+ pred_upper_col = f"pred_hdi_upper_{ hdi_pct } "
434+ impact_lower_col = f"impact_hdi_lower_{ hdi_pct } "
435+ impact_upper_col = f"impact_hdi_upper_{ hdi_pct } "
436436
437- pre_data ["prediction" ] = (
438- az .extract (self .pre_pred , group = "posterior_predictive" , var_names = "mu" )
439- .mean ("sample" )
440- .values
441- )
442- post_data ["prediction" ] = (
443- az .extract (self .post_pred , group = "posterior_predictive" , var_names = "mu" )
444- .mean ("sample" )
445- .values
446- )
447- pre_data [[pred_lower_col , pred_upper_col ]] = get_hdi_to_df (
448- self .pre_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob
449- ).set_index (pre_data .index )
450- post_data [[pred_lower_col , pred_upper_col ]] = get_hdi_to_df (
451- self .post_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob
452- ).set_index (post_data .index )
453-
454- pre_data ["impact" ] = self .pre_impact .mean (dim = ["chain" , "draw" ]).values
455- post_data ["impact" ] = self .post_impact .mean (dim = ["chain" , "draw" ]).values
456- pre_data [[impact_lower_col , impact_upper_col ]] = get_hdi_to_df (
457- self .pre_impact , hdi_prob = hdi_prob
458- ).set_index (pre_data .index )
459- post_data [[impact_lower_col , impact_upper_col ]] = get_hdi_to_df (
460- self .post_impact , hdi_prob = hdi_prob
461- ).set_index (post_data .index )
462-
463- self .plot_data = pd .concat ([pre_data , post_data ])
464-
465- return self .plot_data
466- else :
467- raise ValueError ("Unsupported model type" )
437+ pre_data = self .datapre .copy ()
438+ post_data = self .datapost .copy ()
439+
440+ pre_data ["prediction" ] = (
441+ az .extract (self .pre_pred , group = "posterior_predictive" , var_names = "mu" )
442+ .mean ("sample" )
443+ .values
444+ )
445+ post_data ["prediction" ] = (
446+ az .extract (self .post_pred , group = "posterior_predictive" , var_names = "mu" )
447+ .mean ("sample" )
448+ .values
449+ )
450+ pre_data [[pred_lower_col , pred_upper_col ]] = get_hdi_to_df (
451+ self .pre_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob
452+ ).set_index (pre_data .index )
453+ post_data [[pred_lower_col , pred_upper_col ]] = get_hdi_to_df (
454+ self .post_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob
455+ ).set_index (post_data .index )
456+
457+ pre_data ["impact" ] = (
458+ self .pre_impact .mean (dim = ["chain" , "draw" ]).isel (treated_units = 0 ).values
459+ )
460+ post_data ["impact" ] = (
461+ self .post_impact .mean (dim = ["chain" , "draw" ]).isel (treated_units = 0 ).values
462+ )
463+ pre_data [[impact_lower_col , impact_upper_col ]] = get_hdi_to_df (
464+ self .pre_impact , hdi_prob = hdi_prob
465+ ).set_index (pre_data .index )
466+ post_data [[impact_lower_col , impact_upper_col ]] = get_hdi_to_df (
467+ self .post_impact , hdi_prob = hdi_prob
468+ ).set_index (post_data .index )
469+
470+ self .plot_data = pd .concat ([pre_data , post_data ])
471+
472+ return self .plot_data
0 commit comments