Skip to content

Commit 99ce716

Browse files
committed
Added tests for gridkey, horizontal plots, delta text, summary bars, swarm bars, and contrast bars
1 parent fa60e76 commit 99ce716

File tree

164 files changed

+1343
-81
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

164 files changed

+1343
-81
lines changed

dabest/misc_tools.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,7 @@ def get_plot_groups(is_paired, idx, proportional, all_plot_groups):
872872
return temp_idx, temp_all_plot_groups
873873

874874

875-
def add_counts_to_ticks(plot_data, xvar, yvar, rawdata_axes, plot_kwargs, horizontal):
875+
def add_counts_to_ticks(plot_data, xvar, yvar, rawdata_axes, plot_kwargs, flow, horizontal):
876876
"""
877877
878878
Add the counts to the raw data axes labels.
@@ -889,6 +889,8 @@ def add_counts_to_ticks(plot_data, xvar, yvar, rawdata_axes, plot_kwargs, horizo
889889
The raw data axes.
890890
plot_kwargs : dict
891891
Kwargs passed to the plot function.
892+
flow : bool
893+
Whether sankey flow is enabled or not.
892894
horizontal : bool
893895
A boolean flag to determine if the plot is for horizontal plotting.
894896
"""
@@ -919,7 +921,12 @@ def lookup_value(text):
919921

