Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions dabest/_effsize_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,7 @@ def plot(
face_color=None,

raw_desat=0.5, # swarm_desat=0.5, OLD # bar_desat=0.5, OLD
contrast_desat=1, # halfviolin_desat=1, OLD
contrast_desat=1.0, # halfviolin_desat=1, OLD

raw_alpha=None, # NEW
contrast_alpha=0.8, # halfviolin_alpha=0.8, OLD
Expand Down Expand Up @@ -1478,7 +1478,8 @@ def plot(

if raw_alpha is None:
raw_alpha = (0.4 if self.is_proportional and self.is_paired
else 0.5 if self.is_paired
else 0.5 if self.is_paired and (color_col is not None or self.__delta2)
else 0.2 if self.is_paired and color_col is None
else 1.0
)

Expand Down
45 changes: 35 additions & 10 deletions dabest/misc_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ def get_params(

def get_kwargs(
plot_kwargs: dict,
ytick_color
ytick_color,
is_paired: bool = False
):
"""
Extracts the kwargs from the `plot_kwargs` object for use in the plotter function.
Expand All @@ -214,6 +215,8 @@ def get_kwargs(
Kwargs passed to the plot function.
ytick_color : str or color list
Color of the yticks.
is_paired : bool, optional
A boolean flag to determine if the plot is for paired data. Default is False.
"""
from .misc_tools import merge_two_dicts

Expand Down Expand Up @@ -334,7 +337,7 @@ def get_kwargs(
default_group_summaries_kwargs = {
"zorder": 3,
"lw": 2,
"alpha": 1,
"alpha": 1 if not is_paired else 0.6,
'gap_width_percent': 1.5,
'offset': 0.1,
'color': None
Expand Down Expand Up @@ -513,7 +516,7 @@ def get_color_palette(
idx: list,
all_plot_groups: list,
delta2: bool,
sankey: bool
proportional: bool
):
"""
Create the color palette to be used in the plotter function.
Expand All @@ -534,9 +537,11 @@ def get_color_palette(
A list of all the group names.
delta2 : bool
A boolean flag to determine if the plot will have a delta-delta effect size.
sankey : bool
A boolean flag to determine if the plot is for a Sankey diagram.
proportional : bool
A boolean flag to determine if the plot is for a proportional plot.
"""
sankey = True if proportional and show_pairs else False

# Create color palette that will be shared across subplots.
color_col = plot_kwargs["color_col"]
if color_col is None:
Expand All @@ -548,7 +553,13 @@ def get_color_palette(
color_groups = pd.unique(plot_data[color_col])
bootstraps_color_by_group = False
if show_pairs:
bootstraps_color_by_group = False
if plot_kwargs["custom_palette"] is not None:
if delta2 or sankey:
bootstraps_color_by_group = False
else:
bootstraps_color_by_group = True
else:
bootstraps_color_by_group = False

# Handle the color palette.
filled = True
Expand Down Expand Up @@ -599,6 +610,17 @@ def get_color_palette(
groups_in_palette = {
k: custom_pal[k] for k in color_groups
}
elif proportional and not sankey: # barplots (unpaired proportional data)
keys = list(custom_pal.keys())
if all(k in keys for k in [1, 0]) and len(keys) == 2:
groups_in_palette = {
k: custom_pal[k] for k in [1, 0]
}
bootstraps_color_by_group = False
else:
groups_in_palette = {
k: custom_pal[k] for k in all_plot_groups if k in color_groups
}
elif sankey:
groups_in_palette = {
k: custom_pal[k] for k in [1, 0]
Expand Down Expand Up @@ -1856,13 +1878,15 @@ def color_picker(color_type: str,
elements: list,
color_col: str,
show_pairs: bool,
color_palette: dict) -> list:
color_palette: dict,
bootstraps_color_by_group: bool) -> list:
num_of_elements = len(elements)
colors = (
[kwargs.pop('color')] * num_of_elements
if kwargs.get('color', None) is not None
else ['black'] * num_of_elements
if color_col is not None or show_pairs
# if color_col is not None or show_pairs
if color_col is not None or not bootstraps_color_by_group
else list(color_palette.values())
)
if color_type in ['contrast', 'summary', 'delta_text']:
Expand All @@ -1877,7 +1901,7 @@ def color_picker(color_type: str,
return final_colors


def prepare_bars_for_plot(bar_type, bar_kwargs, horizontal, plot_palette_raw, color_col, show_pairs,
def prepare_bars_for_plot(bar_type, bar_kwargs, horizontal, plot_palette_raw, color_col, show_pairs, bootstraps_color_by_group,
plot_data = None, xvar = None, yvar = None, # Raw data
results = None, ticks_to_plot = None, extra_delta = None, # Contrast data
reference_band = None, summary_axes = None, ci_type = None # Summary data
Expand Down Expand Up @@ -1951,7 +1975,8 @@ def prepare_bars_for_plot(bar_type, bar_kwargs, horizontal, plot_palette_raw, co
elements = ticks_to_plot if bar_type=='contrast' else ticks,
color_col = color_col,
show_pairs = show_pairs,
color_palette = plot_palette_raw
color_palette = plot_palette_raw,
bootstraps_color_by_group = bootstraps_color_by_group
)
if bar_type == 'contrast' and extra_delta is not None:
colors.append('black')
Expand Down
67 changes: 61 additions & 6 deletions dabest/plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,7 @@ def delta_text_plotter(
show_pairs: bool,
float_contrast: bool,
extra_delta: float,
bootstraps_color_by_group: bool = False
):
"""
Add delta text to the contrast plot.
Expand All @@ -928,6 +929,8 @@ def delta_text_plotter(
Whether the DABEST plot uses Gardner-Altman or Cummings.
extra_delta : float or None
The extra mini-meta or delta-delta value if applicable.
bootstraps_color_by_group : bool, optional
Whether to color the bootstraps by group. Default is False.
"""
# Colors
from .misc_tools import color_picker
Expand All @@ -936,7 +939,8 @@ def delta_text_plotter(
elements = ticks_to_plot,
color_col = color_col,
show_pairs = show_pairs,
color_palette = plot_palette_raw
color_palette = plot_palette_raw,
bootstraps_color_by_group = bootstraps_color_by_group
)

num_of_elements = len(ticks_to_plot) + 1 if extra_delta is not None else len(ticks_to_plot)
Expand Down Expand Up @@ -1091,7 +1095,8 @@ def slopegraph_plotter(
temp_idx: list,
horizontal: bool,
temp_all_plot_groups: list,
plot_kwargs: dict
plot_kwargs: dict,
group_summaries_kwargs: dict
):
"""
Add slopegraph to the rawdata axes.
Expand Down Expand Up @@ -1124,6 +1129,8 @@ def slopegraph_plotter(
List of all plot groups.
plot_kwargs : dict
Keyword arguments for the plot.
group_summaries_kwargs : dict, optional
Keyword arguments for group summaries, if applicable.

"""
# Jitter Kwargs
Expand Down Expand Up @@ -1178,6 +1185,45 @@ def slopegraph_plotter(
x_points, y_points = (y_points, x_points) if horizontal else (x_points, y_points)
rawdata_axes.plot(x_points, y_points, **slopegraph_kwargs)

# Add the group summaries if applicable.
group_summaries = plot_kwargs.get("group_summaries", None)
if group_summaries is not None:
for key in ['gap_width_percent', 'offset']:
group_summaries_kwargs.pop(key, None)
group_summaries_kwargs['color'] = 'black' if group_summaries_kwargs.get('color') is None else group_summaries_kwargs['color']
group_summaries_kwargs['capsize'] = 0 if group_summaries_kwargs.get('capsize') is None else group_summaries_kwargs['capsize']

index_points = [t for t in range(x_start, x_start + grp_count)]
av_points, err_points, lo_points, hi_points = [], [], [], []
for group in range(len(index_points)):
if group_summaries == "mean_sd":
av_points.append(current_pair.iloc[:, int(group)].mean())
err_points.append(current_pair.iloc[:, int(group)].std())
elif group_summaries == "median_quartiles":
median = current_pair.iloc[:, int(group)].median()
av_points.append(median)
lo_points.append(median - current_pair.iloc[:, int(group)].quantile(0.25))
hi_points.append(current_pair.iloc[:, int(group)].quantile(0.75) - median)

if group_summaries == "median_quartiles":
err_points = [lo_points, hi_points]

# Plot the lines
if horizontal:
rawdata_axes.errorbar(
av_points,
index_points,
xerr=err_points,
**group_summaries_kwargs
)
else:
rawdata_axes.errorbar(
index_points,
av_points,
yerr=err_points,
**group_summaries_kwargs
)

x_start = x_start + grp_count

# Set the tick labels, because the slopegraph plotting doesn't.
Expand Down Expand Up @@ -1839,6 +1885,9 @@ def barplotter(
horizontal : bool
If the plot is horizontal.
"""
# Check if the custom_palette is a dictionary with two keys 0 and 1 (for filled bar coloring)
filled_bars = True if len(plot_palette_raw.keys())==2 and all(k in plot_palette_raw for k in [1, 0]) else False

bar_width = barplot_kwargs.get('width', 0.5)
fontsize = barplot_kwargs.pop('fontsize')

Expand Down Expand Up @@ -1866,29 +1915,35 @@ def barplotter(
for hue_val in bar1_df[color_col]
]
else:
edge_colors = raw_colors
edge_colors = len(all_plot_groups)*['black',] if filled_bars else raw_colors

bar1 = sns.barplot(
data=bar1_df,
x=xvar,
y="proportion",
ax=rawdata_axes,
order=all_plot_groups,
linewidth=2,
facecolor=(1, 1, 1, 0),
linewidth=1 if filled_bars else 2,
facecolor=plot_palette_raw[0] if filled_bars else (1, 1, 1, 0),
edgecolor=edge_colors,
zorder=1,
orient=orient,
)

if filled_bars:
barplot_kwargs['facecolor'] = plot_palette_raw[1]
barplot_kwargs['edgecolor'] = 'black'
barplot_kwargs['linewidth'] = 1
else:
barplot_kwargs['palette'] = plot_palette_raw

bar2 = sns.barplot(
data=plot_data,
x=yvar if horizontal else xvar,
y=xvar if horizontal else yvar,
hue=xvar if color_col is None else color_col,
ax=rawdata_axes,
order=all_plot_groups,
palette=plot_palette_raw,
dodge=False,
zorder=1,
orient=orient,
Expand Down
14 changes: 10 additions & 4 deletions dabest/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
raw_bars_kwargs, contrast_bars_kwargs, table_kwargs, gridkey_kwargs, contrast_marker_kwargs,
contrast_errorbar_kwargs, prop_sample_counts_kwargs, contrast_paired_lines_kwargs) = get_kwargs(
plot_kwargs = plot_kwargs,
ytick_color = ytick_color
ytick_color = ytick_color,
is_paired = effectsize_df.is_paired
)

(dabest_obj, plot_data, xvar, yvar, is_paired, effect_size, proportional,
Expand All @@ -160,7 +161,7 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
idx = idx,
all_plot_groups = all_plot_groups,
delta2 = effectsize_df.delta2,
sankey = True if proportional and show_pairs else False,
proportional = proportional
)

# Initialise the figure.
Expand Down Expand Up @@ -219,6 +220,7 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
horizontal = horizontal,
temp_all_plot_groups = temp_all_plot_groups,
plot_kwargs = plot_kwargs,
group_summaries_kwargs = group_summaries_kwargs
)

## Add delta dots to the contrast axes for paired plots.
Expand Down Expand Up @@ -333,7 +335,7 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi

## Swarm bars
raw_bars = plot_kwargs["raw_bars"]
if raw_bars and not proportional and not horizontal: #Currently not supporting swarm bars for horizontal plots (looks weird)
if raw_bars and not proportional and not is_paired and not horizontal: #Currently not supporting swarm bars for horizontal plots (looks weird)
raw_bars_dict, raw_bars_kwargs = prepare_bars_for_plot(
bar_type = 'raw',
bar_kwargs = raw_bars_kwargs,
Expand All @@ -343,7 +345,8 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
show_pairs = show_pairs,
plot_data = plot_data,
xvar = xvar,
yvar = yvar,
yvar = yvar,
bootstraps_color_by_group = bootstraps_color_by_group,
)
add_bars_to_plot(bar_dict = raw_bars_dict,
ax = rawdata_axes,
Expand Down Expand Up @@ -424,6 +427,7 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
show_pairs = show_pairs,
results = results,
ticks_to_plot = ticks_to_plot,
bootstraps_color_by_group = bootstraps_color_by_group,
extra_delta = (effectsize_df.mini_meta.difference if show_mini_meta
else effectsize_df.delta_delta.difference if show_delta2
else None)
Expand All @@ -445,6 +449,7 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
plot_palette_raw = plot_palette_raw,
show_pairs = show_pairs,
float_contrast = float_contrast,
bootstraps_color_by_group = bootstraps_color_by_group,
extra_delta = (effectsize_df.mini_meta.difference if show_mini_meta
else effectsize_df.delta_delta.difference if show_delta2
else None),
Expand Down Expand Up @@ -588,6 +593,7 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
reference_band = reference_band,
summary_axes = contrast_axes,
ci_type = ci_type,
bootstraps_color_by_group = bootstraps_color_by_group,
)

add_bars_to_plot(bar_dict = reference_band_dict,
Expand Down
5 changes: 3 additions & 2 deletions nbs/API/effsize_objects.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1343,7 +1343,7 @@
" face_color=None,\n",
"\n",
" raw_desat=0.5, # swarm_desat=0.5, OLD # bar_desat=0.5, OLD\n",
" contrast_desat=1, # halfviolin_desat=1, OLD\n",
" contrast_desat=1.0, # halfviolin_desat=1, OLD\n",
"\n",
" raw_alpha=None, # NEW\n",
" contrast_alpha=0.8, # halfviolin_alpha=0.8, OLD\n",
Expand Down Expand Up @@ -1678,7 +1678,8 @@
"\n",
" if raw_alpha is None:\n",
" raw_alpha = (0.4 if self.is_proportional and self.is_paired \n",
" else 0.5 if self.is_paired\n",
" else 0.5 if self.is_paired and (color_col is not None or self.__delta2)\n",
" else 0.2 if self.is_paired and color_col is None\n",
" else 1.0\n",
" )\n",
"\n",
Expand Down
Loading