diff --git a/dabest/misc_tools.py b/dabest/misc_tools.py index 8c7d0e96..05193fe1 100644 --- a/dabest/misc_tools.py +++ b/dabest/misc_tools.py @@ -542,10 +542,8 @@ def get_plot_groups(is_paired, idx, proportional, all_plot_groups): def add_counts_to_ticks(plot_data, xvar, yvar, rawdata_axes, plot_kwargs): # Add the counts to the rawdata axes xticks. counts = plot_data.groupby(xvar).count()[yvar] - ticks_with_counts = [] - ticks_loc = rawdata_axes.get_xticks() - rawdata_axes.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(ticks_loc)) - def lookup_value(text, counts): + + def lookup_value(text): try: return str(counts.loc[text]) except KeyError: @@ -553,30 +551,24 @@ def lookup_value(text, counts): numeric_key = pd.to_numeric(text, errors='coerce') if pd.notnull(numeric_key): return str(counts.loc[numeric_key]) - else: - raise ValueError except (ValueError, KeyError): - print(f"Key '{text}' not found in counts.") - return "N/A" - for xticklab in rawdata_axes.xaxis.get_ticklabels(): - t = xticklab.get_text() - # Extract the text after the last newline, if present - if t.rfind("\n") != -1: - te = t[t.rfind("\n") + len("\n"):] - value = lookup_value(te, counts) - te = t - else: - te = t - value = lookup_value(te, counts) - - # Append the modified tick label with the count to the list - ticks_with_counts.append(f"{te}\nN = {value}") + pass + print(f"Key '{text}' not found in counts.") + return "N/A" + ticks_with_counts = [] + for xticklab in rawdata_axes.get_xticklabels(): + t = xticklab.get_text() + te = t.split('\n')[-1] # Get the last line of the label + value = lookup_value(te) + ticks_with_counts.append(f"{t}\nN = {value}") - if plot_kwargs["fontsize_rawxlabel"] is not None: - fontsize_rawxlabel = plot_kwargs["fontsize_rawxlabel"] + fontsize_rawxlabel = plot_kwargs.get("fontsize_rawxlabel") rawdata_axes.set_xticklabels(ticks_with_counts, fontsize=fontsize_rawxlabel) + # Ensure ticks are at the correct locations + rawdata_axes.xaxis.set_major_locator(plt.FixedLocator(rawdata_axes.get_xticks())) + def extract_contrast_plotting_ticks(is_paired, show_pairs, two_col_sankey, plot_groups, idx, sankey_control_group): diff --git a/dabest/plotter.py b/dabest/plotter.py index f2c30a1a..59c264d0 100644 --- a/dabest/plotter.py +++ b/dabest/plotter.py @@ -246,7 +246,6 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs): if color_col is None: rawdata_plot.legend().set_visible(False) - else: # Plot the raw data as a barplot. barplotter( @@ -589,5 +588,4 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs): plt.rcParams[parameter] = original_rcParams[parameter] # Return the figure. - fig.show() return fig diff --git a/nbs/API/misc_tools.ipynb b/nbs/API/misc_tools.ipynb index dd9fbdd0..d79fcd23 100644 --- a/nbs/API/misc_tools.ipynb +++ b/nbs/API/misc_tools.ipynb @@ -595,10 +595,8 @@ "def add_counts_to_ticks(plot_data, xvar, yvar, rawdata_axes, plot_kwargs):\n", " # Add the counts to the rawdata axes xticks.\n", " counts = plot_data.groupby(xvar).count()[yvar]\n", - " ticks_with_counts = []\n", - " ticks_loc = rawdata_axes.get_xticks()\n", - " rawdata_axes.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(ticks_loc))\n", - " def lookup_value(text, counts):\n", + " \n", + " def lookup_value(text):\n", " try:\n", " return str(counts.loc[text])\n", " except KeyError:\n", @@ -606,30 +604,24 @@ " numeric_key = pd.to_numeric(text, errors='coerce')\n", " if pd.notnull(numeric_key):\n", " return str(counts.loc[numeric_key])\n", - " else:\n", - " raise ValueError\n", " except (ValueError, KeyError):\n", - " print(f\"Key '{text}' not found in counts.\")\n", - " return \"N/A\"\n", - " for xticklab in rawdata_axes.xaxis.get_ticklabels():\n", - " t = xticklab.get_text()\n", - " # Extract the text after the last newline, if present\n", - " if t.rfind(\"\\n\") != -1:\n", - " te = t[t.rfind(\"\\n\") + len(\"\\n\"):]\n", - " value = lookup_value(te, counts)\n", - " te = t\n", - " else:\n", - " te = t\n", - " value = lookup_value(te, counts)\n", - "\n", - " # Append the modified tick label with the count to the list\n", - " ticks_with_counts.append(f\"{te}\\nN = {value}\")\n", + " pass\n", + " print(f\"Key '{text}' not found in counts.\")\n", + " return \"N/A\"\n", "\n", + " ticks_with_counts = []\n", + " for xticklab in rawdata_axes.get_xticklabels():\n", + " t = xticklab.get_text()\n", + " te = t.split('\\n')[-1] # Get the last line of the label\n", + " value = lookup_value(te)\n", + " ticks_with_counts.append(f\"{t}\\nN = {value}\")\n", "\n", - " if plot_kwargs[\"fontsize_rawxlabel\"] is not None:\n", - " fontsize_rawxlabel = plot_kwargs[\"fontsize_rawxlabel\"]\n", + " fontsize_rawxlabel = plot_kwargs.get(\"fontsize_rawxlabel\")\n", " rawdata_axes.set_xticklabels(ticks_with_counts, fontsize=fontsize_rawxlabel)\n", "\n", + " # Ensure ticks are at the correct locations\n", + " rawdata_axes.xaxis.set_major_locator(plt.FixedLocator(rawdata_axes.get_xticks()))\n", + "\n", "\n", "def extract_contrast_plotting_ticks(is_paired, show_pairs, two_col_sankey, plot_groups, idx, sankey_control_group):\n", "\n", diff --git a/nbs/API/plotter.ipynb b/nbs/API/plotter.ipynb index 8b5f365c..a14a71cd 100644 --- a/nbs/API/plotter.ipynb +++ b/nbs/API/plotter.ipynb @@ -303,7 +303,6 @@ " if color_col is None:\n", " rawdata_plot.legend().set_visible(False)\n", "\n", - "\n", " else:\n", " # Plot the raw data as a barplot.\n", " barplotter(\n", @@ -646,7 +645,6 @@ " plt.rcParams[parameter] = original_rcParams[parameter]\n", "\n", " # Return the figure.\n", - " fig.show()\n", " return fig" ] } diff --git a/test.py b/test.py deleted file mode 100644 index 105136f1..00000000 --- a/test.py +++ /dev/null @@ -1,88 +0,0 @@ -import numpy as np -from scipy.stats import norm -import pandas as pd -import matplotlib as mpl -import os -from pathlib import Path - -import matplotlib.ticker as Ticker -import matplotlib.pyplot as plt - -from dabest._api import load - -import dabest - -columns = [1, 2.0] -columns_str = ["1", "2.0"] -# create a test database -N = 100 -df = pd.DataFrame(np.vstack([np.random.normal(loc=i, size=(N,)) for i in range(len(columns))]).T, columns=columns_str) -females = np.repeat("Female", N / 2).tolist() -males = np.repeat("Male", N / 2).tolist() -df['gender'] = females + males - -# Add an `id` column for paired data plotting. -df['ID'] = pd.Series(range(1, N + 1)) - - -db = dabest.load(data=df, idx=columns_str, paired="baseline", id_col="ID") -print(db.mean_diff) -db.mean_diff.plot(); - -# def create_demo_dataset(seed=9999, N=20): -# import numpy as np -# import pandas as pd -# from scipy.stats import norm # Used in generation of populations. - -# np.random.seed(9999) # Fix the seed so the results are replicable. -# # pop_size = 10000 # Size of each population. - -# # Create samples -# c1 = norm.rvs(loc=3, scale=0.4, size=N) -# c2 = norm.rvs(loc=3.5, scale=0.75, size=N) -# c3 = norm.rvs(loc=3.25, scale=0.4, size=N) - -# t1 = norm.rvs(loc=3.5, scale=0.5, size=N) -# t2 = norm.rvs(loc=2.5, scale=0.6, size=N) -# t3 = norm.rvs(loc=3, scale=0.75, size=N) -# t4 = norm.rvs(loc=3.5, scale=0.75, size=N) -# t5 = norm.rvs(loc=3.25, scale=0.4, size=N) -# t6 = norm.rvs(loc=3.25, scale=0.4, size=N) - -# # Add a `gender` column for coloring the data. -# females = np.repeat("Female", N / 2).tolist() -# males = np.repeat("Male", N / 2).tolist() -# gender = females + males - -# # Add an `id` column for paired data plotting. -# id_col = pd.Series(range(1, N + 1)) - -# # Combine samples and gender into a DataFrame. -# df = pd.DataFrame( -# { -# "Control 1": c1, -# "Test 1": t1, -# "Control 2": c2, -# "Test 2": t2, -# "Control 3": c3, -# "Test 3": t3, -# "Test 4": t4, -# "Test 5": t5, -# "Test 6": t6, -# "Gender": gender, -# "ID": id_col, -# } -# ) - -# return df - - -# df = create_demo_dataset() - -# two_groups_unpaired = load(df, idx=("Control 1", "Test 1")) - -# two_groups_paired = load( -# df, idx=("Control 1", "Test 1"), paired="baseline", id_col="ID" -# ) - -# two_groups_unpaired.mean_diff.plot()