Skip to content

Commit ce5a1d5

Browse files
committed
Reformat some plotting functions
1 parent 4f48598 commit ce5a1d5

File tree

293 files changed

+532
-805
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

293 files changed

+532
-805
lines changed

dabest/_modidx.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
'dabest.forest_plot.load_plot_data': ('API/forest_plot.html#load_plot_data', 'dabest/forest_plot.py')},
8080
'dabest.misc_tools': { 'dabest.misc_tools.add_counts_to_ticks': ( 'API/misc_tools.html#add_counts_to_ticks',
8181
'dabest/misc_tools.py'),
82+
'dabest.misc_tools.color_picker': ('API/misc_tools.html#color_picker', 'dabest/misc_tools.py'),
8283
'dabest.misc_tools.draw_zeroline': ('API/misc_tools.html#draw_zeroline', 'dabest/misc_tools.py'),
8384
'dabest.misc_tools.extract_contrast_plotting_ticks': ( 'API/misc_tools.html#extract_contrast_plotting_ticks',
8485
'dabest/misc_tools.py'),
@@ -95,6 +96,8 @@
9596
'dabest.misc_tools.get_varname': ('API/misc_tools.html#get_varname', 'dabest/misc_tools.py'),
9697
'dabest.misc_tools.initialize_fig': ('API/misc_tools.html#initialize_fig', 'dabest/misc_tools.py'),
9798
'dabest.misc_tools.merge_two_dicts': ('API/misc_tools.html#merge_two_dicts', 'dabest/misc_tools.py'),
99+
'dabest.misc_tools.prepare_bars_for_plot': ( 'API/misc_tools.html#prepare_bars_for_plot',
100+
'dabest/misc_tools.py'),
98101
'dabest.misc_tools.print_greeting': ('API/misc_tools.html#print_greeting', 'dabest/misc_tools.py'),
99102
'dabest.misc_tools.redraw_dependent_spines': ( 'API/misc_tools.html#redraw_dependent_spines',
100103
'dabest/misc_tools.py'),
@@ -117,12 +120,12 @@
117120
'dabest/plot_tools.py'),
118121
'dabest.plot_tools.SwarmPlot._swarm': ('API/plot_tools.html#swarmplot._swarm', 'dabest/plot_tools.py'),
119122
'dabest.plot_tools.SwarmPlot.plot': ('API/plot_tools.html#swarmplot.plot', 'dabest/plot_tools.py'),
123+
'dabest.plot_tools.add_bars_to_plot': ('API/plot_tools.html#add_bars_to_plot', 'dabest/plot_tools.py'),
120124
'dabest.plot_tools.add_counts_to_prop_plots': ( 'API/plot_tools.html#add_counts_to_prop_plots',
121125
'dabest/plot_tools.py'),
122126
'dabest.plot_tools.barplotter': ('API/plot_tools.html#barplotter', 'dabest/plot_tools.py'),
123127
'dabest.plot_tools.check_data_matches_labels': ( 'API/plot_tools.html#check_data_matches_labels',
124128
'dabest/plot_tools.py'),
125-
'dabest.plot_tools.color_picker': ('API/plot_tools.html#color_picker', 'dabest/plot_tools.py'),
126129
'dabest.plot_tools.delta_dots_plotter': ( 'API/plot_tools.html#delta_dots_plotter',
127130
'dabest/plot_tools.py'),
128131
'dabest.plot_tools.delta_text_plotter': ( 'API/plot_tools.html#delta_text_plotter',
@@ -136,14 +139,10 @@
136139
'dabest.plot_tools.normalize_dict': ('API/plot_tools.html#normalize_dict', 'dabest/plot_tools.py'),
137140
'dabest.plot_tools.plot_minimeta_or_deltadelta_violins': ( 'API/plot_tools.html#plot_minimeta_or_deltadelta_violins',
138141
'dabest/plot_tools.py'),
139-
'dabest.plot_tools.raw_contrast_bar_plotter': ( 'API/plot_tools.html#raw_contrast_bar_plotter',
140-
'dabest/plot_tools.py'),
141142
'dabest.plot_tools.sankeydiag': ('API/plot_tools.html#sankeydiag', 'dabest/plot_tools.py'),
142143
'dabest.plot_tools.single_sankey': ('API/plot_tools.html#single_sankey', 'dabest/plot_tools.py'),
143144
'dabest.plot_tools.slopegraph_plotter': ( 'API/plot_tools.html#slopegraph_plotter',
144145
'dabest/plot_tools.py'),
145-
'dabest.plot_tools.summary_bars_plotter': ( 'API/plot_tools.html#summary_bars_plotter',
146-
'dabest/plot_tools.py'),
147146
'dabest.plot_tools.swarmplot': ('API/plot_tools.html#swarmplot', 'dabest/plot_tools.py'),
148147
'dabest.plot_tools.table_for_horizontal_plots': ( 'API/plot_tools.html#table_for_horizontal_plots',
149148
'dabest/plot_tools.py'),

dabest/misc_tools.py

Lines changed: 116 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
__all__ = ['merge_two_dicts', 'unpack_and_add', 'print_greeting', 'get_varname', 'get_unique_categories', 'get_params',
77
'get_kwargs', 'get_color_palette', 'initialize_fig', 'get_plot_groups', 'add_counts_to_ticks',
88
'extract_contrast_plotting_ticks', 'set_xaxis_ticks_and_lims', 'show_legend', 'gardner_altman_adjustments',
9-
'draw_zeroline', 'redraw_independent_spines', 'redraw_dependent_spines', 'extract_group_summaries']
9+
'draw_zeroline', 'redraw_independent_spines', 'redraw_dependent_spines', 'extract_group_summaries',
10+
'color_picker', 'prepare_bars_for_plot']
1011

