Skip to content

Commit fd1b16c

Browse files
committed
Updated slopegraph format to include group summaries
- Added group summary lines to slopegraphs - Added ability to color effect size curves and the constrast bars and delta text via using custom_palette to paired plots
1 parent 8f45d3c commit fd1b16c

File tree

102 files changed

+355
-118
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

102 files changed

+355
-118
lines changed

dabest/_effsize_objects.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,7 +1143,7 @@ def plot(
11431143
face_color=None,
11441144

11451145
raw_desat=0.5, # swarm_desat=0.5, OLD # bar_desat=0.5, OLD
1146-
contrast_desat=1, # halfviolin_desat=1, OLD
1146+
contrast_desat=1.0, # halfviolin_desat=1, OLD
11471147

11481148
raw_alpha=None, # NEW
11491149
contrast_alpha=0.8, # halfviolin_alpha=0.8, OLD
@@ -1478,7 +1478,8 @@ def plot(
14781478

14791479
if raw_alpha is None:
14801480
raw_alpha = (0.4 if self.is_proportional and self.is_paired
1481-
else 0.5 if self.is_paired
1481+
else 0.5 if self.is_paired and (color_col is not None or self.__delta2)
1482+
else 0.2 if self.is_paired and color_col is None
14821483
else 1.0
14831484
)
14841485

dabest/misc_tools.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,8 @@ def get_params(
203203

204204
def get_kwargs(
205205
plot_kwargs: dict,
206-
ytick_color
206+
ytick_color,
207+
is_paired: bool = False
207208
):
208209
"""
209210
Extracts the kwargs from the `plot_kwargs` object for use in the plotter function.
@@ -214,6 +215,8 @@ def get_kwargs(
214215
Kwargs passed to the plot function.
215216
ytick_color : str or color list
216217
Color of the yticks.
218+
is_paired : bool, optional
219+
A boolean flag to determine if the plot is for paired data. Default is False.
217220
"""
218221
from .misc_tools import merge_two_dicts
219222

@@ -334,7 +337,7 @@ def get_kwargs(
334337
default_group_summaries_kwargs = {
335338
"zorder": 3,
336339
"lw": 2,
337-
"alpha": 1,
340+
"alpha": 1 if not is_paired else 0.6,
338341
'gap_width_percent': 1.5,
339342
'offset': 0.1,
340343
'color': None
@@ -548,7 +551,13 @@ def get_color_palette(
548551
color_groups = pd.unique(plot_data[color_col])
549552
bootstraps_color_by_group = False
550553
if show_pairs:
551-
bootstraps_color_by_group = False
554+
if plot_kwargs["custom_palette"] is not None:
555+
if delta2 or sankey:
556+
bootstraps_color_by_group = False
557+
else:
558+
bootstraps_color_by_group = True
559+
else:
560+
bootstraps_color_by_group = False
552561

553562
# Handle the color palette.
554563
filled = True
@@ -1856,13 +1865,15 @@ def color_picker(color_type: str,
18561865
elements: list,
18571866
color_col: str,
18581867
show_pairs: bool,
1859-
color_palette: dict) -> list:
1868+
color_palette: dict,
1869+
bootstraps_color_by_group: bool) -> list:
18601870
num_of_elements = len(elements)
18611871
colors = (
18621872
[kwargs.pop('color')] * num_of_elements
18631873
if kwargs.get('color', None) is not None
18641874
else ['black'] * num_of_elements
1865-
if color_col is not None or show_pairs
1875+
# if color_col is not None or show_pairs
1876+
if color_col is not None or not bootstraps_color_by_group
18661877
else list(color_palette.values())
18671878
)
18681879
if color_type in ['contrast', 'summary', 'delta_text']:
@@ -1877,7 +1888,7 @@ def color_picker(color_type: str,
18771888
return final_colors
18781889

18791890

1880-
def prepare_bars_for_plot(bar_type, bar_kwargs, horizontal, plot_palette_raw, color_col, show_pairs,
1891+
def prepare_bars_for_plot(bar_type, bar_kwargs, horizontal, plot_palette_raw, color_col, show_pairs, bootstraps_color_by_group,
18811892
plot_data = None, xvar = None, yvar = None, # Raw data
18821893
results = None, ticks_to_plot = None, extra_delta = None, # Contrast data
18831894
reference_band = None, summary_axes = None, ci_type = None # Summary data
@@ -1951,7 +1962,8 @@ def prepare_bars_for_plot(bar_type, bar_kwargs, horizontal, plot_palette_raw, co
19511962
elements = ticks_to_plot if bar_type=='contrast' else ticks,
19521963
color_col = color_col,
19531964
show_pairs = show_pairs,
1954-
color_palette = plot_palette_raw
1965+
color_palette = plot_palette_raw,
1966+
bootstraps_color_by_group = bootstraps_color_by_group
19551967
)
19561968
if bar_type == 'contrast' and extra_delta is not None:
19571969
colors.append('black')

dabest/plot_tools.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,7 @@ def delta_text_plotter(
904904
show_pairs: bool,
905905
float_contrast: bool,
906906
extra_delta: float,
907+
bootstraps_color_by_group: bool = False
907908
):
908909
"""
909910
Add delta text to the contrast plot.
@@ -928,6 +929,8 @@ def delta_text_plotter(
928929
Whether the DABEST plot uses Gardner-Altman or Cummings.
929930
extra_delta : float or None
930931
The extra mini-meta or delta-delta value if applicable.
932+
bootstraps_color_by_group : bool, optional
933+
Whether to color the bootstraps by group. Default is False.
931934
"""
932935
# Colors
933936
from .misc_tools import color_picker
@@ -936,7 +939,8 @@ def delta_text_plotter(
936939
elements = ticks_to_plot,
937940
color_col = color_col,
938941
show_pairs = show_pairs,
939-
color_palette = plot_palette_raw
942+
color_palette = plot_palette_raw,
943+
bootstraps_color_by_group = bootstraps_color_by_group
940944
)
941945

942946
num_of_elements = len(ticks_to_plot) + 1 if extra_delta is not None else len(ticks_to_plot)
@@ -1091,7 +1095,8 @@ def slopegraph_plotter(
10911095
temp_idx: list,
10921096
horizontal: bool,
10931097
temp_all_plot_groups: list,
1094-
plot_kwargs: dict
1098+
plot_kwargs: dict,
1099+
group_summaries_kwargs: dict
10951100
):
10961101
"""
10971102
Add slopegraph to the rawdata axes.
@@ -1124,6 +1129,8 @@ def slopegraph_plotter(
11241129
List of all plot groups.
11251130
plot_kwargs : dict
11261131
Keyword arguments for the plot.
1132+
group_summaries_kwargs : dict, optional
1133+
Keyword arguments for group summaries, if applicable.
11271134
11281135
"""
11291136
# Jitter Kwargs
@@ -1178,6 +1185,45 @@ def slopegraph_plotter(
11781185
x_points, y_points = (y_points, x_points) if horizontal else (x_points, y_points)
11791186
rawdata_axes.plot(x_points, y_points, **slopegraph_kwargs)
11801187

1188+
# Add the group summaries if applicable.
1189+
group_summaries = plot_kwargs.get("group_summaries", None)
1190+
if group_summaries is not None:
1191+
for key in ['gap_width_percent', 'offset']:
1192+
group_summaries_kwargs.pop(key, None)
1193+
group_summaries_kwargs['color'] = 'black' if group_summaries_kwargs.get('color') is None else group_summaries_kwargs['color']
1194+
group_summaries_kwargs['capsize'] = 0 if group_summaries_kwargs.get('capsize') is None else group_summaries_kwargs['capsize']
1195+
1196+
index_points = [t for t in range(x_start, x_start + grp_count)]
1197+
av_points, err_points, lo_points, hi_points = [], [], [], []
1198+
for group in range(len(index_points)):
1199+
if group_summaries == "mean_sd":
1200+
av_points.append(current_pair.iloc[:, int(group)].mean())
1201+
err_points.append(current_pair.iloc[:, int(group)].std())
1202+
elif group_summaries == "median_quartiles":
1203+
median = current_pair.iloc[:, int(group)].median()
1204+
av_points.append(median)
1205+
lo_points.append(median - current_pair.iloc[:, int(group)].quantile(0.25))
1206+
hi_points.append(current_pair.iloc[:, int(group)].quantile(0.75) - median)
1207+
1208+
if group_summaries == "median_quartiles":
1209+
err_points = [lo_points, hi_points]
1210+
1211+
# Plot the lines
1212+
if horizontal:
1213+
rawdata_axes.errorbar(
1214+
av_points,
1215+
index_points,
1216+
xerr=err_points,
1217+
**group_summaries_kwargs
1218+
)
1219+
else:
1220+
rawdata_axes.errorbar(
1221+
index_points,
1222+
av_points,
1223+
yerr=err_points,
1224+
**group_summaries_kwargs
1225+
)
1226+
11811227
x_start = x_start + grp_count
11821228

11831229
# Set the tick labels, because the slopegraph plotting doesn't.

dabest/plotter.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
137137
raw_bars_kwargs, contrast_bars_kwargs, table_kwargs, gridkey_kwargs, contrast_marker_kwargs,
138138
contrast_errorbar_kwargs, prop_sample_counts_kwargs, contrast_paired_lines_kwargs) = get_kwargs(
139139
plot_kwargs = plot_kwargs,
140-
ytick_color = ytick_color
140+
ytick_color = ytick_color,
141+
is_paired = effectsize_df.is_paired
141142
)
142143

143144
(dabest_obj, plot_data, xvar, yvar, is_paired, effect_size, proportional,
@@ -219,6 +220,7 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
219220
horizontal = horizontal,
220221
temp_all_plot_groups = temp_all_plot_groups,
221222
plot_kwargs = plot_kwargs,
223+
group_summaries_kwargs = group_summaries_kwargs
222224
)
223225

224226
## Add delta dots to the contrast axes for paired plots.
@@ -333,7 +335,7 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
333335

334336
## Swarm bars
335337
raw_bars = plot_kwargs["raw_bars"]
336-
if raw_bars and not proportional and not horizontal: #Currently not supporting swarm bars for horizontal plots (looks weird)
338+
if raw_bars and not proportional and not is_paired and not horizontal: #Currently not supporting swarm bars for horizontal plots (looks weird)
337339
raw_bars_dict, raw_bars_kwargs = prepare_bars_for_plot(
338340
bar_type = 'raw',
339341
bar_kwargs = raw_bars_kwargs,
@@ -343,7 +345,8 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
343345
show_pairs = show_pairs,
344346
plot_data = plot_data,
345347
xvar = xvar,
346-
yvar = yvar,
348+
yvar = yvar,
349+
bootstraps_color_by_group = bootstraps_color_by_group,
347350
)
348351
add_bars_to_plot(bar_dict = raw_bars_dict,
349352
ax = rawdata_axes,
@@ -424,6 +427,7 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
424427
show_pairs = show_pairs,
425428
results = results,
426429
ticks_to_plot = ticks_to_plot,
430+
bootstraps_color_by_group = bootstraps_color_by_group,
427431
extra_delta = (effectsize_df.mini_meta.difference if show_mini_meta
428432
else effectsize_df.delta_delta.difference if show_delta2
429433
else None)
@@ -445,6 +449,7 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
445449
plot_palette_raw = plot_palette_raw,
446450
show_pairs = show_pairs,
447451
float_contrast = float_contrast,
452+
bootstraps_color_by_group = bootstraps_color_by_group,
448453
extra_delta = (effectsize_df.mini_meta.difference if show_mini_meta
449454
else effectsize_df.delta_delta.difference if show_delta2
450455
else None),
@@ -588,6 +593,7 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
588593
reference_band = reference_band,
589594
summary_axes = contrast_axes,
590595
ci_type = ci_type,
596+
bootstraps_color_by_group = bootstraps_color_by_group,
591597
)
592598

593599
add_bars_to_plot(bar_dict = reference_band_dict,

nbs/API/effsize_objects.ipynb

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,7 +1343,7 @@
13431343
" face_color=None,\n",
13441344
"\n",
13451345
" raw_desat=0.5, # swarm_desat=0.5, OLD # bar_desat=0.5, OLD\n",
1346-
" contrast_desat=1, # halfviolin_desat=1, OLD\n",
1346+
" contrast_desat=1.0, # halfviolin_desat=1, OLD\n",
13471347
"\n",
13481348
" raw_alpha=None, # NEW\n",
13491349
" contrast_alpha=0.8, # halfviolin_alpha=0.8, OLD\n",
@@ -1678,7 +1678,8 @@
16781678
"\n",
16791679
" if raw_alpha is None:\n",
16801680
" raw_alpha = (0.4 if self.is_proportional and self.is_paired \n",
1681-
" else 0.5 if self.is_paired\n",
1681+
" else 0.5 if self.is_paired and (color_col is not None or self.__delta2)\n",
1682+
" else 0.2 if self.is_paired and color_col is None\n",
16821683
" else 1.0\n",
16831684
" )\n",
16841685
"\n",

nbs/API/misc_tools.ipynb

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,8 @@
256256
"\n",
257257
"def get_kwargs(\n",
258258
" plot_kwargs: dict, \n",
259-
" ytick_color\n",
259+
" ytick_color,\n",
260+
" is_paired: bool = False\n",
260261
" ):\n",
261262
" \"\"\"\n",
262263
" Extracts the kwargs from the `plot_kwargs` object for use in the plotter function.\n",
@@ -267,6 +268,8 @@
267268
" Kwargs passed to the plot function.\n",
268269
" ytick_color : str or color list\n",
269270
" Color of the yticks.\n",
271+
" is_paired : bool, optional\n",
272+
" A boolean flag to determine if the plot is for paired data. Default is False.\n",
270273
" \"\"\"\n",
271274
" from .misc_tools import merge_two_dicts\n",
272275
"\n",
@@ -387,7 +390,7 @@
387390
" default_group_summaries_kwargs = {\n",
388391
" \"zorder\": 3, \n",
389392
" \"lw\": 2, \n",
390-
" \"alpha\": 1,\n",
393+
" \"alpha\": 1 if not is_paired else 0.6,\n",
391394
" 'gap_width_percent': 1.5,\n",
392395
" 'offset': 0.1,\n",
393396
" 'color': None\n",
@@ -601,7 +604,13 @@
601604
" color_groups = pd.unique(plot_data[color_col])\n",
602605
" bootstraps_color_by_group = False\n",
603606
" if show_pairs:\n",
604-
" bootstraps_color_by_group = False\n",
607+
" if plot_kwargs[\"custom_palette\"] is not None:\n",
608+
" if delta2 or sankey:\n",
609+
" bootstraps_color_by_group = False\n",
610+
" else:\n",
611+
" bootstraps_color_by_group = True\n",
612+
" else:\n",
613+
" bootstraps_color_by_group = False\n",
605614
"\n",
606615
" # Handle the color palette.\n",
607616
" filled = True\n",
@@ -1909,13 +1918,15 @@
19091918
" elements: list, \n",
19101919
" color_col: str, \n",
19111920
" show_pairs: bool, \n",
1912-
" color_palette: dict) -> list:\n",
1921+
" color_palette: dict,\n",
1922+
" bootstraps_color_by_group: bool) -> list:\n",
19131923
" num_of_elements = len(elements)\n",
19141924
" colors = (\n",
19151925
" [kwargs.pop('color')] * num_of_elements\n",
19161926
" if kwargs.get('color', None) is not None\n",
19171927
" else ['black'] * num_of_elements\n",
1918-
" if color_col is not None or show_pairs \n",
1928+
" # if color_col is not None or show_pairs\n",
1929+
" if color_col is not None or not bootstraps_color_by_group\n",
19191930
" else list(color_palette.values())\n",
19201931
" )\n",
19211932
" if color_type in ['contrast', 'summary', 'delta_text']:\n",
@@ -1930,7 +1941,7 @@
19301941
" return final_colors\n",
19311942
"\n",
19321943
"\n",
1933-
"def prepare_bars_for_plot(bar_type, bar_kwargs, horizontal, plot_palette_raw, color_col, show_pairs,\n",
1944+
"def prepare_bars_for_plot(bar_type, bar_kwargs, horizontal, plot_palette_raw, color_col, show_pairs, bootstraps_color_by_group,\n",
19341945
" plot_data = None, xvar = None, yvar = None, # Raw data\n",
19351946
" results = None, ticks_to_plot = None, extra_delta = None, # Contrast data\n",
19361947
" reference_band = None, summary_axes = None, ci_type = None # Summary data\n",
@@ -2004,7 +2015,8 @@
20042015
" elements = ticks_to_plot if bar_type=='contrast' else ticks, \n",
20052016
" color_col = color_col, \n",
20062017
" show_pairs = show_pairs, \n",
2007-
" color_palette = plot_palette_raw\n",
2018+
" color_palette = plot_palette_raw,\n",
2019+
" bootstraps_color_by_group = bootstraps_color_by_group\n",
20082020
" )\n",
20092021
" if bar_type == 'contrast' and extra_delta is not None:\n",
20102022
" colors.append('black')\n",

0 commit comments

Comments
 (0)