Skip to content

Commit d3d531f

Browse files
authored
Merge pull request #184 from siemdejong/issue-183
fix: changed Subplot to Axes, changed the minimum support version to Python 3.9, upgraded dependencies.
2 parents 360fdc3 + 3ac4852 commit d3d531f

File tree

10 files changed

+124
-98
lines changed

10 files changed

+124
-98
lines changed

.github/workflows/test-pytest.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ jobs:
88
- uses: actions/checkout@v3
99
- uses: actions/setup-python@v4
1010
with:
11-
python-version: 3.8
11+
python-version: 3.9
1212
cache: "pip"
1313
cache-dependency-path: settings.ini
1414
- name: Run pytest

dabest/_dabest_object.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# %% ../nbs/API/dabest_object.ipynb 4
99
# Import standard data science libraries
1010
from numpy import array, repeat, random, issubdtype, number
11+
import numpy as np
1112
import pandas as pd
1213
from scipy.stats import norm
1314
from scipy.stats import randint
@@ -479,7 +480,7 @@ def _check_errors(self, x, y, idx, experiment, experiment_label, x1_level):
479480

480481
# Handling str type condition
481482
if is_str_condition_met:
482-
if len(pd.unique(idx).tolist()) != 2:
483+
if len(np.unique(idx).tolist()) != 2:
483484
err0 = "`mini_meta` is True, but `idx` ({})".format(idx)
484485
err1 = "does not contain exactly 2 unique columns."
485486
raise ValueError(err0 + err1)
@@ -667,7 +668,7 @@ def _get_plot_data(self, x, y, all_plot_groups):
667668
all_plot_groups, ordered=True, inplace=True
668669
)
669670
else:
670-
plot_data.loc[:, self.__xvar] = pd.Categorical(
671+
plot_data[self.__xvar] = pd.Categorical(
671672
plot_data[self.__xvar], categories=all_plot_groups, ordered=True
672673
)
673674

dabest/_modidx.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@
8181
'dabest.misc_tools.get_kwargs': ('API/misc_tools.html#get_kwargs', 'dabest/misc_tools.py'),
8282
'dabest.misc_tools.get_params': ('API/misc_tools.html#get_params', 'dabest/misc_tools.py'),
8383
'dabest.misc_tools.get_plot_groups': ('API/misc_tools.html#get_plot_groups', 'dabest/misc_tools.py'),
84+
'dabest.misc_tools.get_unique_categories': ( 'API/misc_tools.html#get_unique_categories',
85+
'dabest/misc_tools.py'),
8486
'dabest.misc_tools.get_varname': ('API/misc_tools.html#get_varname', 'dabest/misc_tools.py'),
8587
'dabest.misc_tools.initialize_fig': ('API/misc_tools.html#initialize_fig', 'dabest/misc_tools.py'),
8688
'dabest.misc_tools.merge_two_dicts': ('API/misc_tools.html#merge_two_dicts', 'dabest/misc_tools.py'),

dabest/misc_tools.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/misc_tools.ipynb.
44

55
# %% auto 0
6-
__all__ = ['merge_two_dicts', 'unpack_and_add', 'print_greeting', 'get_varname', 'get_params', 'get_kwargs', 'get_color_palette',
7-
'initialize_fig', 'get_plot_groups', 'add_counts_to_ticks', 'extract_contrast_plotting_ticks',
8-
'set_xaxis_ticks_and_lims', 'show_legend', 'Gardner_Altman_Plot_Aesthetic_Adjustments',
9-
'Cumming_Plot_Aesthetic_Adjustments', 'General_Plot_Aesthetic_Adjustments']
6+
__all__ = ['merge_two_dicts', 'unpack_and_add', 'print_greeting', 'get_varname', 'get_unique_categories', 'get_params',
7+
'get_kwargs', 'get_color_palette', 'initialize_fig', 'get_plot_groups', 'add_counts_to_ticks',
8+
'extract_contrast_plotting_ticks', 'set_xaxis_ticks_and_lims', 'show_legend',
9+
'Gardner_Altman_Plot_Aesthetic_Adjustments', 'Cumming_Plot_Aesthetic_Adjustments',
10+
'General_Plot_Aesthetic_Adjustments']
1011

1112
# %% ../nbs/API/misc_tools.ipynb 4
1213
import datetime as dt
@@ -78,6 +79,19 @@ def get_varname(obj):
7879
if len(matching_vars) > 0:
7980
return matching_vars[0]
8081
return ""
82+
83+
84+
def get_unique_categories(names):
85+
"""
86+
Extract unique categories from various input types.
87+
"""
88+
if isinstance(names, np.ndarray):
89+
return names # numpy.unique() returns a sorted array
90+
elif isinstance(names, (pd.Categorical, pd.Series)):
91+
return names.cat.categories if hasattr(names, 'cat') else names.unique()
92+
else:
93+
# For dict_keys and other iterables
94+
return np.unique(list(names))
8195

8296
def get_params(effectsize_df, plot_kwargs):
8397
"""
@@ -369,6 +383,7 @@ def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs, idx, all_plot_gr
369383
raise ValueError(err1 + err2)
370384

371385
if custom_pal is None and color_col is None:
386+
categories = get_unique_categories(names)
372387
swarm_colors = [sns.desaturate(c, swarm_desat) for c in unsat_colors]
373388
contrast_colors = [sns.desaturate(c, contrast_desat) for c in unsat_colors]
374389
bar_color = [sns.desaturate(c, bar_desat) for c in unsat_colors]
@@ -382,9 +397,9 @@ def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs, idx, all_plot_gr
382397
plot_palette_raw[names_i] = swarm_colors[i]
383398
plot_palette_contrast[names_i] = contrast_colors[i]
384399
else:
385-
plot_palette_raw = dict(zip(names.categories, swarm_colors))
386-
plot_palette_contrast = dict(zip(names.categories, contrast_colors))
387-
plot_palette_bar = dict(zip(names.categories, bar_color))
400+
plot_palette_raw = dict(zip(categories, swarm_colors))
401+
plot_palette_contrast = dict(zip(categories, contrast_colors))
402+
plot_palette_bar = dict(zip(categories, bar_color))
388403

389404
# For Sankey Diagram plot, no need to worry about the color, each bar will have the same two colors
390405
# default color palette will be set to "hls"
@@ -541,7 +556,7 @@ def get_plot_groups(is_paired, idx, proportional, all_plot_groups):
541556

542557
def add_counts_to_ticks(plot_data, xvar, yvar, rawdata_axes, plot_kwargs):
543558
# Add the counts to the rawdata axes xticks.
544-
counts = plot_data.groupby(xvar).count()[yvar]
559+
counts = plot_data.groupby(xvar, observed=False).count()[yvar]
545560

546561
def lookup_value(text):
547562
try:
@@ -695,19 +710,19 @@ def Gardner_Altman_Plot_Aesthetic_Adjustments(effect_size_type, plot_data, xvar,
695710
# Check that the effect size is within the swarm ylims.
696711
if effect_size_type in ["mean_diff", "cohens_d", "hedges_g", "cohens_h"]:
697712
control_group_summary = (
698-
plot_data.groupby(xvar)
713+
plot_data.groupby(xvar, observed=False)
699714
.mean(numeric_only=True)
700715
.loc[current_control, yvar]
701716
)
702717
test_group_summary = (
703-
plot_data.groupby(xvar).mean(numeric_only=True).loc[current_group, yvar]
718+
plot_data.groupby(xvar, observed=False).mean(numeric_only=True).loc[current_group, yvar]
704719
)
705720
elif effect_size_type == "median_diff":
706721
control_group_summary = (
707-
plot_data.groupby(xvar).median().loc[current_control, yvar]
722+
plot_data.groupby(xvar, observed=False).median(numeric_only=True).loc[current_control, yvar]
708723
)
709724
test_group_summary = (
710-
plot_data.groupby(xvar).median().loc[current_group, yvar]
725+
plot_data.groupby(xvar, observed=False).median(numeric_only=True).loc[current_group, yvar]
711726
)
712727

713728
if swarm_ylim is None:
@@ -751,7 +766,7 @@ def Gardner_Altman_Plot_Aesthetic_Adjustments(effect_size_type, plot_data, xvar,
751766
pooled_sd = stds[0]
752767

753768
if effect_size_type == "hedges_g":
754-
gby_count = plot_data.groupby(xvar).count()
769+
gby_count = plot_data.groupby(xvar, observed=False).count()
755770
len_control = gby_count.loc[current_control, yvar]
756771
len_test = gby_count.loc[current_group, yvar]
757772

dabest/plot_tools.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -120,25 +120,25 @@ def error_bar(
120120
else:
121121
group_order = pd.unique(data[x])
122122

123-
means = data.groupby(x)[y].mean().reindex(index=group_order)
123+
means = data.groupby(x, observed=False)[y].mean().reindex(index=group_order)
124124

125125
if method in ["proportional_error_bar", "sankey_error_bar"]:
126126
g = lambda x: np.sqrt(
127127
(np.sum(x) * (len(x) - np.sum(x))) / (len(x) * len(x) * len(x))
128128
)
129-
sd = data.groupby(x)[y].apply(g)
129+
sd = data.groupby(x, observed=False)[y].apply(g)
130130
else:
131-
sd = data.groupby(x)[y].std().reindex(index=group_order)
131+
sd = data.groupby(x, observed=False)[y].std().reindex(index=group_order)
132132

133133
lower_sd = means - sd
134134
upper_sd = means + sd
135135

136136
if (lower_sd < ax_ylims[0]).any() or (upper_sd > ax_ylims[1]).any():
137137
kwargs["clip_on"] = True
138138

139-
medians = data.groupby(x)[y].median().reindex(index=group_order)
139+
medians = data.groupby(x, observed=False)[y].median().reindex(index=group_order)
140140
quantiles = (
141-
data.groupby(x)[y].quantile([0.25, 0.75]).unstack().reindex(index=group_order)
141+
data.groupby(x, observed=False)[y].quantile([0.25, 0.75]).unstack().reindex(index=group_order)
142142
)
143143
lower_quartiles = quantiles[0.25]
144144
upper_quartiles = quantiles[0.75]
@@ -978,7 +978,7 @@ def swarm_bars_plotter(plot_data: object, xvar: str, yvar: str, ax: object,
978978
else:
979979
swarm_bars_order = pd.unique(plot_data[xvar])
980980

981-
swarm_means = plot_data.groupby(xvar)[yvar].mean().reindex(index=swarm_bars_order)
981+
swarm_means = plot_data.groupby(xvar, observed=False)[yvar].mean().reindex(index=swarm_bars_order)
982982
swarm_bars_colors = (
983983
[swarm_bars_kwargs.get('color')] * (max(swarm_bars_order) + 1)
984984
if swarm_bars_kwargs.get('color') is not None
@@ -1199,7 +1199,7 @@ def slopegraph_plotter(dabest_obj, plot_data, xvar, yvar, color_col, plot_palett
11991199
if color_col is None:
12001200
slopegraph_kwargs["color"] = ytick_color
12011201
else:
1202-
color_key = observation[color_col][0]
1202+
color_key = observation[color_col].iloc[0]
12031203
if isinstance(color_key, (str, np.int64, np.float64)):
12041204
slopegraph_kwargs["color"] = plot_palette_raw[color_key]
12051205
slopegraph_kwargs["label"] = color_key
@@ -1497,7 +1497,7 @@ def swarmplot(
14971497
data: pd.DataFrame,
14981498
x: str,
14991499
y: str,
1500-
ax: axes.Subplot,
1500+
ax: axes.Axes,
15011501
order: List = None,
15021502
hue: str = None,
15031503
palette: Union[Iterable, str] = "black",
@@ -1521,8 +1521,8 @@ def swarmplot(
15211521
The column in the DataFrame to be used as the x-axis.
15221522
y : str
15231523
The column in the DataFrame to be used as the y-axis.
1524-
ax : axes._subplots.Subplot | axes._axes.Axes
1525-
Matplotlib AxesSubplot object for which the plot would be drawn on. Default is None.
1524+
ax : axes.Axes
1525+
Matplotlib axes.Axes object for which the plot would be drawn on. Default is None.
15261526
order : List
15271527
The order in which x-axis categories should be displayed. Default is None.
15281528
hue : str
@@ -1552,8 +1552,8 @@ def swarmplot(
15521552
15531553
Returns
15541554
-------
1555-
axes._subplots.Subplot | axes._axes.Axes
1556-
Matplotlib AxesSubplot object for which the swarm plot has been drawn on.
1555+
axes.Axes
1556+
Matplotlib axes.Axes object for which the swarm plot has been drawn on.
15571557
"""
15581558
s = SwarmPlot(data, x, y, ax, order, hue, palette, zorder, size, side, jitter)
15591559
ax = s.plot(is_drop_gutter, gutter_limit, ax, filled, **kwargs)
@@ -1566,7 +1566,7 @@ def __init__(
15661566
data: pd.DataFrame,
15671567
x: str,
15681568
y: str,
1569-
ax: axes.Subplot,
1569+
ax: axes.Axes,
15701570
order: List = None,
15711571
hue: str = None,
15721572
palette: Union[Iterable, str] = "black",
@@ -1586,8 +1586,8 @@ def __init__(
15861586
The column in the DataFrame to be used as the x-axis.
15871587
y : str
15881588
The column in the DataFrame to be used as the y-axis.
1589-
ax : axes.Subplot
1590-
Matplotlib AxesSubplot object for which the plot would be drawn on.
1589+
ax : axes.Axes
1590+
Matplotlib axes.Axes object for which the plot would be drawn on.
15911591
order : List
15921592
The order in which x-axis categories should be displayed. Default is None.
15931593
hue : str
@@ -1674,7 +1674,7 @@ def __init__(
16741674
self.__dsize = dsize
16751675

16761676
def _check_errors(
1677-
self, data: pd.DataFrame, ax: axes.Subplot, size: float, side: str
1677+
self, data: pd.DataFrame, ax: axes.Axes, size: float, side: str
16781678
) -> None:
16791679
"""
16801680
Check the validity of input parameters. Raises exceptions if detected.
@@ -1683,8 +1683,8 @@ def _check_errors(
16831683
----------
16841684
data : pd.Dataframe
16851685
Input data used for generation of the swarmplot.
1686-
ax : axes.Subplot
1687-
Matplotlib AxesSubplot object for which the plot would be drawn on.
1686+
ax : axes.Axes
1687+
Matplotlib axes.Axes object for which the plot would be drawn on.
16881688
size : int | float
16891689
scalar value determining size of dots of the swarmplot.
16901690
side: str
@@ -1697,9 +1697,9 @@ def _check_errors(
16971697
# Type enforcement
16981698
if not isinstance(data, pd.DataFrame):
16991699
raise ValueError("`data` must be a Pandas Dataframe.")
1700-
if not isinstance(ax, (axes._subplots.Subplot, axes._axes.Axes)):
1700+
if not isinstance(ax, axes.Axes):
17011701
raise ValueError(
1702-
f"`ax` must be a Matplotlib AxesSubplot. The current `ax` is a {type(ax)}"
1702+
f"`ax` must be a Matplotlib axes.Axes. The current `ax` is a {type(ax)}"
17031703
)
17041704
if not isinstance(size, (int, float)):
17051705
raise ValueError("`size` must be a scalar or float.")
@@ -1859,9 +1859,10 @@ def _swarm(
18591859
raise ValueError("`dsize` must be a scalar or float.")
18601860

18611861
# Sorting algorithm based off of: https://github.com/mgymrek/pybeeswarm
1862-
points_data = pd.DataFrame(
1863-
{"y": [yval * 1.0 / dsize for yval in values], "x": [0] * len(values)}
1864-
)
1862+
points_data = pd.DataFrame({
1863+
"y": [yval * 1.0 / dsize for yval in values],
1864+
"x": np.zeros(len(values), dtype=float) # Initialize with float zeros
1865+
})
18651866
for i in range(1, points_data.shape[0]):
18661867
y_i = points_data["y"].values[i]
18671868
points_placed = points_data[0:i]
@@ -1968,7 +1969,7 @@ def plot(
19681969
ax: axes.Subplot,
19691970
filled: Union[bool, List, Tuple],
19701971
**kwargs,
1971-
) -> axes.Subplot:
1972+
) -> axes.Axes:
19721973
"""
19731974
Generate a swarm plot.
19741975
@@ -1978,7 +1979,7 @@ def plot(
19781979
If True, drop points that hit the gutters; otherwise, readjust them.
19791980
gutter_limit : int | float
19801981
The limit for points hitting the gutters.
1981-
ax : axes.Subplot
1982+
ax : axes.Axes
19821983
The matplotlib figure object to which the swarm plot will be added.
19831984
filled : bool | List | Tuple
19841985
Determines whether the dots in the swarmplot are filled or not. If set to False,
@@ -1990,8 +1991,8 @@ def plot(
19901991
19911992
Returns
19921993
-------
1993-
axes.Subplot:
1994-
The matplotlib figure containing the swarm plot.
1994+
axes.Axes:
1995+
The matplotlib axes containing the swarm plot.
19951996
"""
19961997
# Input validation
19971998
if not isinstance(is_drop_gutter, bool):
@@ -2019,8 +2020,7 @@ def plot(
20192020
0 # x-coordinate of center of each individual swarm of the swarm plot
20202021
)
20212022
x_tick_tabels = []
2022-
2023-
for group_i, values_i in self.__data_copy.groupby(self.__x):
2023+
for group_i, values_i in self.__data_copy.groupby(self.__x, observed=False):
20242024
x_new = []
20252025
values_i_y = values_i[self.__y]
20262026
x_offset = self._swarm(

nbs/API/dabest_object.ipynb

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
"#| export\n",
5858
"# Import standard data science libraries\n",
5959
"from numpy import array, repeat, random, issubdtype, number\n",
60+
"import numpy as np\n",
6061
"import pandas as pd\n",
6162
"from scipy.stats import norm\n",
6263
"from scipy.stats import randint"
@@ -547,7 +548,7 @@
547548
"\n",
548549
" # Handling str type condition\n",
549550
" if is_str_condition_met:\n",
550-
" if len(pd.unique(idx).tolist()) != 2:\n",
551+
" if len(np.unique(idx).tolist()) != 2:\n",
551552
" err0 = \"`mini_meta` is True, but `idx` ({})\".format(idx)\n",
552553
" err1 = \"does not contain exactly 2 unique columns.\"\n",
553554
" raise ValueError(err0 + err1)\n",
@@ -735,7 +736,7 @@
735736
" all_plot_groups, ordered=True, inplace=True\n",
736737
" )\n",
737738
" else:\n",
738-
" plot_data.loc[:, self.__xvar] = pd.Categorical(\n",
739+
" plot_data[self.__xvar] = pd.Categorical(\n",
739740
" plot_data[self.__xvar], categories=all_plot_groups, ordered=True\n",
740741
" )\n",
741742
"\n",

0 commit comments

Comments
 (0)