1112
# %% ../nbs/API/misc_tools.ipynb 4
1213
import datetime as dt
@@ -167,7 +168,6 @@ def get_params(
167168
group_summaries = None if barplot_kwargs['errorbar'] is not None else group_summaries
168169

169170
# Contrast Axes kwargs
170-
contrast_alpha = plot_kwargs["contrast_alpha"]
171171
ci_type = plot_kwargs["ci_type"]
172172
if ci_type not in ["bca", "pct"]:
173173
raise ValueError("Invalid `ci_type`. Must be either 'bca' or 'pct'.")
@@ -196,7 +196,7 @@ def get_params(
196196

197197
return (dabest_obj, plot_data, xvar, yvar, is_paired, effect_size, proportional, all_plot_groups,
198198
idx, show_delta2, show_mini_meta, float_contrast, show_pairs, group_summaries,
199-
horizontal, results, contrast_alpha, ci_type, x1_level, experiment_label, show_baseline_ec,
199+
horizontal, results, ci_type, x1_level, experiment_label, show_baseline_ec,
200200
one_sankey, two_col_sankey, asymmetric_side)
201201

202202
def get_kwargs(
@@ -265,6 +265,7 @@ def get_kwargs(
265265
"orientation": 'vertical',
266266
"showextrema": False,
267267
"showmedians": False,
268+
"alpha": plot_kwargs["contrast_alpha"],
268269

269270
}
270271
if plot_kwargs["contrast_kwargs"] is None:
@@ -365,13 +366,11 @@ def get_kwargs(
365366

366367
# Delta text kwargs.
367368
default_delta_text_kwargs = {
368-
"color": None,
369369
"alpha": 1,
370370
"fontsize": 10,
371371
"ha": 'center',
372372
"va": 'center',
373373
"rotation": 0,
374-
"x_location": 'right',
375374
"x_coordinates": None,
376375
"y_coordinates": None,
377376
"offset": 0
@@ -384,7 +383,6 @@ def get_kwargs(
384383
# Summary bars kwargs.
385384
default_summary_bars_kwargs = {
386385
"span_ax": False,
387-
"color": None,
388386
"alpha": 0.15,
389387
"zorder":-3
390388
}
@@ -395,8 +393,8 @@ def get_kwargs(
395393

396394
# Swarm bars kwargs.
397395
default_raw_bars_kwargs = {
398-
"color": None,
399-
"zorder":-3
396+
"zorder":-3,
397+
"alpha": 0.2
400398
}
401399
if plot_kwargs["raw_bars_kwargs"] is None:
402400
raw_bars_kwargs = default_raw_bars_kwargs
@@ -405,8 +403,8 @@ def get_kwargs(
405403

406404
# Contrast bars kwargs.
407405
default_contrast_bars_kwargs = {
408-
"color": None,
409-
"zorder":-3
406+
"zorder":-3,
407+
"alpha": 0.2
410408
}
411409
if plot_kwargs["contrast_bars_kwargs"] is None:
412410
contrast_bars_kwargs = default_contrast_bars_kwargs
@@ -1115,7 +1113,6 @@ def extract_contrast_plotting_ticks(
11151113
ticks_to_start_twocol_sankey.pop()
11161114
ticks_to_start_twocol_sankey.insert(0, 0)
11171115
else:
1118-
11191116
ticks_to_skip = np.cumsum([len(t) for t in idx])[:-1].tolist()
11201117
ticks_to_skip.insert(0, 0)
11211118
# Then obtain the ticks where we have to plot the effect sizes.
@@ -1848,3 +1845,111 @@ def extract_group_summaries(
18481845
group_summaries_kwargs.pop("offset")
18491846

18501847
return group_summaries_method, group_summaries_offset, group_summaries_line_color
1848+
1849+
def color_picker(color_type: str,
1850+
kwargs: dict,
1851+
elements: list,
1852+
color_col: str,
1853+
show_pairs: bool,
1854+
color_palette: dict) -> list:
1855+
num_of_elements = len(elements)
1856+
colors = (
1857+
[kwargs.pop('color')] * num_of_elements
1858+
if kwargs.get('color', None) is not None
1859+
else ['black'] * num_of_elements
1860+
if color_col is not None or show_pairs
1861+
else list(color_palette.values())
1862+
)
1863+
if color_type in ['contrast', 'summary', 'delta_text']:
1864+
if len(colors) == num_of_elements:
1865+
final_colors = colors
1866+
else:
1867+
final_colors = []
1868+
for tick in elements:
1869+
final_colors.append(colors[int(tick)])
1870+
else:
1871+
final_colors = colors
1872+
return final_colors
1873+
1874+
1875+
def prepare_bars_for_plot(bar_type, bar_kwargs, horizontal, plot_palette_raw, color_col, show_pairs,
1876+
plot_data = None, xvar = None, yvar = None, # Raw data
1877+
results = None, ticks_to_plot = None, extra_delta = None, # Contrast data
1878+
summary_bars = None, summary_axes = None, ci_type = None # Summary data
1879+
):
1880+
from .misc_tools import color_picker
1881+
bar_dict = {}
1882+
if bar_type in ['raw', 'contrast']:
1883+
if bar_type == 'raw':
1884+
if isinstance(plot_data[xvar].dtype, pd.CategoricalDtype):
1885+
order = pd.unique(plot_data[xvar]).categories
1886+
else:
1887+
order = pd.unique(plot_data[xvar])
1888+
means = plot_data.groupby(xvar, observed=False)[yvar].mean().reindex(index=order).values
1889+
ticks = list(range(len(order)))
1890+
elif bar_type == 'contrast':
1891+
means = results.difference.to_list()
1892+
ticks = ticks_to_plot.copy()
1893+
if extra_delta is not None:
1894+
ticks.append(ticks[-1]+1) # Add an extra tick
1895+
means.append(extra_delta)
1896+
1897+
num_of_bars = len(means)
1898+
y_start_values, y_distances = [0]*num_of_bars, means
1899+
x_start_values, x_distances = [num - (0.5 if horizontal else 0.25) for num in ticks], [0.5,]*num_of_bars
1900+
1901+
elif bar_type == 'summary':
1902+
# Begin checks
1903+
if not isinstance(summary_bars, list):
1904+
raise TypeError("summary_bars must be a list of indices (ints).")
1905+
if not all(isinstance(i, int) for i in summary_bars):
1906+
raise TypeError("summary_bars must be a list of indices (ints).")
1907+
if any(i >= len(results) for i in summary_bars):
1908+
raise ValueError("Index {} chosen is out of range for the contrast objects.".format([i for i in summary_bars if i >= len(results)]))
1909+
1910+
ticks = [ticks_to_plot[tick] for tick in summary_bars]
1911+
summary_xmin, summary_xmax = summary_axes.get_xlim()
1912+
summary_ymin, summary_ymax = summary_axes.get_ylim()
1913+
span_ax = bar_kwargs.pop("span_ax")
1914+
1915+
x_start_values, y_start_values, x_distances, y_distances = [], [], [], []
1916+
for summary_index in summary_bars:
1917+
summary_ci_low = results.get(ci_type+'_low')[summary_index]
1918+
summary_ci_high = results.get(ci_type+'_high')[summary_index]
1919+
1920+
if span_ax == True:
1921+
starting_location = summary_ymax if horizontal else summary_xmin
1922+
else:
1923+
starting_location = ticks_to_plot[summary_index]
1924+
x_distance = summary_ymin if horizontal else summary_xmax
1925+
1926+
x_start_values.append(starting_location)
1927+
y_start_values.append(summary_ci_low)
1928+
x_distances.append(x_distance + 1)
1929+
y_distances.append(summary_ci_high - summary_ci_low)
1930+
else:
1931+
raise ValueError("Invalid bar_type. Must be 'raw' or 'contrast'.")
1932+
1933+
if horizontal:
1934+
x_start_values, y_start_values = y_start_values, x_start_values
1935+
x_distances, y_distances = y_distances, x_distances
1936+
1937+
for name, values in zip(['x_start_values', 'x_distances', 'y_start_values', 'y_distance'],
1938+
[x_start_values, x_distances, y_start_values, y_distances]
1939+
):
1940+
bar_dict[name] = values
1941+
1942+
# Colors
1943+
colors = color_picker(
1944+
color_type = bar_type,
1945+
kwargs = bar_kwargs,
1946+
elements = ticks_to_plot if bar_type=='contrast' else ticks,
1947+
color_col = color_col,
1948+
show_pairs = show_pairs,
1949+
color_palette = plot_palette_raw
1950+
)
1951+
if bar_type == 'contrast' and extra_delta is not None:
1952+
colors.append('black')
1953+
bar_dict['colors'] = colors
1954+
1955+
return bar_dict, bar_kwargs

0 commit comments

Comments
 (0)