920922
for ticklab in get_label():
921923
t = ticklab.get_text()
922-
te = t.split('\n')[-1] # Get the last line of the label
924+
925+
if horizontal and not flow:
926+
te = t.split('v.s. ')[-1] # Get the last line of the label
927+
else:
928+
te = t.split('\n')[-1] # Get the last line of the label
929+
923930
value = lookup_value(te)
924931
if horizontal:
925932
ticks_with_counts.append(f"{t} (N={value})")
@@ -992,7 +999,7 @@ def extract_contrast_plotting_ticks(is_paired, show_pairs, two_col_sankey, plot_
992999
return ticks_to_skip, ticks_to_plot, ticks_to_skip_contrast, ticks_to_start_twocol_sankey
9931000

9941001
def set_xaxis_ticks_and_lims(show_delta2, show_mini_meta, rawdata_axes, contrast_axes, show_pairs, float_contrast,
995-
ticks_to_skip, contrast_xtick_labels, plot_kwargs, horizontal):
1002+
ticks_to_skip, contrast_xtick_labels, plot_kwargs, proportional, horizontal):
9961003
"""
9971004
Set the x-axis/yaxis ticks and limits for the plotter function.
9981005
@@ -1016,6 +1023,8 @@ def set_xaxis_ticks_and_lims(show_delta2, show_mini_meta, rawdata_axes, contrast
10161023
A list of contrast xtick labels.
10171024
plot_kwargs : dict
10181025
Kwargs passed to the plot function.
1026+
proportional: bool
1027+
A boolean flag to determine if the plot is a proportional plot.
10191028
horizontal : bool
10201029
A boolean flag to determine if the plot is for horizontal plotting.
10211030
"""
@@ -1034,6 +1043,9 @@ def set_xaxis_ticks_and_lims(show_delta2, show_mini_meta, rawdata_axes, contrast
10341043
max_x = contrast_axes.get_ylim()[1]
10351044
rawdata_axes.set_ylim(-0.375, max_x)
10361045

1046+
if proportional:
1047+
rawdata_axes.set_ylim(-0.375, max_x+0.1)
1048+
10371049
if show_delta2 or show_mini_meta:
10381050
# Increase the ylim of raw data by 2
10391051
temp = rawdata_axes.get_ylim()

dabest/plot_tools.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -850,8 +850,8 @@ def sankeydiag(
850850
sankey_ticks = (
851851
[f"{left}" for left in broadcasted_left]
852852
if flow
853-
else [
854-
f"{left}\n v.s.\n{right}"
853+
else [f"{left} v.s. {right}" if horizontal
854+
else f"{left}\n v.s.\n{right}"
855855
for left, right in zip(broadcasted_left, right_idx)
856856
]
857857
)
@@ -917,9 +917,9 @@ def summary_bars_plotter(summary_bars: list, results: object, ax_to_plot: object
917917
else:
918918
summary_xmin, summary_xmax = ax_to_plot.get_xlim()
919919
summary_bars_colors = (
920-
[summary_bars_kwargs.get('color')]*(max(ticks_to_plot)+1)
920+
[summary_bars_kwargs.get('color')]*int(max(ticks_to_plot)+1)
921921
if summary_bars_kwargs.get('color') is not None
922-
else ['black']*(max(ticks_to_plot)+1)
922+
else ['black']*int(max(ticks_to_plot)+1)
923923
if color_col is not None or (proportional and is_paired) or is_paired
924924
else list(plot_palette_raw.values())
925925
)
@@ -932,7 +932,7 @@ def summary_bars_plotter(summary_bars: list, results: object, ax_to_plot: object
932932
summary_ci_low = results.pct_low[summary_index]
933933
summary_ci_high = results.pct_high[summary_index]
934934

935-
summary_color = summary_bars_colors[ticks_to_plot[summary_index]]
935+
summary_color = summary_bars_colors[int(ticks_to_plot[summary_index])]
936936

937937
ax_to_plot.add_patch(mpatches.Rectangle((summary_xmin,summary_ci_low),summary_xmax+1,
938938
summary_ci_high-summary_ci_low, zorder=-2, color=summary_color, **summary_bars_kwargs))
@@ -981,14 +981,14 @@ def contrast_bars_plotter(results: object, ax_to_plot: object, swarm_plot_ax: o
981981

982982
contrast_means = []
983983
for j, tick in enumerate(ticks_to_plot):
984-
contrast_means.append(results.difference[j])
984+
contrast_means.append(results.difference[int(j)])
985985

986986
unpacked_idx = [element for innerList in idx for element in innerList]
987987

988988
contrast_bars_colors = (
989-
[contrast_bars_kwargs.get('color')] * (max(ticks_to_plot) + 1)
989+
[contrast_bars_kwargs.get('color')] * int(max(ticks_to_plot) + 1)
990990
if contrast_bars_kwargs.get('color') is not None
991-
else ['black'] * (max(ticks_to_plot) + 1)
991+
else ['black'] * int(max(ticks_to_plot) + 1)
992992
if color_col is not None or is_paired
993993
else plot_palette_raw
994994
)
@@ -1136,9 +1136,9 @@ def delta_text_plotter(results: object, ax_to_plot: object, swarm_plot_ax: objec
11361136
delta_text_kwargs.pop('x_location')
11371137

11381138
delta_text_colors = (
1139-
[delta_text_kwargs.get('color')]*(max(ticks_to_plot)+1)
1139+
[delta_text_kwargs.get('color')]*int(max(ticks_to_plot)+1)
11401140
if delta_text_kwargs.get('color') is not None
1141-
else ['black']*(max(ticks_to_plot)+1)
1141+
else ['black']*int(max(ticks_to_plot)+1)
11421142
if color_col is not None or (proportional and is_paired) or is_paired
11431143
else plot_palette_raw
11441144
)
@@ -1156,7 +1156,7 @@ def delta_text_plotter(results: object, ax_to_plot: object, swarm_plot_ax: objec
11561156
# Collect the Y-values for the delta text
11571157
Delta_Values = []
11581158
for j, tick in enumerate(ticks_to_plot):
1159-
Delta_Values.append(results.difference[j])
1159+
Delta_Values.append(results.difference[int(j)])
11601160
if show_delta2: Delta_Values.append(delta_delta.difference)
11611161
if show_mini_meta: Delta_Values.append(mini_meta_delta.difference)
11621162

@@ -1197,9 +1197,9 @@ def delta_text_plotter(results: object, ax_to_plot: object, swarm_plot_ax: objec
11971197
for x,y,t,tick in zip(delta_text_x_coordinates, delta_text_y_coordinates,Delta_Values,ticks_to_plot):
11981198
Delta_Text = np.format_float_positional(t, precision=2, sign=True, trim="k", min_digits=2)
11991199
idx_selector = (
1200-
tick
1200+
int(tick)
12011201
if type(delta_text_colors) == list
1202-
else unpacked_idx[tick]
1202+
else unpacked_idx[int(tick)]
12031203
)
12041204
ax_to_plot.text(x, y, Delta_Text, color=delta_text_colors[idx_selector], zorder=5, **delta_text_kwargs)
12051205

@@ -1349,7 +1349,16 @@ def slopegraph_plotter(dabest_obj, plot_data, xvar, yvar, color_col, plot_palett
13491349
current_pair = pivoted_plot_data.loc[
13501350
:, pd.MultiIndex.from_product([pivot_values, current_tuple])
13511351
].dropna()
1352+
1353+
# Check for correct pairing
1354+
if len(current_pair) == 0:
1355+
raise ValueError('There are no pairs to plot... check original dataframe for correct ID pairing')
1356+
1357+
current_pair = pivoted_plot_data.loc[
1358+
:, pd.MultiIndex.from_product([pivot_values, current_tuple])
1359+
]
13521360
grp_count = len(current_tuple)
1361+
13531362
# Iterate through the data for the current tuple.
13541363
for ID, observation in current_pair.iterrows():
13551364
x_points = [t + 0.15*jitter*rng.standard_t(df=6, size=None) for t in range(x_start, x_start + grp_count)] # devMJBL
@@ -1545,16 +1554,16 @@ def effect_size_curve_plotter(ticks_to_plot, results, ci_type, contrast_axes, vi
15451554
# Plot the curves
15461555
contrast_xtick_labels = []
15471556
for j, tick in enumerate(ticks_to_plot):
1548-
current_group = results.test[j]
1549-
current_control = results.control[j]
1550-
current_bootstrap = results.bootstraps[j]
1551-
current_effsize = results.difference[j]
1557+
current_group = results.test[int(j)]
1558+
current_control = results.control[int(j)]
1559+
current_bootstrap = results.bootstraps[int(j)]
1560+
current_effsize = results.difference[int(j)]
15521561
if ci_type == "bca":
1553-
current_ci_low = results.bca_low[j]
1554-
current_ci_high = results.bca_high[j]
1562+
current_ci_low = results.bca_low[int(j)]
1563+
current_ci_high = results.bca_high[int(j)]
15551564
else:
1556-
current_ci_low = results.pct_low[j]
1557-
current_ci_high = results.pct_high[j]
1565+
current_ci_low = results.pct_low[int(j)]
1566+
current_ci_high = results.pct_high[int(j)]
15581567

15591568
# Create the violinplot.
15601569
# New in v0.2.6: drop negative infinities before plotting.
@@ -2195,6 +2204,8 @@ def __init__(
21952204
h = (ax.get_position().ymax - ax.get_position().ymin) * figh
21962205
ax_xspan = ax.get_xlim()[1] - ax.get_xlim()[0]
21972206
ax_yspan = ax.get_ylim()[1] - ax.get_ylim()[0]
2207+
if horizontal:
2208+
ax_xspan, ax_yspan = ax_yspan, ax_xspan
21982209

21992210
# increases jitter distance based on number of swarms that is going to be drawn
22002211
jitter = jitter * (1 + 0.05 * (math.log(ax_xspan)))

dabest/plotter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
175175
)
176176
if not proportional:
177177
# Plot the raw data as a slopegraph.
178+
178179
slopegraph_plotter(
179180
dabest_obj=dabest_obj,
180181
plot_data=plot_data,
@@ -188,7 +189,7 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
188189
temp_idx=temp_idx,
189190
horizontal=horizontal
190191
)
191-
192+
192193
# DELTA PTS ON CONTRAST PLOT WIP
193194
show_delta_dots = plot_kwargs["delta_dot"]
194195
if show_delta_dots and is_paired is not None:
@@ -313,6 +314,7 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
313314
yvar=yvar,
314315
rawdata_axes=rawdata_axes,
315316
plot_kwargs=plot_kwargs,
317+
flow = sankey_kwargs["flow"],
316318
horizontal=horizontal,
317319
)
318320

@@ -388,6 +390,7 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
388390
ticks_to_skip=ticks_to_skip,
389391
contrast_xtick_labels=contrast_xtick_labels,
390392
plot_kwargs=plot_kwargs,
393+
proportional=proportional,
391394
horizontal=horizontal,
392395
)
393396
# Legend

nbs/API/misc_tools.ipynb

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -925,7 +925,7 @@
925925
" return temp_idx, temp_all_plot_groups\n",
926926
"\n",
927927
"\n",
928-
"def add_counts_to_ticks(plot_data, xvar, yvar, rawdata_axes, plot_kwargs, horizontal):\n",
928+
"def add_counts_to_ticks(plot_data, xvar, yvar, rawdata_axes, plot_kwargs, flow, horizontal):\n",
929929
" \"\"\"\n",
930930
"\n",
931931
" Add the counts to the raw data axes labels.\n",
@@ -942,6 +942,8 @@
942942
" The raw data axes.\n",
943943
" plot_kwargs : dict\n",
944944
" Kwargs passed to the plot function.\n",
945+
" flow : bool\n",
946+
" Whether sankey flow is enabled or not.\n",
945947
" horizontal : bool\n",
946948
" A boolean flag to determine if the plot is for horizontal plotting.\n",
947949
" \"\"\"\n",
@@ -972,7 +974,12 @@
972974
" \n",
973975
" for ticklab in get_label():\n",
974976
" t = ticklab.get_text()\n",
975-
" te = t.split('\\n')[-1] # Get the last line of the label\n",
977+
"\n",
978+
" if horizontal and not flow:\n",
979+
" te = t.split('v.s. ')[-1] # Get the last line of the label\n",
980+
" else:\n",
981+
" te = t.split('\\n')[-1] # Get the last line of the label\n",
982+
"\n",
976983
" value = lookup_value(te)\n",
977984
" if horizontal:\n",
978985
" ticks_with_counts.append(f\"{t} (N={value})\")\n",
@@ -1045,7 +1052,7 @@
10451052
" return ticks_to_skip, ticks_to_plot, ticks_to_skip_contrast, ticks_to_start_twocol_sankey\n",
10461053
"\n",
10471054
"def set_xaxis_ticks_and_lims(show_delta2, show_mini_meta, rawdata_axes, contrast_axes, show_pairs, float_contrast,\n",
1048-
" ticks_to_skip, contrast_xtick_labels, plot_kwargs, horizontal):\n",
1055+
" ticks_to_skip, contrast_xtick_labels, plot_kwargs, proportional, horizontal):\n",
10491056
" \"\"\"\n",
10501057
" Set the x-axis/yaxis ticks and limits for the plotter function.\n",
10511058
"\n",
@@ -1069,6 +1076,8 @@
10691076
" A list of contrast xtick labels.\n",
10701077
" plot_kwargs : dict\n",
10711078
" Kwargs passed to the plot function.\n",
1079+
" proportional: bool\n",
1080+
" A boolean flag to determine if the plot is a proportional plot.\n",
10721081
" horizontal : bool\n",
10731082
" A boolean flag to determine if the plot is for horizontal plotting.\n",
10741083
" \"\"\"\n",
@@ -1087,6 +1096,9 @@
10871096
" max_x = contrast_axes.get_ylim()[1]\n",
10881097
" rawdata_axes.set_ylim(-0.375, max_x)\n",
10891098
"\n",
1099+
" if proportional:\n",
1100+
" rawdata_axes.set_ylim(-0.375, max_x+0.1)\n",
1101+
"\n",
10901102
" if show_delta2 or show_mini_meta:\n",
10911103
" # Increase the ylim of raw data by 2\n",
10921104
" temp = rawdata_axes.get_ylim()\n",

0 commit comments

Comments
 (0)