@@ -2594,6 +2594,8 @@ def colorbar(
25942594 """
25952595 # Backwards compatibility
25962596 ax = kwargs .pop ("ax" , None )
2597+ ref = kwargs .pop ("ref" , None )
2598+ loc_ax = ref if ref is not None else ax
25972599 cax = kwargs .pop ("cax" , None )
25982600 if isinstance (values , maxes .Axes ):
25992601 cax = _not_none (cax_positional = values , cax = cax )
@@ -2613,20 +2615,96 @@ def colorbar(
26132615 with context ._state_context (cax , _internal_call = True ): # do not wrap pcolor
26142616 cb = super ().colorbar (mappable , cax = cax , ** kwargs )
26152617 # Axes panel colorbar
2616- elif ax is not None :
2618+ elif loc_ax is not None :
26172619 # Check if span parameters are provided
26182620 has_span = _not_none (span , row , col , rows , cols ) is not None
26192621
2622+ # Infer span from loc_ax if it is a list and no span provided
2623+ if (
2624+ not has_span
2625+ and np .iterable (loc_ax )
2626+ and not isinstance (loc_ax , (str , maxes .Axes ))
2627+ ):
2628+ loc_trans = _translate_loc (loc , "colorbar" , default = rc ["colorbar.loc" ])
2629+ side = (
2630+ loc_trans
2631+ if loc_trans in ("left" , "right" , "top" , "bottom" )
2632+ else None
2633+ )
2634+
2635+ if side :
2636+ r_min , r_max = float ("inf" ), float ("-inf" )
2637+ c_min , c_max = float ("inf" ), float ("-inf" )
2638+ valid_ax = False
2639+ for axi in loc_ax :
2640+ if not hasattr (axi , "get_subplotspec" ):
2641+ continue
2642+ ss = axi .get_subplotspec ().get_topmost_subplotspec ()
2643+ r1 , r2 , c1 , c2 = ss ._get_rows_columns ()
2644+ r_min = min (r_min , r1 )
2645+ r_max = max (r_max , r2 )
2646+ c_min = min (c_min , c1 )
2647+ c_max = max (c_max , c2 )
2648+ valid_ax = True
2649+
2650+ if valid_ax :
2651+ if side in ("left" , "right" ):
2652+ rows = (r_min + 1 , r_max + 1 )
2653+ else :
2654+ cols = (c_min + 1 , c_max + 1 )
2655+ has_span = True
2656+
26202657 # Extract a single axes from array if span is provided
26212658 # Otherwise, pass the array as-is for normal colorbar behavior
2622- if has_span and np .iterable (ax ) and not isinstance (ax , (str , maxes .Axes )):
2623- try :
2624- ax_single = next (iter (ax ))
2659+ if (
2660+ has_span
2661+ and np .iterable (loc_ax )
2662+ and not isinstance (loc_ax , (str , maxes .Axes ))
2663+ ):
2664+ # Pick the best axis to anchor to based on the colorbar side
2665+ loc_trans = _translate_loc (loc , "colorbar" , default = rc ["colorbar.loc" ])
2666+ side = (
2667+ loc_trans
2668+ if loc_trans in ("left" , "right" , "top" , "bottom" )
2669+ else None
2670+ )
2671+
2672+ best_ax = None
2673+ best_coord = float ("-inf" )
2674+
2675+ # If side is determined, search for the edge axis
2676+ if side :
2677+ for axi in loc_ax :
2678+ if not hasattr (axi , "get_subplotspec" ):
2679+ continue
2680+ ss = axi .get_subplotspec ().get_topmost_subplotspec ()
2681+ r1 , r2 , c1 , c2 = ss ._get_rows_columns ()
2682+
2683+ if side == "right" :
2684+ val = c2 # Maximize column index
2685+ elif side == "left" :
2686+ val = - c1 # Minimize column index
2687+ elif side == "bottom" :
2688+ val = r2 # Maximize row index
2689+ elif side == "top" :
2690+ val = - r1 # Minimize row index
2691+ else :
2692+ val = 0
2693+
2694+ if val > best_coord :
2695+ best_coord = val
2696+ best_ax = axi
26252697
2626- except (TypeError , StopIteration ):
2627- ax_single = ax
2698+ # Fallback to first axis
2699+ if best_ax is None :
2700+ try :
2701+ ax_single = next (iter (loc_ax ))
2702+ except (TypeError , StopIteration ):
2703+ ax_single = loc_ax
2704+ else :
2705+ ax_single = best_ax
26282706 else :
2629- ax_single = ax
2707+ ax_single = loc_ax
26302708
26312709 # Pass span parameters through to axes colorbar
26322710 cb = ax_single .colorbar (
0 commit comments