Skip to content

Commit 40b3a7e

Browse files
Merge pull request #3917 from MartinCairnsSQL/adjust-ddim-uniform-steps
Certain step counts for DDIM cause out of bounds error
2 parents dd02889 + b885059 commit 40b3a7e

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

modules/sd_samplers.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections import namedtuple
22
import numpy as np
3+
from math import floor
34
import torch
45
import tqdm
56
from PIL import Image
@@ -205,17 +206,22 @@ def initialize(self, p):
205206
self.mask = p.mask if hasattr(p, 'mask') else None
206207
self.nmask = p.nmask if hasattr(p, 'nmask') else None
207208

209+
210+
def adjust_steps_if_invalid(self, p, num_steps):
211+
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
212+
valid_step = 999 / (1000 // num_steps)
213+
if valid_step == floor(valid_step):
214+
return int(valid_step) + 1
215+
216+
return num_steps
217+
218+
208219
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
209220
steps, t_enc = setup_img2img_steps(p, steps)
210-
221+
steps = self.adjust_steps_if_invalid(p, steps)
211222
self.initialize(p)
212223

213-
# existing code fails with certain step counts, like 9
214-
try:
215-
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
216-
except Exception:
217-
self.sampler.make_schedule(ddim_num_steps=steps+1, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
218-
224+
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
219225
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
220226

221227
self.init_latent = x
@@ -239,18 +245,14 @@ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, ima
239245
self.last_latent = x
240246
self.step = 0
241247

242-
steps = steps or p.steps
248+
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
243249

244250
# Wrap the conditioning models with additional image conditioning for inpainting model
245251
if image_conditioning is not None:
246252
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
247253
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
248254

249-
# existing code fails with certain step counts, like 9
250-
try:
251-
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
252-
except Exception:
253-
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
255+
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
254256

255257
return samples_ddim
256258

0 commit comments

Comments
 (0)