66__all__ = ['merge_two_dicts' , 'unpack_and_add' , 'print_greeting' , 'get_varname' , 'get_unique_categories' , 'get_params' ,
77 'get_kwargs' , 'get_color_palette' , 'initialize_fig' , 'get_plot_groups' , 'add_counts_to_ticks' ,
88 'extract_contrast_plotting_ticks' , 'set_xaxis_ticks_and_lims' , 'show_legend' , 'gardner_altman_adjustments' ,
9- 'draw_zeroline' , 'redraw_independent_spines' , 'redraw_dependent_spines' , 'extract_group_summaries' ]
9+ 'draw_zeroline' , 'redraw_independent_spines' , 'redraw_dependent_spines' , 'extract_group_summaries' ,
10+ 'color_picker' , 'prepare_bars_for_plot' ]
1011
1112# %% ../nbs/API/misc_tools.ipynb 4
1213import datetime as dt
@@ -167,7 +168,6 @@ def get_params(
167168 group_summaries = None if barplot_kwargs ['errorbar' ] is not None else group_summaries
168169
169170 # Contrast Axes kwargs
170- contrast_alpha = plot_kwargs ["contrast_alpha" ]
171171 ci_type = plot_kwargs ["ci_type" ]
172172 if ci_type not in ["bca" , "pct" ]:
173173 raise ValueError ("Invalid `ci_type`. Must be either 'bca' or 'pct'." )
@@ -196,7 +196,7 @@ def get_params(
196196
197197 return (dabest_obj , plot_data , xvar , yvar , is_paired , effect_size , proportional , all_plot_groups ,
198198 idx , show_delta2 , show_mini_meta , float_contrast , show_pairs , group_summaries ,
199- horizontal , results , contrast_alpha , ci_type , x1_level , experiment_label , show_baseline_ec ,
199+ horizontal , results , ci_type , x1_level , experiment_label , show_baseline_ec ,
200200 one_sankey , two_col_sankey , asymmetric_side )
201201
202202def get_kwargs (
@@ -265,6 +265,7 @@ def get_kwargs(
265265 "orientation" : 'vertical' ,
266266 "showextrema" : False ,
267267 "showmedians" : False ,
268+ "alpha" : plot_kwargs ["contrast_alpha" ],
268269
269270 }
270271 if plot_kwargs ["contrast_kwargs" ] is None :
@@ -365,13 +366,11 @@ def get_kwargs(
365366
366367 # Delta text kwargs.
367368 default_delta_text_kwargs = {
368- "color" : None ,
369369 "alpha" : 1 ,
370370 "fontsize" : 10 ,
371371 "ha" : 'center' ,
372372 "va" : 'center' ,
373373 "rotation" : 0 ,
374- "x_location" : 'right' ,
375374 "x_coordinates" : None ,
376375 "y_coordinates" : None ,
377376 "offset" : 0
@@ -384,7 +383,6 @@ def get_kwargs(
384383 # Summary bars kwargs.
385384 default_summary_bars_kwargs = {
386385 "span_ax" : False ,
387- "color" : None ,
388386 "alpha" : 0.15 ,
389387 "zorder" :- 3
390388 }
@@ -395,8 +393,8 @@ def get_kwargs(
395393
396394 # Swarm bars kwargs.
397395 default_raw_bars_kwargs = {
398- "color" : None ,
399- "zorder" : - 3
396+ "zorder" : - 3 ,
397+ "alpha" : 0.2
400398 }
401399 if plot_kwargs ["raw_bars_kwargs" ] is None :
402400 raw_bars_kwargs = default_raw_bars_kwargs
@@ -405,8 +403,8 @@ def get_kwargs(
405403
406404 # Contrast bars kwargs.
407405 default_contrast_bars_kwargs = {
408- "color" : None ,
409- "zorder" : - 3
406+ "zorder" : - 3 ,
407+ "alpha" : 0.2
410408 }
411409 if plot_kwargs ["contrast_bars_kwargs" ] is None :
412410 contrast_bars_kwargs = default_contrast_bars_kwargs
@@ -1115,7 +1113,6 @@ def extract_contrast_plotting_ticks(
11151113 ticks_to_start_twocol_sankey .pop ()
11161114 ticks_to_start_twocol_sankey .insert (0 , 0 )
11171115 else :
1118-
11191116 ticks_to_skip = np .cumsum ([len (t ) for t in idx ])[:- 1 ].tolist ()
11201117 ticks_to_skip .insert (0 , 0 )
11211118 # Then obtain the ticks where we have to plot the effect sizes.
@@ -1848,3 +1845,111 @@ def extract_group_summaries(
18481845 group_summaries_kwargs .pop ("offset" )
18491846
18501847 return group_summaries_method , group_summaries_offset , group_summaries_line_color
1848+
1849+ def color_picker (color_type : str ,
1850+ kwargs : dict ,
1851+ elements : list ,
1852+ color_col : str ,
1853+ show_pairs : bool ,
1854+ color_palette : dict ) -> list :
1855+ num_of_elements = len (elements )
1856+ colors = (
1857+ [kwargs .pop ('color' )] * num_of_elements
1858+ if kwargs .get ('color' , None ) is not None
1859+ else ['black' ] * num_of_elements
1860+ if color_col is not None or show_pairs
1861+ else list (color_palette .values ())
1862+ )
1863+ if color_type in ['contrast' , 'summary' , 'delta_text' ]:
1864+ if len (colors ) == num_of_elements :
1865+ final_colors = colors
1866+ else :
1867+ final_colors = []
1868+ for tick in elements :
1869+ final_colors .append (colors [int (tick )])
1870+ else :
1871+ final_colors = colors
1872+ return final_colors
1873+
1874+
1875+ def prepare_bars_for_plot (bar_type , bar_kwargs , horizontal , plot_palette_raw , color_col , show_pairs ,
1876+ plot_data = None , xvar = None , yvar = None , # Raw data
1877+ results = None , ticks_to_plot = None , extra_delta = None , # Contrast data
1878+ summary_bars = None , summary_axes = None , ci_type = None # Summary data
1879+ ):
1880+ from .misc_tools import color_picker
1881+ bar_dict = {}
1882+ if bar_type in ['raw' , 'contrast' ]:
1883+ if bar_type == 'raw' :
1884+ if isinstance (plot_data [xvar ].dtype , pd .CategoricalDtype ):
1885+ order = pd .unique (plot_data [xvar ]).categories
1886+ else :
1887+ order = pd .unique (plot_data [xvar ])
1888+ means = plot_data .groupby (xvar , observed = False )[yvar ].mean ().reindex (index = order ).values
1889+ ticks = list (range (len (order )))
1890+ elif bar_type == 'contrast' :
1891+ means = results .difference .to_list ()
1892+ ticks = ticks_to_plot .copy ()
1893+ if extra_delta is not None :
1894+ ticks .append (ticks [- 1 ]+ 1 ) # Add an extra tick
1895+ means .append (extra_delta )
1896+
1897+ num_of_bars = len (means )
1898+ y_start_values , y_distances = [0 ]* num_of_bars , means
1899+ x_start_values , x_distances = [num - (0.5 if horizontal else 0.25 ) for num in ticks ], [0.5 ,]* num_of_bars
1900+
1901+ elif bar_type == 'summary' :
1902+ # Begin checks
1903+ if not isinstance (summary_bars , list ):
1904+ raise TypeError ("summary_bars must be a list of indices (ints)." )
1905+ if not all (isinstance (i , int ) for i in summary_bars ):
1906+ raise TypeError ("summary_bars must be a list of indices (ints)." )
1907+ if any (i >= len (results ) for i in summary_bars ):
1908+ raise ValueError ("Index {} chosen is out of range for the contrast objects." .format ([i for i in summary_bars if i >= len (results )]))
1909+
1910+ ticks = [ticks_to_plot [tick ] for tick in summary_bars ]
1911+ summary_xmin , summary_xmax = summary_axes .get_xlim ()
1912+ summary_ymin , summary_ymax = summary_axes .get_ylim ()
1913+ span_ax = bar_kwargs .pop ("span_ax" )
1914+
1915+ x_start_values , y_start_values , x_distances , y_distances = [], [], [], []
1916+ for summary_index in summary_bars :
1917+ summary_ci_low = results .get (ci_type + '_low' )[summary_index ]
1918+ summary_ci_high = results .get (ci_type + '_high' )[summary_index ]
1919+
1920+ if span_ax == True :
1921+ starting_location = summary_ymax if horizontal else summary_xmin
1922+ else :
1923+ starting_location = ticks_to_plot [summary_index ]
1924+ x_distance = summary_ymin if horizontal else summary_xmax
1925+
1926+ x_start_values .append (starting_location )
1927+ y_start_values .append (summary_ci_low )
1928+ x_distances .append (x_distance + 1 )
1929+ y_distances .append (summary_ci_high - summary_ci_low )
1930+ else :
1931+ raise ValueError ("Invalid bar_type. Must be 'raw' or 'contrast'." )
1932+
1933+ if horizontal :
1934+ x_start_values , y_start_values = y_start_values , x_start_values
1935+ x_distances , y_distances = y_distances , x_distances
1936+
1937+ for name , values in zip (['x_start_values' , 'x_distances' , 'y_start_values' , 'y_distance' ],
1938+ [x_start_values , x_distances , y_start_values , y_distances ]
1939+ ):
1940+ bar_dict [name ] = values
1941+
1942+ # Colors
1943+ colors = color_picker (
1944+ color_type = bar_type ,
1945+ kwargs = bar_kwargs ,
1946+ elements = ticks_to_plot if bar_type == 'contrast' else ticks ,
1947+ color_col = color_col ,
1948+ show_pairs = show_pairs ,
1949+ color_palette = plot_palette_raw
1950+ )
1951+ if bar_type == 'contrast' and extra_delta is not None :
1952+ colors .append ('black' )
1953+ bar_dict ['colors' ] = colors
1954+
1955+ return bar_dict , bar_kwargs
0 commit comments