1
1
from collections import namedtuple
2
2
import numpy as np
3
+ from math import floor
3
4
import torch
4
5
import tqdm
5
6
from PIL import Image
@@ -205,17 +206,22 @@ def initialize(self, p):
205
206
self .mask = p .mask if hasattr (p , 'mask' ) else None
206
207
self .nmask = p .nmask if hasattr (p , 'nmask' ) else None
207
208
209
+
210
+ def adjust_steps_if_invalid (self , p , num_steps ):
211
+ if (self .config .name == 'DDIM' and p .ddim_discretize == 'uniform' ) or (self .config .name == 'PLMS' ):
212
+ valid_step = 999 / (1000 // num_steps )
213
+ if valid_step == floor (valid_step ):
214
+ return int (valid_step ) + 1
215
+
216
+ return num_steps
217
+
218
+
208
219
def sample_img2img (self , p , x , noise , conditioning , unconditional_conditioning , steps = None , image_conditioning = None ):
209
220
steps , t_enc = setup_img2img_steps (p , steps )
210
-
221
+ steps = self . adjust_steps_if_invalid ( p , steps )
211
222
self .initialize (p )
212
223
213
- # existing code fails with certain step counts, like 9
214
- try :
215
- self .sampler .make_schedule (ddim_num_steps = steps , ddim_eta = self .eta , ddim_discretize = p .ddim_discretize , verbose = False )
216
- except Exception :
217
- self .sampler .make_schedule (ddim_num_steps = steps + 1 , ddim_eta = self .eta , ddim_discretize = p .ddim_discretize , verbose = False )
218
-
224
+ self .sampler .make_schedule (ddim_num_steps = steps , ddim_eta = self .eta , ddim_discretize = p .ddim_discretize , verbose = False )
219
225
x1 = self .sampler .stochastic_encode (x , torch .tensor ([t_enc ] * int (x .shape [0 ])).to (shared .device ), noise = noise )
220
226
221
227
self .init_latent = x
@@ -239,18 +245,14 @@ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, ima
239
245
self .last_latent = x
240
246
self .step = 0
241
247
242
- steps = steps or p .steps
248
+ steps = self . adjust_steps_if_invalid ( p , steps or p .steps )
243
249
244
250
# Wrap the conditioning models with additional image conditioning for inpainting model
245
251
if image_conditioning is not None :
246
252
conditioning = {"c_concat" : [image_conditioning ], "c_crossattn" : [conditioning ]}
247
253
unconditional_conditioning = {"c_concat" : [image_conditioning ], "c_crossattn" : [unconditional_conditioning ]}
248
254
249
- # existing code fails with certain step counts, like 9
250
- try :
251
- samples_ddim = self .launch_sampling (steps , lambda : self .sampler .sample (S = steps , conditioning = conditioning , batch_size = int (x .shape [0 ]), shape = x [0 ].shape , verbose = False , unconditional_guidance_scale = p .cfg_scale , unconditional_conditioning = unconditional_conditioning , x_T = x , eta = self .eta )[0 ])
252
- except Exception :
253
- samples_ddim = self .launch_sampling (steps , lambda : self .sampler .sample (S = steps + 1 , conditioning = conditioning , batch_size = int (x .shape [0 ]), shape = x [0 ].shape , verbose = False , unconditional_guidance_scale = p .cfg_scale , unconditional_conditioning = unconditional_conditioning , x_T = x , eta = self .eta )[0 ])
255
+ samples_ddim = self .launch_sampling (steps , lambda : self .sampler .sample (S = steps , conditioning = conditioning , batch_size = int (x .shape [0 ]), shape = x [0 ].shape , verbose = False , unconditional_guidance_scale = p .cfg_scale , unconditional_conditioning = unconditional_conditioning , x_T = x , eta = self .eta )[0 ])
254
256
255
257
return samples_ddim
256
258
0 commit comments