Skip to content

Commit 03bafb7

Browse files
authored
Merge pull request #208 from ACCLAB/feat-paired-coloroption
Updated slopegraph visuals and additional custom_palette options updated the design of slopegraphs to remove raw bars and include a group summary (central tendency line with error bars) Added additional coloring options: - custom_palette dict for unpaired prop plots (bar plots) can now take 0 and 1 as keys to color the filled and unfilled portions of the plots - custom_palette usage with in paired plots can now color the contrast bars and effect size curves
2 parents 8f45d3c + ef4d4b5 commit 03bafb7

File tree

104 files changed

+443
-136
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

104 files changed

+443
-136
lines changed

dabest/_effsize_objects.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,7 +1143,7 @@ def plot(
11431143
face_color=None,
11441144

11451145
raw_desat=0.5, # swarm_desat=0.5, OLD # bar_desat=0.5, OLD
1146-
contrast_desat=1, # halfviolin_desat=1, OLD
1146+
contrast_desat=1.0, # halfviolin_desat=1, OLD
11471147

11481148
raw_alpha=None, # NEW
11491149
contrast_alpha=0.8, # halfviolin_alpha=0.8, OLD
@@ -1478,7 +1478,8 @@ def plot(
14781478

14791479
if raw_alpha is None:
14801480
raw_alpha = (0.4 if self.is_proportional and self.is_paired
1481-
else 0.5 if self.is_paired
1481+
else 0.5 if self.is_paired and (color_col is not None or self.__delta2)
1482+
else 0.2 if self.is_paired and color_col is None
14821483
else 1.0
14831484
)
14841485

dabest/misc_tools.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,8 @@ def get_params(
203203

204204
def get_kwargs(
205205
plot_kwargs: dict,
206-
ytick_color
206+
ytick_color,
207+
is_paired: bool = False
207208
):
208209
"""
209210
Extracts the kwargs from the `plot_kwargs` object for use in the plotter function.
@@ -214,6 +215,8 @@ def get_kwargs(
214215
Kwargs passed to the plot function.
215216
ytick_color : str or color list
216217
Color of the yticks.
218+
is_paired : bool, optional
219+
A boolean flag to determine if the plot is for paired data. Default is False.
217220
"""
218221
from .misc_tools import merge_two_dicts
219222

@@ -334,7 +337,7 @@ def get_kwargs(
334337
default_group_summaries_kwargs = {
335338
"zorder": 3,
336339
"lw": 2,
337-
"alpha": 1,
340+
"alpha": 1 if not is_paired else 0.6,
338341
'gap_width_percent': 1.5,
339342
'offset': 0.1,
340343
'color': None
@@ -513,7 +516,7 @@ def get_color_palette(
513516
idx: list,
514517
all_plot_groups: list,
515518
delta2: bool,
516-
sankey: bool
519+
proportional: bool
517520
):
518521
"""
519522
Create the color palette to be used in the plotter function.
@@ -534,9 +537,11 @@ def get_color_palette(
534537
A list of all the group names.
535538
delta2 : bool
536539
A boolean flag to determine if the plot will have a delta-delta effect size.
537-
sankey : bool
538-
A boolean flag to determine if the plot is for a Sankey diagram.
540+
proportional : bool
541+
A boolean flag to determine if the plot is for a proportional plot.
539542
"""
543+
sankey = True if proportional and show_pairs else False
544+
540545
# Create color palette that will be shared across subplots.
541546
color_col = plot_kwargs["color_col"]
542547
if color_col is None:
@@ -548,7 +553,13 @@ def get_color_palette(
548553
color_groups = pd.unique(plot_data[color_col])
549554
bootstraps_color_by_group = False
550555
if show_pairs:
551-
bootstraps_color_by_group = False
556+
if plot_kwargs["custom_palette"] is not None:
557+
if delta2 or sankey:
558+
bootstraps_color_by_group = False
559+
else:
560+
bootstraps_color_by_group = True
561+
else:
562+
bootstraps_color_by_group = False
552563

553564
# Handle the color palette.
554565
filled = True
@@ -599,6 +610,17 @@ def get_color_palette(
599610
groups_in_palette = {
600611
k: custom_pal[k] for k in color_groups
601612
}
613+
elif proportional and not sankey: # barplots (unpaired proportional data)
614+
keys = list(custom_pal.keys())
615+
if all(k in keys for k in [1, 0]) and len(keys) == 2:
616+
groups_in_palette = {
617+
k: custom_pal[k] for k in [1, 0]
618+
}
619+
bootstraps_color_by_group = False
620+
else:
621+
groups_in_palette = {
622+
k: custom_pal[k] for k in all_plot_groups if k in color_groups
623+
}
602624
elif sankey:
603625
groups_in_palette = {
604626
k: custom_pal[k] for k in [1, 0]
@@ -1856,13 +1878,15 @@ def color_picker(color_type: str,
18561878
elements: list,
18571879
color_col: str,
18581880
show_pairs: bool,
1859-
color_palette: dict) -> list:
1881+
color_palette: dict,
1882+
bootstraps_color_by_group: bool) -> list:
18601883
num_of_elements = len(elements)
18611884
colors = (
18621885
[kwargs.pop('color')] * num_of_elements
18631886
if kwargs.get('color', None) is not None
18641887
else ['black'] * num_of_elements
1865-
if color_col is not None or show_pairs
1888+
# if color_col is not None or show_pairs
1889+
if color_col is not None or not bootstraps_color_by_group
18661890
else list(color_palette.values())
18671891
)
18681892
if color_type in ['contrast', 'summary', 'delta_text']:
@@ -1877,7 +1901,7 @@ def color_picker(color_type: str,
18771901
return final_colors
18781902

18791903

1880-
def prepare_bars_for_plot(bar_type, bar_kwargs, horizontal, plot_palette_raw, color_col, show_pairs,
1904+
def prepare_bars_for_plot(bar_type, bar_kwargs, horizontal, plot_palette_raw, color_col, show_pairs, bootstraps_color_by_group,
18811905
plot_data = None, xvar = None, yvar = None, # Raw data
18821906
results = None, ticks_to_plot = None, extra_delta = None, # Contrast data
18831907
reference_band = None, summary_axes = None, ci_type = None # Summary data
@@ -1951,7 +1975,8 @@ def prepare_bars_for_plot(bar_type, bar_kwargs, horizontal, plot_palette_raw, co
19511975
elements = ticks_to_plot if bar_type=='contrast' else ticks,
19521976
color_col = color_col,
19531977
show_pairs = show_pairs,
1954-
color_palette = plot_palette_raw
1978+
color_palette = plot_palette_raw,
1979+
bootstraps_color_by_group = bootstraps_color_by_group
19551980
)
19561981
if bar_type == 'contrast' and extra_delta is not None:
19571982
colors.append('black')

dabest/plot_tools.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,7 @@ def delta_text_plotter(
904904
show_pairs: bool,
905905
float_contrast: bool,
906906
extra_delta: float,
907+
bootstraps_color_by_group: bool = False
907908
):
908909
"""
909910
Add delta text to the contrast plot.
@@ -928,6 +929,8 @@ def delta_text_plotter(
928929
Whether the DABEST plot uses Gardner-Altman or Cummings.
929930
extra_delta : float or None
930931
The extra mini-meta or delta-delta value if applicable.
932+
bootstraps_color_by_group : bool, optional
933+
Whether to color the bootstraps by group. Default is False.
931934
"""
932935
# Colors
933936
from .misc_tools import color_picker
@@ -936,7 +939,8 @@ def delta_text_plotter(
936939
elements = ticks_to_plot,
937940
color_col = color_col,
938941
show_pairs = show_pairs,
939-
color_palette = plot_palette_raw
942+
color_palette = plot_palette_raw,
943+
bootstraps_color_by_group = bootstraps_color_by_group
940944
)
941945

942946
num_of_elements = len(ticks_to_plot) + 1 if extra_delta is not None else len(ticks_to_plot)
@@ -1091,7 +1095,8 @@ def slopegraph_plotter(
10911095
temp_idx: list,
10921096
horizontal: bool,
10931097
temp_all_plot_groups: list,
1094-
plot_kwargs: dict
1098+
plot_kwargs: dict,
1099+
group_summaries_kwargs: dict
10951100
):
10961101
"""
10971102
Add slopegraph to the rawdata axes.
@@ -1124,6 +1129,8 @@ def slopegraph_plotter(
11241129
List of all plot groups.
11251130
plot_kwargs : dict
11261131
Keyword arguments for the plot.
1132+
group_summaries_kwargs : dict, optional
1133+
Keyword arguments for group summaries, if applicable.
11271134
11281135
"""
11291136
# Jitter Kwargs
@@ -1178,6 +1185,45 @@ def slopegraph_plotter(
11781185
x_points, y_points = (y_points, x_points) if horizontal else (x_points, y_points)
11791186
rawdata_axes.plot(x_points, y_points, **slopegraph_kwargs)
11801187

1188+
# Add the group summaries if applicable.
1189+
group_summaries = plot_kwargs.get("group_summaries", None)
1190+
if group_summaries is not None:
1191+
for key in ['gap_width_percent', 'offset']:
1192+
group_summaries_kwargs.pop(key, None)
1193+
group_summaries_kwargs['color'] = 'black' if group_summaries_kwargs.get('color') is None else group_summaries_kwargs['color']
1194+
group_summaries_kwargs['capsize'] = 0 if group_summaries_kwargs.get('capsize') is None else group_summaries_kwargs['capsize']
1195+
1196+
index_points = [t for t in range(x_start, x_start + grp_count)]
1197+
av_points, err_points, lo_points, hi_points = [], [], [], []
1198+
for group in range(len(index_points)):
1199+
if group_summaries == "mean_sd":
1200+
av_points.append(current_pair.iloc[:, int(group)].mean())
1201+
err_points.append(current_pair.iloc[:, int(group)].std())
1202+
elif group_summaries == "median_quartiles":
1203+
median = current_pair.iloc[:, int(group)].median()
1204+
av_points.append(median)
1205+
lo_points.append(median - current_pair.iloc[:, int(group)].quantile(0.25))
1206+
hi_points.append(current_pair.iloc[:, int(group)].quantile(0.75) - median)
1207+
1208+
if group_summaries == "median_quartiles":
1209+
err_points = [lo_points, hi_points]
1210+
1211+
# Plot the lines
1212+
if horizontal:
1213+
rawdata_axes.errorbar(
1214+
av_points,
1215+
index_points,
1216+
xerr=err_points,
1217+
**group_summaries_kwargs
1218+
)
1219+
else:
1220+
rawdata_axes.errorbar(
1221+
index_points,
1222+
av_points,
1223+
yerr=err_points,
1224+
**group_summaries_kwargs
1225+
)
1226+
11811227
x_start = x_start + grp_count
11821228

11831229
# Set the tick labels, because the slopegraph plotting doesn't.
@@ -1839,6 +1885,9 @@ def barplotter(
18391885
horizontal : bool
18401886
If the plot is horizontal.
18411887
"""
1888+
# Check if the custom_palette is a dictionary with two keys 0 and 1 (for filled bar coloring)
1889+
filled_bars = True if len(plot_palette_raw.keys())==2 and all(k in plot_palette_raw for k in [1, 0]) else False
1890+
18421891
bar_width = barplot_kwargs.get('width', 0.5)
18431892
fontsize = barplot_kwargs.pop('fontsize')
18441893

@@ -1866,29 +1915,35 @@ def barplotter(
18661915
for hue_val in bar1_df[color_col]
18671916
]
18681917
else:
1869-
edge_colors = raw_colors
1918+
edge_colors = len(all_plot_groups)*['black',] if filled_bars else raw_colors
18701919

18711920
bar1 = sns.barplot(
18721921
data=bar1_df,
18731922
x=xvar,
18741923
y="proportion",
18751924
ax=rawdata_axes,
18761925
order=all_plot_groups,
1877-
linewidth=2,
1878-
facecolor=(1, 1, 1, 0),
1926+
linewidth=1 if filled_bars else 2,
1927+
facecolor=plot_palette_raw[0] if filled_bars else (1, 1, 1, 0),
18791928
edgecolor=edge_colors,
18801929
zorder=1,
18811930
orient=orient,
18821931
)
18831932

1933+
if filled_bars:
1934+
barplot_kwargs['facecolor'] = plot_palette_raw[1]
1935+
barplot_kwargs['edgecolor'] = 'black'
1936+
barplot_kwargs['linewidth'] = 1
1937+
else:
1938+
barplot_kwargs['palette'] = plot_palette_raw
1939+
18841940
bar2 = sns.barplot(
18851941
data=plot_data,
18861942
x=yvar if horizontal else xvar,
18871943
y=xvar if horizontal else yvar,
18881944
hue=xvar if color_col is None else color_col,
18891945
ax=rawdata_axes,
18901946
order=all_plot_groups,
1891-
palette=plot_palette_raw,
18921947
dodge=False,
18931948
zorder=1,
18941949
orient=orient,

dabest/plotter.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
137137
raw_bars_kwargs, contrast_bars_kwargs, table_kwargs, gridkey_kwargs, contrast_marker_kwargs,
138138
contrast_errorbar_kwargs, prop_sample_counts_kwargs, contrast_paired_lines_kwargs) = get_kwargs(
139139
plot_kwargs = plot_kwargs,
140-
ytick_color = ytick_color
140+
ytick_color = ytick_color,
141+
is_paired = effectsize_df.is_paired
141142
)
142143

143144
(dabest_obj, plot_data, xvar, yvar, is_paired, effect_size, proportional,
@@ -160,7 +161,7 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
160161
idx = idx,
161162
all_plot_groups = all_plot_groups,
162163
delta2 = effectsize_df.delta2,
163-
sankey = True if proportional and show_pairs else False,
164+
proportional = proportional
164165
)
165166

166167
# Initialise the figure.
@@ -219,6 +220,7 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
219220
horizontal = horizontal,
220221
temp_all_plot_groups = temp_all_plot_groups,
221222
plot_kwargs = plot_kwargs,
223+
group_summaries_kwargs = group_summaries_kwargs
222224
)
223225

224226
## Add delta dots to the contrast axes for paired plots.
@@ -333,7 +335,7 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
333335

334336
## Swarm bars
335337
raw_bars = plot_kwargs["raw_bars"]
336-
if raw_bars and not proportional and not horizontal: #Currently not supporting swarm bars for horizontal plots (looks weird)
338+
if raw_bars and not proportional and not is_paired and not horizontal: #Currently not supporting swarm bars for horizontal plots (looks weird)
337339
raw_bars_dict, raw_bars_kwargs = prepare_bars_for_plot(
338340
bar_type = 'raw',
339341
bar_kwargs = raw_bars_kwargs,
@@ -343,7 +345,8 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
343345
show_pairs = show_pairs,
344346
plot_data = plot_data,
345347
xvar = xvar,
346-
yvar = yvar,
348+
yvar = yvar,
349+
bootstraps_color_by_group = bootstraps_color_by_group,
347350
)
348351
add_bars_to_plot(bar_dict = raw_bars_dict,
349352
ax = rawdata_axes,
@@ -424,6 +427,7 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
424427
show_pairs = show_pairs,
425428
results = results,
426429
ticks_to_plot = ticks_to_plot,
430+
bootstraps_color_by_group = bootstraps_color_by_group,
427431
extra_delta = (effectsize_df.mini_meta.difference if show_mini_meta
428432
else effectsize_df.delta_delta.difference if show_delta2
429433
else None)
@@ -445,6 +449,7 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
445449
plot_palette_raw = plot_palette_raw,
446450
show_pairs = show_pairs,
447451
float_contrast = float_contrast,
452+
bootstraps_color_by_group = bootstraps_color_by_group,
448453
extra_delta = (effectsize_df.mini_meta.difference if show_mini_meta
449454
else effectsize_df.delta_delta.difference if show_delta2
450455
else None),
@@ -588,6 +593,7 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
588593
reference_band = reference_band,
589594
summary_axes = contrast_axes,
590595
ci_type = ci_type,
596+
bootstraps_color_by_group = bootstraps_color_by_group,
591597
)
592598

593599
add_bars_to_plot(bar_dict = reference_band_dict,

nbs/API/effsize_objects.ipynb

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,7 +1343,7 @@
13431343
" face_color=None,\n",
13441344
"\n",
13451345
" raw_desat=0.5, # swarm_desat=0.5, OLD # bar_desat=0.5, OLD\n",
1346-
" contrast_desat=1, # halfviolin_desat=1, OLD\n",
1346+
" contrast_desat=1.0, # halfviolin_desat=1, OLD\n",
13471347
"\n",
13481348
" raw_alpha=None, # NEW\n",
13491349
" contrast_alpha=0.8, # halfviolin_alpha=0.8, OLD\n",
@@ -1678,7 +1678,8 @@
16781678
"\n",
16791679
" if raw_alpha is None:\n",
16801680
" raw_alpha = (0.4 if self.is_proportional and self.is_paired \n",
1681-
" else 0.5 if self.is_paired\n",
1681+
" else 0.5 if self.is_paired and (color_col is not None or self.__delta2)\n",
1682+
" else 0.2 if self.is_paired and color_col is None\n",
16821683
" else 1.0\n",
16831684
" )\n",
16841685
"\n",

0 commit comments

Comments
 (0)