Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions dabest/_dabest_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# %% ../nbs/API/dabest_object.ipynb 5
# Import standard data science libraries
import warnings
from numpy import array, repeat, random, issubdtype, number
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -62,7 +63,6 @@ def __init__(

# Check if there is NaN under any of the paired settings
if self.__is_paired and self.__output_data.isnull().values.any():
import warnings
warn1 = f"NaN values detected under paired setting and removed,"
warn2 = f" please check your data."
warnings.warn(warn1 + warn2)
Expand Down Expand Up @@ -500,10 +500,10 @@ def _check_errors(self, x, y, idx, experiment, experiment_label, x1_level):
if x is None:
error_msg = "If `delta2` is True. `x` parameter cannot be None. String or list expected"
raise ValueError(error_msg)

if self.__proportional:
err0 = "`proportional` and `delta2` cannot be True at the same time."
raise ValueError(err0)
mes1 = "Only mean_diff is supported for proportional data when `delta2` is True"
warnings.warn(message=mes1, category=UserWarning)

# idx should not be specified
if idx:
Expand Down Expand Up @@ -581,8 +581,6 @@ def _get_plot_data(self, x, y, all_plot_groups):
"""
# Check if there is NaN under any of the paired settings
if self.__is_paired is not None and self.__output_data.isnull().values.any():
print("Nan")
import warnings
warn1 = f"NaN values detected under paired setting and removed,"
warn2 = f" please check your data."
warnings.warn(warn1 + warn2)
Expand Down Expand Up @@ -634,7 +632,6 @@ def _get_plot_data(self, x, y, all_plot_groups):

# Check if there is NaN under any of the paired settings
if self.__is_paired is not None and self.__output_data.isnull().values.any():
import warnings
warn1 = f"NaN values detected under paired setting and removed,"
warn2 = f" please check your data."
warnings.warn(warn1 + warn2)
Expand Down
4 changes: 3 additions & 1 deletion dabest/_effsize_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ def _check_errors(self, control, test):
raise ValueError(err1)

if self.__proportional and self.__effect_size not in ["mean_diff", "cohens_h"]:
err1 = "`proportional` is True; therefore effect size other than mean_diff and cohens_h is not defined."
err1 = "`proportional` is True; therefore effect size other than mean_diff and cohens_h is not defined." + \
"If you are calculating deltas' g, it's the same as delta-delta when `proportional` is True"
raise ValueError(err1)

if self.__proportional and (
Expand Down Expand Up @@ -884,6 +885,7 @@ def __pre_calc(self):
self.__is_paired,
self.__resamples,
self.__random_seed,
self.__proportional,
)

for j, current_tuple in enumerate(idx):
Expand Down
90 changes: 48 additions & 42 deletions dabest/_stats_tools/confint_2group_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,24 +159,26 @@ def compute_bootstrapped_diff(

return out

@njit(cache=True) # parallelization must be turned off for random number generation
def delta2_bootstrap_loop(x1, x2, x3, x4, resamples, pooled_sd, rng_seed, is_paired):

@njit(cache=True)
def delta2_bootstrap_loop(x1, x2, x3, x4, resamples, pooled_sd, rng_seed, is_paired, proportional=False):
"""
Compute bootstrapped differences for delta-delta, handling both regular and proportional data
"""
np.random.seed(rng_seed)
out_delta_g = np.empty(resamples)
deltadelta = np.empty(resamples)
out_delta_g = np.empty(resamples)

n1, n2, n3, n4 = len(x1), len(x2), len(x3), len(x4)
if is_paired:
if n1 != n2 or n3 != n4:
raise ValueError("Each control group must have the same length as its corresponding test group in paired analysis.")

if is_paired and (n1 != n2 or n3 != n4):
raise ValueError("Each control group must have the same length as its corresponding test group in paired analysis.")

# Bootstrapping
for i in range(resamples):
# Paired or unpaired resampling
if is_paired:
indices_1 = np.random.choice(len(x1),len(x1))
indices_2 = np.random.choice(len(x3),len(x3))
indices_1 = np.random.choice(len(x1), len(x1))
indices_2 = np.random.choice(len(x3), len(x3))
x1_sample, x2_sample = x1[indices_1], x2[indices_1]
x3_sample, x4_sample = x3[indices_2], x4[indices_2]
else:
Expand All @@ -187,13 +189,14 @@ def delta2_bootstrap_loop(x1, x2, x3, x4, resamples, pooled_sd, rng_seed, is_pai
x1_sample, x2_sample = x1[indices_1], x2[indices_2]
x3_sample, x4_sample = x3[indices_3], x4[indices_4]

# Calculating deltas
# Calculate deltas
delta_1 = np.mean(x2_sample) - np.mean(x1_sample)
delta_2 = np.mean(x4_sample) - np.mean(x3_sample)
delta_delta = delta_2 - delta_1

deltadelta[i] = delta_delta
out_delta_g[i] = delta_delta / pooled_sd

out_delta_g[i] = delta_delta if proportional else delta_delta/pooled_sd

return out_delta_g, deltadelta

Expand All @@ -204,39 +207,42 @@ def compute_delta2_bootstrapped_diff(
x3: np.ndarray, # Control group 2
x4: np.ndarray, # Test group 2
is_paired: str = None,
resamples: int = 5000, # The number of bootstrap resamples to be taken for the calculation of the confidence interval limits.
random_seed: int = 12345, # `random_seed` is used to seed the random number generator during bootstrap resampling. This ensures that the confidence intervals reported are replicable.
) -> (
tuple
): # bootstraped result and empirical result of deltas' g, and the bootstraped result of delta-delta
resamples: int = 5000,
random_seed: int = 12345,
proportional: bool = False
) -> tuple:
"""
Bootstraps the effect size deltas' g.

Bootstraps the effect size deltas' g or proportional delta-delta
"""

x1, x2, x3, x4 = map(np.asarray, [x1, x2, x3, x4])

# Calculating pooled sample standard deviation
stds = [np.std(x) for x in [x1, x2, x3, x4]]
ns = [len(x) for x in [x1, x2, x3, x4]]

sd_numerator = sum((n - 1) * s**2 for n, s in zip(ns, stds))
sd_denominator = sum(n - 1 for n in ns)

# Avoid division by zero
if sd_denominator == 0:
raise ValueError("Insufficient data to compute pooled standard deviation.")

pooled_sample_sd = np.sqrt(sd_numerator / sd_denominator)

# Ensure pooled_sample_sd is not NaN or zero (to avoid division by zero later)
if np.isnan(pooled_sample_sd) or pooled_sample_sd == 0:
raise ValueError("Pooled sample standard deviation is NaN or zero.")

out_delta_g, deltadelta = delta2_bootstrap_loop(x1, x2, x3, x4, resamples, pooled_sample_sd, random_seed, is_paired)

# Empirical delta_g calculation
delta_g = ((np.mean(x4) - np.mean(x3)) - (np.mean(x2) - np.mean(x1))) / pooled_sample_sd

if proportional:
# For proportional data, pass 1.0 as dummy pooled_sd (won't be used)
out_delta_g, deltadelta = delta2_bootstrap_loop(
x1, x2, x3, x4, resamples, 1.0, random_seed, is_paired, proportional=True
)
# For proportional data, delta_g is the empirical delta-delta
delta_g = ((np.mean(x4) - np.mean(x3)) - (np.mean(x2) - np.mean(x1)))
else:
# Calculate pooled sample standard deviation for non-proportional data
stds = [np.std(x) for x in [x1, x2, x3, x4]]
ns = [len(x) for x in [x1, x2, x3, x4]]

sd_numerator = sum((n - 1) * s**2 for n, s in zip(ns, stds))
sd_denominator = sum(n - 1 for n in ns)

if sd_denominator == 0:
raise ValueError("Insufficient data to compute pooled standard deviation.")

pooled_sample_sd = np.sqrt(sd_numerator / sd_denominator)

if np.isnan(pooled_sample_sd) or pooled_sample_sd == 0:
raise ValueError("Pooled sample standard deviation is NaN or zero.")

out_delta_g, deltadelta = delta2_bootstrap_loop(
x1, x2, x3, x4, resamples, pooled_sample_sd, random_seed, is_paired, proportional=False
)
delta_g = ((np.mean(x4) - np.mean(x3)) - (np.mean(x2) - np.mean(x1))) / pooled_sample_sd

return out_delta_g, delta_g, deltadelta

Expand Down
8 changes: 5 additions & 3 deletions dabest/misc_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,12 +590,12 @@ def get_color_palette(
if color_by_subgroups:
plot_palette_raw = dict()
plot_palette_contrast = dict()
# plot_palette_bar set to None because currently there is no empty_circle toggle for proportion plots
plot_palette_bar = None
plot_palette_bar = dict()
for i in range(len(idx)):
for names_i in idx[i]:
plot_palette_raw[names_i] = swarm_colors[i]
plot_palette_contrast[names_i] = contrast_colors[i]
plot_palette_bar[names_i] = bar_color[i]
else:
plot_palette_raw = dict(zip(categories, swarm_colors))
plot_palette_contrast = dict(zip(categories, contrast_colors))
Expand All @@ -612,11 +612,12 @@ def get_color_palette(
if color_by_subgroups:
plot_palette_raw = dict()
plot_palette_contrast = dict()
plot_palette_bar = None # plot_palette_bar set to None because currently there is no empty_circle toggle for proportion plots
plot_palette_bar = dict()
for i in range(len(idx)):
for names_i in idx[i]:
plot_palette_raw[names_i] = swarm_colors[i]
plot_palette_contrast[names_i] = contrast_colors[i]
plot_palette_bar[names_i] = bar_color[i]
else:
plot_palette_raw = dict(zip(names, swarm_colors))
plot_palette_contrast = dict(zip(names, contrast_colors))
Expand Down Expand Up @@ -1018,6 +1019,7 @@ def lookup_value(text):
ticks_with_counts.append(f"{t}\n(N={value})")

fontsize_rawxlabel = plot_kwargs.get("fontsize_rawxlabel")
set_major_loc_method(plt.FixedLocator(get_ticks()))
set_label(ticks_with_counts, fontsize=fontsize_rawxlabel)

# Ensure ticks are at the correct locations
Expand Down
47 changes: 37 additions & 10 deletions dabest/plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,14 +731,17 @@ def sankeydiag(
right_idx = []
# Design for Sankey Flow Diagram
sankey_idx = (
[
(control, test)
for i in idx
for control, test in zip(i[:], (i[1:] + (i[0],)))
]
if flow
else temp_idx
)
[
(control, test)
for i in idx
for control, test in zip(
i[:],
(tuple(i[1:]) + (i[0],)) if isinstance(i, tuple) else (list(i[1:]) + [i[0]])
)
]
if flow
else temp_idx
)
for i in sankey_idx:
left_idx.append(i[0])
right_idx.append(i[1])
Expand Down Expand Up @@ -2065,6 +2068,7 @@ def barplotter(
plot_data: pd.DataFrame,
bar_color: str,
plot_palette_bar: dict,
color_col: str,
plot_kwargs: dict,
barplot_kwargs: dict,
horizontal: bool
Expand All @@ -2088,6 +2092,8 @@ def barplotter(
Color of the bar.
plot_palette_bar : dict
Dictionary of colors used in the bar plot.
color_col : str
Column name of the color column.
plot_kwargs : dict
Keyword arguments for the plot.
barplot_kwargs : dict
Expand All @@ -2102,7 +2108,26 @@ def barplotter(
else:
x_var, y_var, orient = all_plot_groups, np.ones(len(all_plot_groups)), "v"

bar1_df = pd.DataFrame({xvar: x_var, "proportion": y_var})
# Create bar1_df with basic columns
bar1_df = pd.DataFrame({
xvar: x_var,
"proportion": y_var
})

# Handle colors
if color_col:
# Get first color value for each group
color_mapping = plot_data.groupby(xvar, observed=False)[color_col].first()
bar1_df[color_col] = [color_mapping.get(group) for group in all_plot_groups]

# Map colors, defaulting to bar_color if no match
edge_colors = [
plot_palette_bar.get(hue_val, bar_color)
for hue_val in bar1_df[color_col]
]
else:
edge_colors = bar_color


bar1 = sns.barplot(
data=bar1_df,
Expand All @@ -2112,7 +2137,7 @@ def barplotter(
order=all_plot_groups,
linewidth=2,
facecolor=(1, 1, 1, 0),
edgecolor=bar_color,
edgecolor=edge_colors,
zorder=1,
orient=orient,
)
Expand All @@ -2123,6 +2148,8 @@ def barplotter(
ax=rawdata_axes,
order=all_plot_groups,
palette=plot_palette_bar,
hue=color_col,
dodge=False,
zorder=1,
orient=orient,
**barplot_kwargs
Expand Down
1 change: 1 addition & 0 deletions dabest/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
plot_data = plot_data,
bar_color = bar_color,
plot_palette_bar = plot_palette_bar,
color_col = color_col,
plot_kwargs = plot_kwargs,
barplot_kwargs = barplot_kwargs,
horizontal = horizontal,
Expand Down
Loading
Loading