@@ -2594,6 +2594,8 @@ def colorbar(
25942594 """
25952595 # Backwards compatibility
25962596 ax = kwargs .pop ("ax" , None )
2597+ ref = kwargs .pop ("ref" , None )
2598+ loc_ax = ref if ref is not None else ax
25972599 cax = kwargs .pop ("cax" , None )
25982600 if isinstance (values , maxes .Axes ):
25992601 cax = _not_none (cax_positional = values , cax = cax )
@@ -2613,20 +2615,102 @@ def colorbar(
26132615 with context ._state_context (cax , _internal_call = True ): # do not wrap pcolor
26142616 cb = super ().colorbar (mappable , cax = cax , ** kwargs )
26152617 # Axes panel colorbar
2616- elif ax is not None :
2618+ elif loc_ax is not None :
26172619 # Check if span parameters are provided
26182620 has_span = _not_none (span , row , col , rows , cols ) is not None
26192621
2622+ # Infer span from loc_ax if it is a list and no span provided
2623+ if (
2624+ not has_span
2625+ and np .iterable (loc_ax )
2626+ and not isinstance (loc_ax , (str , maxes .Axes ))
2627+ ):
2628+ loc_trans = _translate_loc (loc , "colorbar" , default = rc ["colorbar.loc" ])
2629+ side = (
2630+ loc_trans
2631+ if loc_trans in ("left" , "right" , "top" , "bottom" )
2632+ else None
2633+ )
2634+
2635+ if side :
2636+ r_min , r_max = float ("inf" ), float ("-inf" )
2637+ c_min , c_max = float ("inf" ), float ("-inf" )
2638+ valid_ax = False
2639+ for axi in loc_ax :
2640+ if not hasattr (axi , "get_subplotspec" ):
2641+ continue
2642+ ss = axi .get_subplotspec ()
2643+ if ss is None :
2644+ continue
2645+ ss = ss .get_topmost_subplotspec ()
2646+ r1 , r2 , c1 , c2 = ss ._get_rows_columns ()
2647+ r_min = min (r_min , r1 )
2648+ r_max = max (r_max , r2 )
2649+ c_min = min (c_min , c1 )
2650+ c_max = max (c_max , c2 )
2651+ valid_ax = True
2652+
2653+ if valid_ax :
2654+ if side in ("left" , "right" ):
2655+ rows = (r_min + 1 , r_max + 1 )
2656+ else :
2657+ cols = (c_min + 1 , c_max + 1 )
2658+ has_span = True
2659+
26202660 # Extract a single axes from array if span is provided
26212661 # Otherwise, pass the array as-is for normal colorbar behavior
2622- if has_span and np .iterable (ax ) and not isinstance (ax , (str , maxes .Axes )):
2623- try :
2624- ax_single = next (iter (ax ))
2662+ if (
2663+ has_span
2664+ and np .iterable (loc_ax )
2665+ and not isinstance (loc_ax , (str , maxes .Axes ))
2666+ ):
2667+ # Pick the best axis to anchor to based on the colorbar side
2668+ loc_trans = _translate_loc (loc , "colorbar" , default = rc ["colorbar.loc" ])
2669+ side = (
2670+ loc_trans
2671+ if loc_trans in ("left" , "right" , "top" , "bottom" )
2672+ else None
2673+ )
26252674
2626- except (TypeError , StopIteration ):
2627- ax_single = ax
2675+ best_ax = None
2676+ best_coord = float ("-inf" )
2677+
2678+ # If side is determined, search for the edge axis
2679+ if side :
2680+ for axi in loc_ax :
2681+ if not hasattr (axi , "get_subplotspec" ):
2682+ continue
2683+ ss = axi .get_subplotspec ()
2684+ if ss is None :
2685+ continue
2686+ ss = ss .get_topmost_subplotspec ()
2687+ r1 , r2 , c1 , c2 = ss ._get_rows_columns ()
2688+
2689+ if side == "right" :
2690+ val = c2 # Maximize column index
2691+ elif side == "left" :
2692+ val = - c1 # Minimize column index
2693+ elif side == "bottom" :
2694+ val = r2 # Maximize row index
2695+ elif side == "top" :
2696+ val = - r1 # Minimize row index
2697+ else :
2698+ val = 0
2699+
2700+ if val > best_coord :
2701+ best_coord = val
2702+ best_ax = axi
2703+
2704+ # Fallback to first axis
2705+ if best_ax is None :
2706+ try :
2707+ ax_single = next (iter (loc_ax ))
2708+ except (TypeError , StopIteration ):
2709+ ax_single = loc_ax
2710+ else :
2711+ ax_single = best_ax
26282712 else :
2629- ax_single = ax
2713+ ax_single = loc_ax
26302714
26312715 # Pass span parameters through to axes colorbar
26322716 cb = ax_single .colorbar (
@@ -2700,27 +2784,136 @@ def legend(
27002784 matplotlib.axes.Axes.legend
27012785 """
27022786 ax = kwargs .pop ("ax" , None )
2787+ ref = kwargs .pop ("ref" , None )
2788+ loc_ax = ref if ref is not None else ax
2789+
27032790 # Axes panel legend
2704- if ax is not None :
2791+ if loc_ax is not None :
2792+ content_ax = ax if ax is not None else loc_ax
27052793 # Check if span parameters are provided
27062794 has_span = _not_none (span , row , col , rows , cols ) is not None
2707- # Extract a single axes from array if span is provided
2708- # Otherwise, pass the array as-is for normal legend behavior
2709- # Automatically collect handles and labels from spanned axes if not provided
2710- if has_span and np .iterable (ax ) and not isinstance (ax , (str , maxes .Axes )):
2711- # Auto-collect handles and labels if not explicitly provided
2712- if handles is None and labels is None :
2713- handles , labels = [], []
2714- for axi in ax :
2795+
2796+ # Automatically collect handles and labels from content axes if not provided
2797+ # Case 1: content_ax is a list (we must auto-collect)
2798+ # Case 2: content_ax != loc_ax (we must auto-collect because loc_ax.legend won't find content_ax handles)
2799+ must_collect = (
2800+ np .iterable (content_ax )
2801+ and not isinstance (content_ax , (str , maxes .Axes ))
2802+ ) or (content_ax is not loc_ax )
2803+
2804+ if must_collect and handles is None and labels is None :
2805+ handles , labels = [], []
2806+ # Handle list of axes
2807+ if np .iterable (content_ax ) and not isinstance (
2808+ content_ax , (str , maxes .Axes )
2809+ ):
2810+ for axi in content_ax :
27152811 h , l = axi .get_legend_handles_labels ()
27162812 handles .extend (h )
27172813 labels .extend (l )
2718- try :
2719- ax_single = next (iter (ax ))
2720- except (TypeError , StopIteration ):
2721- ax_single = ax
2814+ # Handle single axis
2815+ else :
2816+ handles , labels = content_ax .get_legend_handles_labels ()
2817+
2818+ # Infer span from loc_ax if it is a list and no span provided
2819+ if (
2820+ not has_span
2821+ and np .iterable (loc_ax )
2822+ and not isinstance (loc_ax , (str , maxes .Axes ))
2823+ ):
2824+ loc_trans = _translate_loc (loc , "legend" , default = rc ["legend.loc" ])
2825+ side = (
2826+ loc_trans
2827+ if loc_trans in ("left" , "right" , "top" , "bottom" )
2828+ else None
2829+ )
2830+
2831+ if side :
2832+ r_min , r_max = float ("inf" ), float ("-inf" )
2833+ c_min , c_max = float ("inf" ), float ("-inf" )
2834+ valid_ax = False
2835+ for axi in loc_ax :
2836+ if not hasattr (axi , "get_subplotspec" ):
2837+ continue
2838+ ss = axi .get_subplotspec ()
2839+ if ss is None :
2840+ continue
2841+ ss = ss .get_topmost_subplotspec ()
2842+ r1 , r2 , c1 , c2 = ss ._get_rows_columns ()
2843+ r_min = min (r_min , r1 )
2844+ r_max = max (r_max , r2 )
2845+ c_min = min (c_min , c1 )
2846+ c_max = max (c_max , c2 )
2847+ valid_ax = True
2848+
2849+ if valid_ax :
2850+ if side in ("left" , "right" ):
2851+ rows = (r_min + 1 , r_max + 1 )
2852+ else :
2853+ cols = (c_min + 1 , c_max + 1 )
2854+ has_span = True
2855+
2856+ # Extract a single axes from array if span is provided (or if ref is a list)
2857+ # Otherwise, pass the array as-is for normal legend behavior (only if loc_ax is list)
2858+ if (
2859+ has_span
2860+ and np .iterable (loc_ax )
2861+ and not isinstance (loc_ax , (str , maxes .Axes ))
2862+ ):
2863+ # Pick the best axis to anchor to based on the legend side
2864+ loc_trans = _translate_loc (loc , "legend" , default = rc ["legend.loc" ])
2865+ side = (
2866+ loc_trans
2867+ if loc_trans in ("left" , "right" , "top" , "bottom" )
2868+ else None
2869+ )
2870+
2871+ best_ax = None
2872+ best_coord = float ("-inf" )
2873+
2874+ # If side is determined, search for the edge axis
2875+ if side :
2876+ for axi in loc_ax :
2877+ if not hasattr (axi , "get_subplotspec" ):
2878+ continue
2879+ ss = axi .get_subplotspec ()
2880+ if ss is None :
2881+ continue
2882+ ss = ss .get_topmost_subplotspec ()
2883+ r1 , r2 , c1 , c2 = ss ._get_rows_columns ()
2884+
2885+ if side == "right" :
2886+ val = c2 # Maximize column index
2887+ elif side == "left" :
2888+ val = - c1 # Minimize column index
2889+ elif side == "bottom" :
2890+ val = r2 # Maximize row index
2891+ elif side == "top" :
2892+ val = - r1 # Minimize row index
2893+ else :
2894+ val = 0
2895+
2896+ if val > best_coord :
2897+ best_coord = val
2898+ best_ax = axi
2899+
2900+ # Fallback to first axis if no best axis found (or side is None)
2901+ if best_ax is None :
2902+ try :
2903+ ax_single = next (iter (loc_ax ))
2904+ except (TypeError , StopIteration ):
2905+ ax_single = loc_ax
2906+ else :
2907+ ax_single = best_ax
2908+
27222909 else :
2723- ax_single = ax
2910+ ax_single = loc_ax
2911+ if isinstance (ax_single , list ):
2912+ try :
2913+ ax_single = pgridspec .SubplotGrid (ax_single )
2914+ except ValueError :
2915+ ax_single = ax_single [0 ]
2916+
27242917 leg = ax_single .legend (
27252918 handles ,
27262919 labels ,
0 commit comments