@@ -213,32 +213,40 @@ def summary(self, round_to=None) -> None:
213213 self .print_coefficients (round_to )
214214
215215 def _bayesian_plot (
216- self , round_to = None , treated_unit = None , ** kwargs
216+ self , round_to = None , treated_unit : str | None = None , ** kwargs
217217 ) -> tuple [plt .Figure , List [plt .Axes ]]:
218218 """
219219 Plot the results for a specific treated unit
220220
221221 :param round_to:
222222 Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
223223 :param treated_unit:
224- Which treated unit to plot. Can be an integer index or string name .
224+ Which treated unit to plot. Must be a string name of the treated unit .
225225 If None, plots the first treated unit.
226226 """
227227 counterfactual_label = "Counterfactual"
228228
229229 fig , ax = plt .subplots (3 , 1 , sharex = True , figsize = (7 , 8 ))
230230 # TOP PLOT --------------------------------------------------
231231 # pre-intervention period
232- primary_unit_idx = self ._get_primary_treated_unit_index (treated_unit )
233- primary_unit_name = self .treated_units [primary_unit_idx ]
232+
233+ # Get treated unit name - default to first unit if None
234+ primary_unit_name = (
235+ treated_unit if treated_unit is not None else self .treated_units [0 ]
236+ )
237+
238+ if primary_unit_name not in self .treated_units :
239+ raise ValueError (
240+ f"treated_unit '{ primary_unit_name } ' not found. Available units: { self .treated_units } "
241+ )
234242
235243 # For multi-unit, select primary unit for main plot
236244 if len (self .treated_units ) > 1 :
237- pre_pred_plot = self .pre_pred ["posterior_predictive" ].mu .isel (
238- treated_units = primary_unit_idx
245+ pre_pred_plot = self .pre_pred ["posterior_predictive" ].mu .sel (
246+ treated_units = primary_unit_name
239247 )
240- post_pred_plot = self .post_pred ["posterior_predictive" ].mu .isel (
241- treated_units = primary_unit_idx
248+ post_pred_plot = self .post_pred ["posterior_predictive" ].mu .sel (
249+ treated_units = primary_unit_name
242250 )
243251 else :
244252 pre_pred_plot = self .pre_pred ["posterior_predictive" ].mu
@@ -256,12 +264,12 @@ def _bayesian_plot(
256264 # Plot observations for primary treated unit
257265 (h ,) = ax [0 ].plot (
258266 self .datapre .index ,
259- self .datapre_treated .isel (treated_units = primary_unit_idx ),
267+ self .datapre_treated .sel (treated_units = primary_unit_name ),
260268 "k." ,
261- label = f"Observations ({ self . treated_units [ primary_unit_idx ] } )" ,
269+ label = f"Observations ({ primary_unit_name } )" ,
262270 )
263271 handles .append (h )
264- labels .append (f"Observations ({ self . treated_units [ primary_unit_idx ] } )" )
272+ labels .append (f"Observations ({ primary_unit_name } )" )
265273
266274 # post intervention period
267275 h_line , h_patch = plot_xY (
@@ -275,14 +283,14 @@ def _bayesian_plot(
275283
276284 ax [0 ].plot (
277285 self .datapost .index ,
278- self .datapost_treated .isel (treated_units = primary_unit_idx ),
286+ self .datapost_treated .sel (treated_units = primary_unit_name ),
279287 "k." ,
280288 )
281289 # Shaded causal effect for primary treated unit
282290 h = ax [0 ].fill_between (
283291 self .datapost .index ,
284292 y1 = post_pred_plot .mean (dim = ["chain" , "draw" ]).values ,
285- y2 = self .datapost_treated .isel (treated_units = primary_unit_idx ).values ,
293+ y2 = self .datapost_treated .sel (treated_units = primary_unit_name ).values ,
286294 color = "C2" ,
287295 alpha = 0.25 ,
288296 label = "Causal impact" ,
@@ -295,21 +303,21 @@ def _bayesian_plot(
295303 # MIDDLE PLOT -----------------------------------------------
296304 plot_xY (
297305 self .datapre .index ,
298- self .pre_impact .sel (treated_units = self . treated_units [ primary_unit_idx ] ),
306+ self .pre_impact .sel (treated_units = primary_unit_name ),
299307 ax = ax [1 ],
300308 plot_hdi_kwargs = {"color" : "C0" },
301309 )
302310 plot_xY (
303311 self .datapost .index ,
304- self .post_impact .sel (treated_units = self . treated_units [ primary_unit_idx ] ),
312+ self .post_impact .sel (treated_units = primary_unit_name ),
305313 ax = ax [1 ],
306314 plot_hdi_kwargs = {"color" : "C1" },
307315 )
308316 ax [1 ].axhline (y = 0 , c = "k" )
309317 ax [1 ].fill_between (
310318 self .datapost .index ,
311319 y1 = self .post_impact .mean (["chain" , "draw" ]).sel (
312- treated_units = self . treated_units [ primary_unit_idx ]
320+ treated_units = primary_unit_name
313321 ),
314322 color = "C0" ,
315323 alpha = 0.25 ,
@@ -321,9 +329,7 @@ def _bayesian_plot(
321329 ax [2 ].set (title = f"Cumulative Causal Impact ({ primary_unit_name } )" )
322330 plot_xY (
323331 self .datapost .index ,
324- self .post_impact_cumulative .sel (
325- treated_units = self .treated_units [primary_unit_idx ]
326- ),
332+ self .post_impact_cumulative .sel (treated_units = primary_unit_name ),
327333 ax = ax [2 ],
328334 plot_hdi_kwargs = {"color" : "C1" },
329335 )
@@ -365,31 +371,39 @@ def _bayesian_plot(
365371 return fig , ax
366372
367373 def _ols_plot (
368- self , round_to = None , treated_unit = None , ** kwargs
374+ self , round_to = None , treated_unit : str | None = None , ** kwargs
369375 ) -> tuple [plt .Figure , List [plt .Axes ]]:
370376 """
371377 Plot the results for OLS model for a specific treated unit
372378
373379 :param round_to:
374380 Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
375381 :param treated_unit:
376- Which treated unit to plot. Can be an integer index or string name .
382+ Which treated unit to plot. Must be a string name of the treated unit .
377383 If None, plots the first treated unit.
378384 """
379385 counterfactual_label = "Counterfactual"
380- primary_unit_idx = self ._get_primary_treated_unit_index (treated_unit )
381- primary_unit_name = self .treated_units [primary_unit_idx ]
386+
387+ # Get treated unit name - default to first unit if None
388+ primary_unit_name = (
389+ treated_unit if treated_unit is not None else self .treated_units [0 ]
390+ )
391+
392+ if primary_unit_name not in self .treated_units :
393+ raise ValueError (
394+ f"treated_unit '{ primary_unit_name } ' not found. Available units: { self .treated_units } "
395+ )
382396
383397 fig , ax = plt .subplots (3 , 1 , sharex = True , figsize = (7 , 8 ))
384398
385399 ax [0 ].plot (
386400 self .datapre_treated ["obs_ind" ],
387- self .datapre_treated .isel (treated_units = primary_unit_idx ),
401+ self .datapre_treated .sel (treated_units = primary_unit_name ),
388402 "k." ,
389403 )
390404 ax [0 ].plot (
391405 self .datapost_treated ["obs_ind" ],
392- self .datapost_treated .isel (treated_units = primary_unit_idx ),
406+ self .datapost_treated .sel (treated_units = primary_unit_name ),
393407 "k." ,
394408 )
395409
@@ -422,7 +436,7 @@ def _ols_plot(
422436 self .datapost .index ,
423437 y1 = post_pred_values ,
424438 y2 = np .squeeze (
425- self .datapost_treated .isel (treated_units = primary_unit_idx ).data
439+ self .datapost_treated .sel (treated_units = primary_unit_name ).data
426440 ),
427441 color = "C0" ,
428442 alpha = 0.25 ,
@@ -482,15 +496,15 @@ def get_plot_data_ols(self) -> pd.DataFrame:
482496 return self .plot_data
483497
484498 def get_plot_data_bayesian (
485- self , hdi_prob : float = 0.94 , treated_unit = None
499+ self , hdi_prob : float = 0.94 , treated_unit : str | None = None
486500 ) -> pd .DataFrame :
487501 """
488502 Recover the data of the PrePostFit experiment along with the prediction and causal impact information.
489503
490504 :param hdi_prob:
491505 Prob for which the highest density interval will be computed. The default value is defined as the default from the :func:`arviz.hdi` function.
492506 :param treated_unit:
493- Which treated unit to extract data for. Can be an integer index or string name .
507+ Which treated unit to extract data for. Must be a string name of the treated unit .
494508 If None, uses the first treated unit.
495509 """
496510 if not isinstance (self .model , PyMCModel ):
@@ -506,8 +520,15 @@ def get_plot_data_bayesian(
506520 pre_data = self .datapre .copy ()
507521 post_data = self .datapost .copy ()
508522
509- # Get primary treated unit index for data extraction
510- primary_unit_idx = self ._get_primary_treated_unit_index (treated_unit )
523+ # Get treated unit name - default to first unit if None
524+ primary_unit_name = (
525+ treated_unit if treated_unit is not None else self .treated_units [0 ]
526+ )
527+
528+ if primary_unit_name not in self .treated_units :
529+ raise ValueError (
530+ f"treated_unit '{ primary_unit_name } ' not found. Available units: { self .treated_units } "
531+ )
511532
512533 # Extract predictions - handle multi-unit case
513534 pre_pred_vals = az .extract (
@@ -519,11 +540,11 @@ def get_plot_data_bayesian(
519540
520541 if len (self .treated_units ) > 1 :
521542 # Multi-unit case: extract primary unit
522- pre_data ["prediction" ] = pre_pred_vals .isel (
523- treated_units = primary_unit_idx
543+ pre_data ["prediction" ] = pre_pred_vals .sel (
544+ treated_units = primary_unit_name
524545 ).values
525- post_data ["prediction" ] = post_pred_vals .isel (
526- treated_units = primary_unit_idx
546+ post_data ["prediction" ] = post_pred_vals .sel (
547+ treated_units = primary_unit_name
527548 ).values
528549 else :
529550 # Single unit case
@@ -533,14 +554,14 @@ def get_plot_data_bayesian(
533554 # HDI intervals for predictions
534555 if len (self .treated_units ) > 1 :
535556 pre_hdi = get_hdi_to_df (
536- self .pre_pred ["posterior_predictive" ].mu .isel (
537- treated_units = primary_unit_idx
557+ self .pre_pred ["posterior_predictive" ].mu .sel (
558+ treated_units = primary_unit_name
538559 ),
539560 hdi_prob = hdi_prob ,
540561 )
541562 post_hdi = get_hdi_to_df (
542- self .post_pred ["posterior_predictive" ].mu .isel (
543- treated_units = primary_unit_idx
563+ self .post_pred ["posterior_predictive" ].mu .sel (
564+ treated_units = primary_unit_name
544565 ),
545566 hdi_prob = hdi_prob ,
546567 )
@@ -562,21 +583,21 @@ def get_plot_data_bayesian(
562583 # Impact data - always use primary unit for main dataframe
563584 pre_data ["impact" ] = (
564585 self .pre_impact .mean (dim = ["chain" , "draw" ])
565- .isel (treated_units = primary_unit_idx )
586+ .sel (treated_units = primary_unit_name )
566587 .values
567588 )
568589 post_data ["impact" ] = (
569590 self .post_impact .mean (dim = ["chain" , "draw" ])
570- .isel (treated_units = primary_unit_idx )
591+ .sel (treated_units = primary_unit_name )
571592 .values
572593 )
573594 # Impact HDI intervals - use primary unit
574595 if len (self .treated_units ) > 1 :
575596 pre_impact_hdi = get_hdi_to_df (
576- self .pre_impact .isel (treated_units = primary_unit_idx ), hdi_prob = hdi_prob
597+ self .pre_impact .sel (treated_units = primary_unit_name ), hdi_prob = hdi_prob
577598 )
578599 post_impact_hdi = get_hdi_to_df (
579- self .post_impact .isel (treated_units = primary_unit_idx ), hdi_prob = hdi_prob
600+ self .post_impact .sel (treated_units = primary_unit_name ), hdi_prob = hdi_prob
580601 )
581602 else :
582603 pre_impact_hdi = get_hdi_to_df (self .pre_impact , hdi_prob = hdi_prob )
@@ -617,30 +638,3 @@ def _get_score_title(self, round_to=None):
617638 else :
618639 # OLS model - score is typically a simple float
619640 return f"$R^2$ on pre-intervention data = { round_num (self .score , round_to )} "
620-
621- def _get_primary_treated_unit_index (self , treated_unit = None ):
622- """Get the index for the treated unit to plot.
623-
624- :param treated_unit: Optional. Either an integer index or string name of the treated unit.
625- If None, defaults to the first treated unit (index 0).
626- """
627- if treated_unit is None :
628- return 0
629- elif isinstance (treated_unit , int ):
630- if 0 <= treated_unit < len (self .treated_units ):
631- return treated_unit
632- else :
633- raise ValueError (
634- f"treated_unit index { treated_unit } out of range. Valid range: 0-{ len (self .treated_units ) - 1 } "
635- )
636- elif isinstance (treated_unit , str ):
637- if treated_unit in self .treated_units :
638- return self .treated_units .index (treated_unit )
639- else :
640- raise ValueError (
641- f"treated_unit '{ treated_unit } ' not found. Available units: { self .treated_units } "
642- )
643- else :
644- raise ValueError (
645- "treated_unit must be an integer index, string name, or None"
646- )
0 commit comments