|
9 | 9 | __all__ = ['halfviolin', 'get_swarm_spans', 'error_bar', 'check_data_matches_labels', 'normalize_dict', 'width_determine', |
10 | 10 | 'single_sankey', 'sankeydiag', 'summary_bars_plotter', 'contrast_bars_plotter', 'swarm_bars_plotter', |
11 | 11 | 'delta_text_plotter', 'DeltaDotsPlotter', 'slopegraph_plotter', 'plot_minimeta_or_deltadelta_violins', |
12 | | - 'effect_size_curve_plotter', 'gridkey_plotter', 'barplotter', 'table_for_horizontal_plots', 'swarmplot', |
13 | | - 'SwarmPlot'] |
| 12 | + 'effect_size_curve_plotter', 'gridkey_plotter', 'barplotter', 'table_for_horizontal_plots', |
| 13 | + 'add_counts_to_prop_plots', 'swarmplot', 'SwarmPlot'] |
14 | 14 |
|
15 | 15 | # %% ../nbs/API/plot_tools.ipynb 4 |
16 | 16 | import math |
@@ -1919,9 +1919,8 @@ def barplotter(xvar, yvar, all_plot_groups, rawdata_axes, plot_data, bar_color, |
1919 | 1919 | horizontal : bool |
1920 | 1920 | If the plot is horizontal. |
1921 | 1921 | """ |
1922 | | - |
1923 | | - x_label, y_label = rawdata_axes.get_xlabel(), rawdata_axes.get_ylabel() |
1924 | 1922 |
|
| 1923 | + x_label, y_label = rawdata_axes.get_xlabel(), rawdata_axes.get_ylabel() |
1925 | 1924 | if horizontal: |
1926 | 1925 | x_var, y_var, orient = np.ones(len(all_plot_groups)), all_plot_groups, "h" |
1927 | 1926 | else: |
@@ -1969,7 +1968,7 @@ def barplotter(xvar, yvar, all_plot_groups, rawdata_axes, plot_data, bar_color, |
1969 | 1968 | centre = x + width / 2.0 |
1970 | 1969 | bar.set_x(centre - bar_width / 2.0) |
1971 | 1970 | bar.set_width(bar_width) |
1972 | | - |
| 1971 | + |
1973 | 1972 | # reset the x and y labels |
1974 | 1973 | rawdata_axes.set_xlabel(x_label) |
1975 | 1974 | rawdata_axes.set_ylabel(y_label) |
@@ -2047,6 +2046,61 @@ def table_for_horizontal_plots(effectsize_df, ax, contrast_axes, ticks_to_plot, |
2047 | 2046 | ax.set_xlabel(label, fontsize=fontsize_label) # Set the x-axis label - hardcoded for now |
2048 | 2047 | sns.despine(ax=ax, left=True, bottom=True) |
2049 | 2048 |
|
| 2049 | + |
| 2050 | +def add_counts_to_prop_plots(plot_data, xvar, yvar, rawdata_axes, horizontal, is_paired, prop_sample_counts_kwargs): |
| 2051 | + """ |
| 2052 | + Add counts to the proportion plots. |
| 2053 | +
|
| 2054 | + Parameters |
| 2055 | + ---------- |
| 2056 | + plot_data : object (Dataframe) |
| 2057 | + Dataframe of the plot data. |
| 2058 | + xvar : str |
| 2059 | + Column name of the x variable. |
| 2060 | + yvar : str |
| 2061 | + Column name of the y variable. |
| 2062 | + rawdata_axes : object |
| 2063 | + Matplotlib axis object to plot on. |
| 2064 | + horizontal : bool |
| 2065 | + If the plot is horizontal. |
| 2066 | + is_paired : bool |
| 2067 | + Whether the data is paired. |
| 2068 | + prop_sample_counts_kwargs : dict |
| 2069 | + Keyword arguments for the sample counts. |
| 2070 | + """ |
| 2071 | + |
| 2072 | + # Group orders |
| 2073 | + if isinstance(plot_data[xvar].dtype, pd.CategoricalDtype): |
| 2074 | + sample_size_text_order = pd.unique(plot_data[xvar]).categories |
| 2075 | + else: |
| 2076 | + sample_size_text_order = pd.unique(plot_data[xvar]) |
| 2077 | + |
| 2078 | + # Get the sample size values |
| 2079 | + ones, zeros = plot_data[plot_data[yvar] == 1], plot_data[plot_data[yvar] == 0] |
| 2080 | + |
| 2081 | + sample_size_val1 = ones.groupby(xvar, observed=False)[yvar].count().reindex(index=sample_size_text_order) |
| 2082 | + sample_size_val0 = zeros.groupby(xvar, observed=False)[yvar].count().reindex(index=sample_size_text_order) |
| 2083 | + |
| 2084 | + fontsize = 8 if horizontal else 10 |
| 2085 | + fontsize -= 2 if is_paired else 0 |
| 2086 | + |
| 2087 | + if "fontsize" not in prop_sample_counts_kwargs.keys(): |
| 2088 | + fontsize = 8 if horizontal else 10 |
| 2089 | + fontsize -= 2 if is_paired else 0 |
| 2090 | + prop_sample_counts_kwargs.update({'fontsize': fontsize}) |
| 2091 | + |
| 2092 | + for sample_text_x, sample_text_y0, sample_text_y1 in zip( |
| 2093 | + np.arange(0,len(sample_size_text_order)+1,1), |
| 2094 | + sample_size_val0, |
| 2095 | + sample_size_val1, |
| 2096 | + ): |
| 2097 | + if horizontal: |
| 2098 | + rawdata_axes.text(0.05, sample_text_x, sample_text_y1, **prop_sample_counts_kwargs) |
| 2099 | + rawdata_axes.text(0.95, sample_text_x, sample_text_y0, **prop_sample_counts_kwargs) |
| 2100 | + else: |
| 2101 | + rawdata_axes.text(sample_text_x, 0.05, sample_text_y1, **prop_sample_counts_kwargs) |
| 2102 | + rawdata_axes.text(sample_text_x, 0.95, sample_text_y0, **prop_sample_counts_kwargs) |
| 2103 | + |
2050 | 2104 | # %% ../nbs/API/plot_tools.ipynb 6 |
2051 | 2105 | def swarmplot( |
2052 | 2106 | data: pd.DataFrame, |
|
0 commit comments