@@ -120,25 +120,25 @@ def error_bar(
120120 else :
121121 group_order = pd .unique (data [x ])
122122
123- means = data .groupby (x )[y ].mean ().reindex (index = group_order )
123+ means = data .groupby (x , observed = False )[y ].mean ().reindex (index = group_order )
124124
125125 if method in ["proportional_error_bar" , "sankey_error_bar" ]:
126126 g = lambda x : np .sqrt (
127127 (np .sum (x ) * (len (x ) - np .sum (x ))) / (len (x ) * len (x ) * len (x ))
128128 )
129- sd = data .groupby (x )[y ].apply (g )
129+ sd = data .groupby (x , observed = False )[y ].apply (g )
130130 else :
131- sd = data .groupby (x )[y ].std ().reindex (index = group_order )
131+ sd = data .groupby (x , observed = False )[y ].std ().reindex (index = group_order )
132132
133133 lower_sd = means - sd
134134 upper_sd = means + sd
135135
136136 if (lower_sd < ax_ylims [0 ]).any () or (upper_sd > ax_ylims [1 ]).any ():
137137 kwargs ["clip_on" ] = True
138138
139- medians = data .groupby (x )[y ].median ().reindex (index = group_order )
139+ medians = data .groupby (x , observed = False )[y ].median ().reindex (index = group_order )
140140 quantiles = (
141- data .groupby (x )[y ].quantile ([0.25 , 0.75 ]).unstack ().reindex (index = group_order )
141+ data .groupby (x , observed = False )[y ].quantile ([0.25 , 0.75 ]).unstack ().reindex (index = group_order )
142142 )
143143 lower_quartiles = quantiles [0.25 ]
144144 upper_quartiles = quantiles [0.75 ]
@@ -978,7 +978,7 @@ def swarm_bars_plotter(plot_data: object, xvar: str, yvar: str, ax: object,
978978 else :
979979 swarm_bars_order = pd .unique (plot_data [xvar ])
980980
981- swarm_means = plot_data .groupby (xvar )[yvar ].mean ().reindex (index = swarm_bars_order )
981+ swarm_means = plot_data .groupby (xvar , observed = False )[yvar ].mean ().reindex (index = swarm_bars_order )
982982 swarm_bars_colors = (
983983 [swarm_bars_kwargs .get ('color' )] * (max (swarm_bars_order ) + 1 )
984984 if swarm_bars_kwargs .get ('color' ) is not None
@@ -1199,7 +1199,7 @@ def slopegraph_plotter(dabest_obj, plot_data, xvar, yvar, color_col, plot_palett
11991199 if color_col is None :
12001200 slopegraph_kwargs ["color" ] = ytick_color
12011201 else :
1202- color_key = observation [color_col ][0 ]
1202+ color_key = observation [color_col ]. iloc [0 ]
12031203 if isinstance (color_key , (str , np .int64 , np .float64 )):
12041204 slopegraph_kwargs ["color" ] = plot_palette_raw [color_key ]
12051205 slopegraph_kwargs ["label" ] = color_key
@@ -1497,7 +1497,7 @@ def swarmplot(
14971497 data : pd .DataFrame ,
14981498 x : str ,
14991499 y : str ,
1500- ax : axes .Subplot ,
1500+ ax : axes .Axes ,
15011501 order : List = None ,
15021502 hue : str = None ,
15031503 palette : Union [Iterable , str ] = "black" ,
@@ -1521,8 +1521,8 @@ def swarmplot(
15211521 The column in the DataFrame to be used as the x-axis.
15221522 y : str
15231523 The column in the DataFrame to be used as the y-axis.
1524- ax : axes._subplots.Subplot | axes._axes. Axes
1525- Matplotlib AxesSubplot object for which the plot would be drawn on. Default is None.
1524+ ax : axes.Axes
1525+ Matplotlib axes.Axes object for which the plot would be drawn on. Default is None.
15261526 order : List
15271527 The order in which x-axis categories should be displayed. Default is None.
15281528 hue : str
@@ -1552,8 +1552,8 @@ def swarmplot(
15521552
15531553 Returns
15541554 -------
1555- axes._subplots.Subplot | axes._axes. Axes
1556- Matplotlib AxesSubplot object for which the swarm plot has been drawn on.
1555+ axes.Axes
1556+ Matplotlib axes.Axes object for which the swarm plot has been drawn on.
15571557 """
15581558 s = SwarmPlot (data , x , y , ax , order , hue , palette , zorder , size , side , jitter )
15591559 ax = s .plot (is_drop_gutter , gutter_limit , ax , filled , ** kwargs )
@@ -1566,7 +1566,7 @@ def __init__(
15661566 data : pd .DataFrame ,
15671567 x : str ,
15681568 y : str ,
1569- ax : axes .Subplot ,
1569+ ax : axes .Axes ,
15701570 order : List = None ,
15711571 hue : str = None ,
15721572 palette : Union [Iterable , str ] = "black" ,
@@ -1586,8 +1586,8 @@ def __init__(
15861586 The column in the DataFrame to be used as the x-axis.
15871587 y : str
15881588 The column in the DataFrame to be used as the y-axis.
1589- ax : axes.Subplot
1590- Matplotlib AxesSubplot object for which the plot would be drawn on.
1589+ ax : axes.Axes
1590+ Matplotlib axes.Axes object for which the plot would be drawn on.
15911591 order : List
15921592 The order in which x-axis categories should be displayed. Default is None.
15931593 hue : str
@@ -1674,7 +1674,7 @@ def __init__(
16741674 self .__dsize = dsize
16751675
16761676 def _check_errors (
1677- self , data : pd .DataFrame , ax : axes .Subplot , size : float , side : str
1677+ self , data : pd .DataFrame , ax : axes .Axes , size : float , side : str
16781678 ) -> None :
16791679 """
16801680 Check the validity of input parameters. Raises exceptions if detected.
@@ -1683,8 +1683,8 @@ def _check_errors(
16831683 ----------
16841684 data : pd.Dataframe
16851685 Input data used for generation of the swarmplot.
1686- ax : axes.Subplot
1687- Matplotlib AxesSubplot object for which the plot would be drawn on.
1686+ ax : axes.Axes
1687+ Matplotlib axes.Axes object for which the plot would be drawn on.
16881688 size : int | float
16891689 scalar value determining size of dots of the swarmplot.
16901690 side: str
@@ -1697,9 +1697,9 @@ def _check_errors(
16971697 # Type enforcement
16981698 if not isinstance (data , pd .DataFrame ):
16991699 raise ValueError ("`data` must be a Pandas Dataframe." )
1700- if not isinstance (ax , ( axes ._subplots . Subplot , axes . _axes . Axes ) ):
1700+ if not isinstance (ax , axes .Axes ):
17011701 raise ValueError (
1702- f"`ax` must be a Matplotlib AxesSubplot . The current `ax` is a { type (ax )} "
1702+ f"`ax` must be a Matplotlib axes.Axes . The current `ax` is a { type (ax )} "
17031703 )
17041704 if not isinstance (size , (int , float )):
17051705 raise ValueError ("`size` must be a scalar or float." )
@@ -1859,9 +1859,10 @@ def _swarm(
18591859 raise ValueError ("`dsize` must be a scalar or float." )
18601860
18611861 # Sorting algorithm based off of: https://github.com/mgymrek/pybeeswarm
1862- points_data = pd .DataFrame (
1863- {"y" : [yval * 1.0 / dsize for yval in values ], "x" : [0 ] * len (values )}
1864- )
1862+ points_data = pd .DataFrame ({
1863+ "y" : [yval * 1.0 / dsize for yval in values ],
1864+ "x" : np .zeros (len (values ), dtype = float ) # Initialize with float zeros
1865+ })
18651866 for i in range (1 , points_data .shape [0 ]):
18661867 y_i = points_data ["y" ].values [i ]
18671868 points_placed = points_data [0 :i ]
@@ -1968,7 +1969,7 @@ def plot(
19681969 ax : axes .Subplot ,
19691970 filled : Union [bool , List , Tuple ],
19701971 ** kwargs ,
1971- ) -> axes .Subplot :
1972+ ) -> axes .Axes :
19721973 """
19731974 Generate a swarm plot.
19741975
@@ -1978,7 +1979,7 @@ def plot(
19781979 If True, drop points that hit the gutters; otherwise, readjust them.
19791980 gutter_limit : int | float
19801981 The limit for points hitting the gutters.
1981- ax : axes.Subplot
1982+ ax : axes.Axes
19821983 The matplotlib figure object to which the swarm plot will be added.
19831984 filled : bool | List | Tuple
19841985 Determines whether the dots in the swarmplot are filled or not. If set to False,
@@ -1990,8 +1991,8 @@ def plot(
19901991
19911992 Returns
19921993 -------
1993- axes.Subplot :
1994- The matplotlib figure containing the swarm plot.
1994+ axes.Axes :
1995+ The matplotlib axes containing the swarm plot.
19951996 """
19961997 # Input validation
19971998 if not isinstance (is_drop_gutter , bool ):
@@ -2019,8 +2020,7 @@ def plot(
20192020 0 # x-coordinate of center of each individual swarm of the swarm plot
20202021 )
20212022 x_tick_tabels = []
2022-
2023- for group_i , values_i in self .__data_copy .groupby (self .__x ):
2023+ for group_i , values_i in self .__data_copy .groupby (self .__x , observed = False ):
20242024 x_new = []
20252025 values_i_y = values_i [self .__y ]
20262026 x_offset = self ._swarm (
0 commit comments