Skip to content

Commit 87b8df2

Browse files
authored
Merge pull request #195 from ACCLAB/feat-prop-deltadelta
Make delta-delta feature for proportional plots
2 parents f0730ab + fab4061 commit 87b8df2

18 files changed

+248
-142
lines changed

dabest/_dabest_object.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# %% ../nbs/API/dabest_object.ipynb 5
99
# Import standard data science libraries
10+
import warnings
1011
from numpy import array, repeat, random, issubdtype, number
1112
import numpy as np
1213
import pandas as pd
@@ -62,7 +63,6 @@ def __init__(
6263

6364
# Check if there is NaN under any of the paired settings
6465
if self.__is_paired and self.__output_data.isnull().values.any():
65-
import warnings
6666
warn1 = f"NaN values detected under paired setting and removed,"
6767
warn2 = f" please check your data."
6868
warnings.warn(warn1 + warn2)
@@ -500,10 +500,10 @@ def _check_errors(self, x, y, idx, experiment, experiment_label, x1_level):
500500
if x is None:
501501
error_msg = "If `delta2` is True. `x` parameter cannot be None. String or list expected"
502502
raise ValueError(error_msg)
503-
503+
504504
if self.__proportional:
505-
err0 = "`proportional` and `delta2` cannot be True at the same time."
506-
raise ValueError(err0)
505+
mes1 = "Only mean_diff is supported for proportional data when `delta2` is True"
506+
warnings.warn(message=mes1, category=UserWarning)
507507

508508
# idx should not be specified
509509
if idx:
@@ -581,8 +581,6 @@ def _get_plot_data(self, x, y, all_plot_groups):
581581
"""
582582
# Check if there is NaN under any of the paired settings
583583
if self.__is_paired is not None and self.__output_data.isnull().values.any():
584-
print("Nan")
585-
import warnings
586584
warn1 = f"NaN values detected under paired setting and removed,"
587585
warn2 = f" please check your data."
588586
warnings.warn(warn1 + warn2)
@@ -634,7 +632,6 @@ def _get_plot_data(self, x, y, all_plot_groups):
634632

635633
# Check if there is NaN under any of the paired settings
636634
if self.__is_paired is not None and self.__output_data.isnull().values.any():
637-
import warnings
638635
warn1 = f"NaN values detected under paired setting and removed,"
639636
warn2 = f" please check your data."
640637
warnings.warn(warn1 + warn2)

dabest/_effsize_objects.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,8 @@ def _check_errors(self, control, test):
257257
raise ValueError(err1)
258258

259259
if self.__proportional and self.__effect_size not in ["mean_diff", "cohens_h"]:
260-
err1 = "`proportional` is True; therefore effect size other than mean_diff and cohens_h is not defined."
260+
err1 = "`proportional` is True; therefore effect size other than mean_diff and cohens_h is not defined." + \
261+
"If you are calculating deltas' g, it's the same as delta-delta when `proportional` is True"
261262
raise ValueError(err1)
262263

263264
if self.__proportional and (
@@ -884,6 +885,7 @@ def __pre_calc(self):
884885
self.__is_paired,
885886
self.__resamples,
886887
self.__random_seed,
888+
self.__proportional,
887889
)
888890

889891
for j, current_tuple in enumerate(idx):

dabest/_stats_tools/confint_2group_diff.py

Lines changed: 48 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -159,24 +159,26 @@ def compute_bootstrapped_diff(
159159

160160
return out
161161

162-
@njit(cache=True) # parallelization must be turned off for random number generation
163-
def delta2_bootstrap_loop(x1, x2, x3, x4, resamples, pooled_sd, rng_seed, is_paired):
162+
163+
@njit(cache=True)
164+
def delta2_bootstrap_loop(x1, x2, x3, x4, resamples, pooled_sd, rng_seed, is_paired, proportional=False):
165+
"""
166+
Compute bootstrapped differences for delta-delta, handling both regular and proportional data
167+
"""
164168
np.random.seed(rng_seed)
165-
out_delta_g = np.empty(resamples)
166169
deltadelta = np.empty(resamples)
170+
out_delta_g = np.empty(resamples)
167171

168172
n1, n2, n3, n4 = len(x1), len(x2), len(x3), len(x4)
169-
if is_paired:
170-
if n1 != n2 or n3 != n4:
171-
raise ValueError("Each control group must have the same length as its corresponding test group in paired analysis.")
172-
173+
if is_paired and (n1 != n2 or n3 != n4):
174+
raise ValueError("Each control group must have the same length as its corresponding test group in paired analysis.")
173175

174176
# Bootstrapping
175177
for i in range(resamples):
176178
# Paired or unpaired resampling
177179
if is_paired:
178-
indices_1 = np.random.choice(len(x1),len(x1))
179-
indices_2 = np.random.choice(len(x3),len(x3))
180+
indices_1 = np.random.choice(len(x1), len(x1))
181+
indices_2 = np.random.choice(len(x3), len(x3))
180182
x1_sample, x2_sample = x1[indices_1], x2[indices_1]
181183
x3_sample, x4_sample = x3[indices_2], x4[indices_2]
182184
else:
@@ -187,13 +189,14 @@ def delta2_bootstrap_loop(x1, x2, x3, x4, resamples, pooled_sd, rng_seed, is_pai
187189
x1_sample, x2_sample = x1[indices_1], x2[indices_2]
188190
x3_sample, x4_sample = x3[indices_3], x4[indices_4]
189191

190-
# Calculating deltas
192+
# Calculate deltas
191193
delta_1 = np.mean(x2_sample) - np.mean(x1_sample)
192194
delta_2 = np.mean(x4_sample) - np.mean(x3_sample)
193195
delta_delta = delta_2 - delta_1
194-
196+
195197
deltadelta[i] = delta_delta
196-
out_delta_g[i] = delta_delta / pooled_sd
198+
199+
out_delta_g[i] = delta_delta if proportional else delta_delta/pooled_sd
197200

198201
return out_delta_g, deltadelta
199202

@@ -204,39 +207,42 @@ def compute_delta2_bootstrapped_diff(
204207
x3: np.ndarray, # Control group 2
205208
x4: np.ndarray, # Test group 2
206209
is_paired: str = None,
207-
resamples: int = 5000, # The number of bootstrap resamples to be taken for the calculation of the confidence interval limits.
208-
random_seed: int = 12345, # `random_seed` is used to seed the random number generator during bootstrap resampling. This ensures that the confidence intervals reported are replicable.
209-
) -> (
210-
tuple
211-
): # bootstraped result and empirical result of deltas' g, and the bootstraped result of delta-delta
210+
resamples: int = 5000,
211+
random_seed: int = 12345,
212+
proportional: bool = False
213+
) -> tuple:
212214
"""
213-
Bootstraps the effect size deltas' g.
214-
215+
Bootstraps the effect size deltas' g or proportional delta-delta
215216
"""
216-
217217
x1, x2, x3, x4 = map(np.asarray, [x1, x2, x3, x4])
218-
219-
# Calculating pooled sample standard deviation
220-
stds = [np.std(x) for x in [x1, x2, x3, x4]]
221-
ns = [len(x) for x in [x1, x2, x3, x4]]
222-
223-
sd_numerator = sum((n - 1) * s**2 for n, s in zip(ns, stds))
224-
sd_denominator = sum(n - 1 for n in ns)
225-
226-
# Avoid division by zero
227-
if sd_denominator == 0:
228-
raise ValueError("Insufficient data to compute pooled standard deviation.")
229-
230-
pooled_sample_sd = np.sqrt(sd_numerator / sd_denominator)
231-
232-
# Ensure pooled_sample_sd is not NaN or zero (to avoid division by zero later)
233-
if np.isnan(pooled_sample_sd) or pooled_sample_sd == 0:
234-
raise ValueError("Pooled sample standard deviation is NaN or zero.")
235-
236-
out_delta_g, deltadelta = delta2_bootstrap_loop(x1, x2, x3, x4, resamples, pooled_sample_sd, random_seed, is_paired)
237-
238-
# Empirical delta_g calculation
239-
delta_g = ((np.mean(x4) - np.mean(x3)) - (np.mean(x2) - np.mean(x1))) / pooled_sample_sd
218+
219+
if proportional:
220+
# For proportional data, pass 1.0 as dummy pooled_sd (won't be used)
221+
out_delta_g, deltadelta = delta2_bootstrap_loop(
222+
x1, x2, x3, x4, resamples, 1.0, random_seed, is_paired, proportional=True
223+
)
224+
# For proportional data, delta_g is the empirical delta-delta
225+
delta_g = ((np.mean(x4) - np.mean(x3)) - (np.mean(x2) - np.mean(x1)))
226+
else:
227+
# Calculate pooled sample standard deviation for non-proportional data
228+
stds = [np.std(x) for x in [x1, x2, x3, x4]]
229+
ns = [len(x) for x in [x1, x2, x3, x4]]
230+
231+
sd_numerator = sum((n - 1) * s**2 for n, s in zip(ns, stds))
232+
sd_denominator = sum(n - 1 for n in ns)
233+
234+
if sd_denominator == 0:
235+
raise ValueError("Insufficient data to compute pooled standard deviation.")
236+
237+
pooled_sample_sd = np.sqrt(sd_numerator / sd_denominator)
238+
239+
if np.isnan(pooled_sample_sd) or pooled_sample_sd == 0:
240+
raise ValueError("Pooled sample standard deviation is NaN or zero.")
241+
242+
out_delta_g, deltadelta = delta2_bootstrap_loop(
243+
x1, x2, x3, x4, resamples, pooled_sample_sd, random_seed, is_paired, proportional=False
244+
)
245+
delta_g = ((np.mean(x4) - np.mean(x3)) - (np.mean(x2) - np.mean(x1))) / pooled_sample_sd
240246

241247
return out_delta_g, delta_g, deltadelta
242248

dabest/misc_tools.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -590,12 +590,12 @@ def get_color_palette(
590590
if color_by_subgroups:
591591
plot_palette_raw = dict()
592592
plot_palette_contrast = dict()
593-
# plot_palette_bar set to None because currently there is no empty_circle toggle for proportion plots
594-
plot_palette_bar = None
593+
plot_palette_bar = dict()
595594
for i in range(len(idx)):
596595
for names_i in idx[i]:
597596
plot_palette_raw[names_i] = swarm_colors[i]
598597
plot_palette_contrast[names_i] = contrast_colors[i]
598+
plot_palette_bar[names_i] = bar_color[i]
599599
else:
600600
plot_palette_raw = dict(zip(categories, swarm_colors))
601601
plot_palette_contrast = dict(zip(categories, contrast_colors))
@@ -612,11 +612,12 @@ def get_color_palette(
612612
if color_by_subgroups:
613613
plot_palette_raw = dict()
614614
plot_palette_contrast = dict()
615-
plot_palette_bar = None # plot_palette_bar set to None because currently there is no empty_circle toggle for proportion plots
615+
plot_palette_bar = dict()
616616
for i in range(len(idx)):
617617
for names_i in idx[i]:
618618
plot_palette_raw[names_i] = swarm_colors[i]
619619
plot_palette_contrast[names_i] = contrast_colors[i]
620+
plot_palette_bar[names_i] = bar_color[i]
620621
else:
621622
plot_palette_raw = dict(zip(names, swarm_colors))
622623
plot_palette_contrast = dict(zip(names, contrast_colors))
@@ -1018,6 +1019,7 @@ def lookup_value(text):
10181019
ticks_with_counts.append(f"{t}\n(N={value})")
10191020

10201021
fontsize_rawxlabel = plot_kwargs.get("fontsize_rawxlabel")
1022+
set_major_loc_method(plt.FixedLocator(get_ticks()))
10211023
set_label(ticks_with_counts, fontsize=fontsize_rawxlabel)
10221024

10231025
# Ensure ticks are at the correct locations

dabest/plot_tools.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -731,14 +731,17 @@ def sankeydiag(
731731
right_idx = []
732732
# Design for Sankey Flow Diagram
733733
sankey_idx = (
734-
[
735-
(control, test)
736-
for i in idx
737-
for control, test in zip(i[:], (i[1:] + (i[0],)))
738-
]
739-
if flow
740-
else temp_idx
741-
)
734+
[
735+
(control, test)
736+
for i in idx
737+
for control, test in zip(
738+
i[:],
739+
(tuple(i[1:]) + (i[0],)) if isinstance(i, tuple) else (list(i[1:]) + [i[0]])
740+
)
741+
]
742+
if flow
743+
else temp_idx
744+
)
742745
for i in sankey_idx:
743746
left_idx.append(i[0])
744747
right_idx.append(i[1])
@@ -2065,6 +2068,7 @@ def barplotter(
20652068
plot_data: pd.DataFrame,
20662069
bar_color: str,
20672070
plot_palette_bar: dict,
2071+
color_col: str,
20682072
plot_kwargs: dict,
20692073
barplot_kwargs: dict,
20702074
horizontal: bool
@@ -2088,6 +2092,8 @@ def barplotter(
20882092
Color of the bar.
20892093
plot_palette_bar : dict
20902094
Dictionary of colors used in the bar plot.
2095+
color_col : str
2096+
Column name of the color column.
20912097
plot_kwargs : dict
20922098
Keyword arguments for the plot.
20932099
barplot_kwargs : dict
@@ -2102,7 +2108,26 @@ def barplotter(
21022108
else:
21032109
x_var, y_var, orient = all_plot_groups, np.ones(len(all_plot_groups)), "v"
21042110

2105-
bar1_df = pd.DataFrame({xvar: x_var, "proportion": y_var})
2111+
# Create bar1_df with basic columns
2112+
bar1_df = pd.DataFrame({
2113+
xvar: x_var,
2114+
"proportion": y_var
2115+
})
2116+
2117+
# Handle colors
2118+
if color_col:
2119+
# Get first color value for each group
2120+
color_mapping = plot_data.groupby(xvar, observed=False)[color_col].first()
2121+
bar1_df[color_col] = [color_mapping.get(group) for group in all_plot_groups]
2122+
2123+
# Map colors, defaulting to bar_color if no match
2124+
edge_colors = [
2125+
plot_palette_bar.get(hue_val, bar_color)
2126+
for hue_val in bar1_df[color_col]
2127+
]
2128+
else:
2129+
edge_colors = bar_color
2130+
21062131

21072132
bar1 = sns.barplot(
21082133
data=bar1_df,
@@ -2112,7 +2137,7 @@ def barplotter(
21122137
order=all_plot_groups,
21132138
linewidth=2,
21142139
facecolor=(1, 1, 1, 0),
2115-
edgecolor=bar_color,
2140+
edgecolor=edge_colors,
21162141
zorder=1,
21172142
orient=orient,
21182143
)
@@ -2123,6 +2148,8 @@ def barplotter(
21232148
ax=rawdata_axes,
21242149
order=all_plot_groups,
21252150
palette=plot_palette_bar,
2151+
hue=color_col,
2152+
dodge=False,
21262153
zorder=1,
21272154
orient=orient,
21282155
**barplot_kwargs

dabest/plotter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi
277277
plot_data = plot_data,
278278
bar_color = bar_color,
279279
plot_palette_bar = plot_palette_bar,
280+
color_col = color_col,
280281
plot_kwargs = plot_kwargs,
281282
barplot_kwargs = barplot_kwargs,
282283
horizontal = horizontal,

0 commit comments

Comments
 (0)