Skip to content

Commit e7f0d03

Browse files
committed
Fix Beta sampling to match the paper
We need a beta distribution of timesteps, not sigmas. Also allow distribution parameters > 1.
1 parent 374bb6c commit e7f0d03

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

modules/sd_schedulers.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,25 @@ def ddim_scheduler(n, sigma_min, sigma_max, inner_model, device):
116116
return torch.FloatTensor(sigs).to(device)
117117

118118

119-
def beta_scheduler(n, sigma_min, sigma_max, inner_model, device):
120-
# From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024) """
119+
def beta_scheduler(n, sigma_min, sigma_max, inner_model, device, sgm=False):
120+
start = inner_model.sigma_to_t(torch.tensor(sigma_max))
121+
end = inner_model.sigma_to_t(torch.tensor(sigma_min))
122+
123+
# From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)
121124
alpha = shared.opts.beta_dist_alpha
122125
beta = shared.opts.beta_dist_beta
123-
timesteps = 1 - np.linspace(0, 1, n)
126+
if sgm:
127+
timesteps = np.linspace(1, 0, n + 1)[:-1]
128+
else:
129+
timesteps = np.linspace(1, 0, n)
124130
timesteps = [stats.beta.ppf(x, alpha, beta) for x in timesteps]
125-
sigmas = [sigma_min + (x * (sigma_max-sigma_min)) for x in timesteps]
126-
sigmas += [0.0]
127-
return torch.FloatTensor(sigmas).to(device)
131+
132+
sigs = []
133+
for x in range(len(timesteps)):
134+
ts = end + (timesteps[x] * (start - end))
135+
sigs.append(inner_model.t_to_sigma(ts))
136+
sigs += [0.0]
137+
return torch.FloatTensor(sigs).to(device)
128138

129139

130140
schedulers = [

modules/shared_options.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,8 @@
407407
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'),
408408
'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise Schedule").info("for use with zero terminal SNR trained models"),
409409
'skip_early_cond': OptionInfo(0.0, "Ignore negative prompt during early sampling", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext="Skip Early CFG").info("disables CFG on a proportion of steps at the beginning of generation; 0=skip none; 1=skip all; can both improve sample diversity/quality and speed up sampling; XYZ plot: Skip Early CFG"),
410-
'beta_dist_alpha': OptionInfo(0.6, "Beta scheduler - alpha", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}, infotext='Beta scheduler alpha').info('Default = 0.6; the alpha parameter of the beta distribution used in Beta sampling'),
411-
'beta_dist_beta': OptionInfo(0.6, "Beta scheduler - beta", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}, infotext='Beta scheduler beta').info('Default = 0.6; the beta parameter of the beta distribution used in Beta sampling'),
410+
'beta_dist_alpha': OptionInfo(0.6, "Beta scheduler - alpha", gr.Slider, {"minimum": 0.01, "maximum": 5.0, "step": 0.01}, infotext='Beta scheduler alpha').info('Default = 0.6; the alpha parameter of the beta distribution used in Beta sampling'),
411+
'beta_dist_beta': OptionInfo(0.6, "Beta scheduler - beta", gr.Slider, {"minimum": 0.01, "maximum": 5.0, "step": 0.01}, infotext='Beta scheduler beta').info('Default = 0.6; the beta parameter of the beta distribution used in Beta sampling'),
412412
}))
413413

414414
options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), {

0 commit comments

Comments
 (0)