Skip to content

Commit 905a4ad

Browse files
committed
Add custom_palette option for coloring barplots with 0 and 1 keys
now custom_palette dictionary can accept 0 and 1 to fill barplots
1 parent fd1b16c commit 905a4ad

File tree

9 files changed

+87
-17
lines changed

9 files changed

+87
-17
lines changed

dabest/misc_tools.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ def get_color_palette(
516516
idx: list,
517517
all_plot_groups: list,
518518
delta2: bool,
519-
sankey: bool
519+
proportional: bool
520520
):
521521
"""
522522
Create the color palette to be used in the plotter function.
@@ -537,9 +537,11 @@ def get_color_palette(
537537
A list of all the group names.
538538
delta2 : bool
539539
A boolean flag to determine if the plot will have a delta-delta effect size.
540-
sankey : bool
541-
A boolean flag to determine if the plot is for a Sankey diagram.
540+
proportional : bool
541+
A boolean flag to determine if the plot is for a proportional plot.
542542
"""
543+
sankey = True if proportional and show_pairs else False
544+
543545
# Create color palette that will be shared across subplots.
544546
color_col = plot_kwargs["color_col"]
545547
if color_col is None:
@@ -608,6 +610,17 @@ def get_color_palette(
608610
groups_in_palette = {
609611
k: custom_pal[k] for k in color_groups
610612
}
613+
elif proportional and not sankey: # barplots (unpaired proportional data)
614+
keys = list(custom_pal.keys())
615+
if all(k in keys for k in [1, 0]) and len(keys) == 2:
616+
groups_in_palette = {
617+
k: custom_pal[k] for k in [1, 0]
618+
}
619+
bootstraps_color_by_group = False
620+
else:
621+
groups_in_palette = {
622+
k: custom_pal[k] for k in all_plot_groups if k in color_groups
623+
}
611624
elif sankey:
612625
groups_in_palette = {
613626
k: custom_pal[k] for k in [1, 0]

dabest/plot_tools.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1885,6 +1885,9 @@ def barplotter(
18851885
horizontal : bool
18861886
If the plot is horizontal.
18871887
"""
1888+
# Check if the custom_palette is a dictionary with two keys 0 and 1 (for filled bar coloring)
1889+
filled_bars = True if len(plot_palette_raw.keys())==2 and all(k in plot_palette_raw for k in [1, 0]) else False
1890+
18881891
bar_width = barplot_kwargs.get('width', 0.5)
18891892
fontsize = barplot_kwargs.pop('fontsize')
18901893

@@ -1912,7 +1915,7 @@ def barplotter(
19121915
for hue_val in bar1_df[color_col]
19131916
]
19141917
else:
1915-
edge_colors = raw_colors
1918+
edge_colors = len(all_plot_groups)*['black',] if filled_bars else raw_colors
19161919

19171920
bar1 = sns.barplot(
19181921
data=bar1_df,
@@ -1921,20 +1924,26 @@ def barplotter(
19211924
ax=rawdata_axes,
19221925
order=all_plot_groups,
19231926
linewidth=2,
1924-
facecolor=(1, 1, 1, 0),
1927+
facecolor=plot_palette_raw[0] if filled_bars else (1, 1, 1, 0),
19251928
edgecolor=edge_colors,
19261929
zorder=1,
19271930
orient=orient,
19281931
)
19291932

1933+
if filled_bars:
1934+
barplot_kwargs['facecolor'] = plot_palette_raw[1]
1935+
barplot_kwargs['edgecolor'] = 'black'
1936+
barplot_kwargs['linewidth'] = 2
1937+
else:
1938+
barplot_kwargs['palette'] = plot_palette_raw
1939+
19301940
bar2 = sns.barplot(
19311941
data=plot_data,
19321942
x=yvar if horizontal else xvar,
19331943
y=xvar if horizontal else yvar,
19341944
hue=xvar if color_col is None else color_col,
19351945
ax=rawdata_axes,
19361946
order=all_plot_groups,
1937-
palette=plot_palette_raw,
19381947
dodge=False,
19391948
zorder=1,
19401949
orient=orient,

dabest/plotter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
161161
idx = idx,
162162
all_plot_groups = all_plot_groups,
163163
delta2 = effectsize_df.delta2,
164-
sankey = True if proportional and show_pairs else False,
164+
proportional = proportional
165165
)
166166

167167
# Initialise the figure.

nbs/API/misc_tools.ipynb

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@
569569
" idx: list, \n",
570570
" all_plot_groups: list,\n",
571571
" delta2: bool,\n",
572-
" sankey: bool\n",
572+
" proportional: bool\n",
573573
" ):\n",
574574
" \"\"\"\n",
575575
" Create the color palette to be used in the plotter function.\n",
@@ -590,9 +590,11 @@
590590
" A list of all the group names.\n",
591591
" delta2 : bool\n",
592592
" A boolean flag to determine if the plot will have a delta-delta effect size.\n",
593-
" sankey : bool\n",
594-
" A boolean flag to determine if the plot is for a Sankey diagram.\n",
593+
" proportional : bool\n",
594+
" A boolean flag to determine if the plot is for a proportional plot.\n",
595595
" \"\"\"\n",
596+
" sankey = True if proportional and show_pairs else False\n",
597+
"\n",
596598
" # Create color palette that will be shared across subplots.\n",
597599
" color_col = plot_kwargs[\"color_col\"]\n",
598600
" if color_col is None:\n",
@@ -661,6 +663,17 @@
661663
" groups_in_palette = {\n",
662664
" k: custom_pal[k] for k in color_groups\n",
663665
" }\n",
666+
" elif proportional and not sankey: # barplots (unpaired proportional data)\n",
667+
" keys = list(custom_pal.keys())\n",
668+
" if all(k in keys for k in [1, 0]) and len(keys) == 2:\n",
669+
" groups_in_palette = {\n",
670+
" k: custom_pal[k] for k in [1, 0]\n",
671+
" }\n",
672+
" bootstraps_color_by_group = False\n",
673+
" else:\n",
674+
" groups_in_palette = {\n",
675+
" k: custom_pal[k] for k in all_plot_groups if k in color_groups\n",
676+
" }\n",
664677
" elif sankey:\n",
665678
" groups_in_palette = {\n",
666679
" k: custom_pal[k] for k in [1, 0]\n",

nbs/API/plot_tools.ipynb

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1936,6 +1936,9 @@
19361936
" horizontal : bool\n",
19371937
" If the plot is horizontal.\n",
19381938
" \"\"\"\n",
1939+
" # Check if the custom_palette is a dictionary with two keys 0 and 1 (for filled bar coloring)\n",
1940+
" filled_bars = True if len(plot_palette_raw.keys())==2 and all(k in plot_palette_raw for k in [1, 0]) else False\n",
1941+
"\n",
19391942
" bar_width = barplot_kwargs.get('width', 0.5)\n",
19401943
" fontsize = barplot_kwargs.pop('fontsize')\n",
19411944
"\n",
@@ -1963,7 +1966,7 @@
19631966
" for hue_val in bar1_df[color_col]\n",
19641967
" ]\n",
19651968
" else:\n",
1966-
" edge_colors = raw_colors\n",
1969+
" edge_colors = len(all_plot_groups)*['black',] if filled_bars else raw_colors\n",
19671970
"\n",
19681971
" bar1 = sns.barplot(\n",
19691972
" data=bar1_df,\n",
@@ -1972,20 +1975,26 @@
19721975
" ax=rawdata_axes,\n",
19731976
" order=all_plot_groups,\n",
19741977
" linewidth=2,\n",
1975-
" facecolor=(1, 1, 1, 0),\n",
1978+
" facecolor=plot_palette_raw[0] if filled_bars else (1, 1, 1, 0),\n",
19761979
" edgecolor=edge_colors,\n",
19771980
" zorder=1,\n",
19781981
" orient=orient,\n",
19791982
" )\n",
19801983
"\n",
1984+
" if filled_bars:\n",
1985+
" barplot_kwargs['facecolor'] = plot_palette_raw[1]\n",
1986+
" barplot_kwargs['edgecolor'] = 'black'\n",
1987+
" barplot_kwargs['linewidth'] = 2\n",
1988+
" else:\n",
1989+
" barplot_kwargs['palette'] = plot_palette_raw\n",
1990+
"\n",
19811991
" bar2 = sns.barplot(\n",
19821992
" data=plot_data,\n",
19831993
" x=yvar if horizontal else xvar,\n",
19841994
" y=xvar if horizontal else yvar,\n",
19851995
" hue=xvar if color_col is None else color_col,\n",
19861996
" ax=rawdata_axes,\n",
19871997
" order=all_plot_groups,\n",
1988-
" palette=plot_palette_raw,\n",
19891998
" dodge=False,\n",
19901999
" zorder=1,\n",
19912000
" orient=orient,\n",

nbs/API/plotter.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@
218218
" idx = idx,\n",
219219
" all_plot_groups = all_plot_groups,\n",
220220
" delta2 = effectsize_df.delta2,\n",
221-
" sankey = True if proportional and show_pairs else False,\n",
221+
" proportional = proportional\n",
222222
" )\n",
223223
"\n",
224224
" # Initialise the figure.\n",
30.6 KB
Loading

nbs/tests/mpl_image_tests/test_10_proportion_plot.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,11 @@ def test_143_sankey_change_palette_c():
459459
plt.rcdefaults()
460460
return multi_groups_paired.mean_diff.plot(custom_palette=['red', 'blue'])
461461

462+
@pytest.mark.mpl_image_compare(tolerance=8)
463+
def test_144_change_palette_d():
464+
plt.rcdefaults()
465+
return multi_2group.mean_diff.plot(custom_palette={0:'blue', 1: 'red'})
466+
462467
@pytest.mark.mpl_image_compare(tolerance=8)
463468
def test_136_style_sheets():
464469
# Perform this test last so we don't have to reset the plot style.

nbs/tutorials/09-plot_aesthetics.ipynb

Lines changed: 24 additions & 3 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)