Skip to content

Commit 751283a

Browse files
authored
fix img2img variations/MPS (#353)
* fix img2img variations * fix assert for variation_amount
1 parent c22c3de commit 751283a

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

ldm/simplet2i.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def process_image(image,seed):
286286
0.0 <= variation_amount <= 1.0
287287
), '-v --variation_amount must be in [0.0, 1.0]'
288288

289-
if len(with_variations) > 0 or variation_amount > 1.0:
289+
if len(with_variations) > 0 or variation_amount > 0.0:
290290
assert seed is not None,\
291291
'seed must be specified when using with_variations'
292292
if variation_amount == 0.0:
@@ -336,6 +336,7 @@ def process_image(image,seed):
336336
callback=step_callback,
337337
)
338338
else:
339+
init_latent = None
339340
make_image = self._txt2img(
340341
prompt,
341342
steps=steps,
@@ -351,11 +352,11 @@ def process_image(image,seed):
351352
if variation_amount > 0 or len(with_variations) > 0:
352353
# use fixed initial noise plus random noise per iteration
353354
seed_everything(seed)
354-
initial_noise = self._get_noise(init_img,width,height)
355+
initial_noise = self._get_noise(init_latent,width,height)
355356
for v_seed, v_weight in with_variations:
356357
seed = v_seed
357358
seed_everything(seed)
358-
next_noise = self._get_noise(init_img,width,height)
359+
next_noise = self._get_noise(init_latent,width,height)
359360
initial_noise = self.slerp(v_weight, initial_noise, next_noise)
360361
if variation_amount > 0:
361362
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
@@ -367,15 +368,15 @@ def process_image(image,seed):
367368
x_T = None
368369
if variation_amount > 0:
369370
seed_everything(seed)
370-
target_noise = self._get_noise(init_img,width,height)
371+
target_noise = self._get_noise(init_latent,width,height)
371372
x_T = self.slerp(variation_amount, initial_noise, target_noise)
372373
elif initial_noise is not None:
373374
# i.e. we specified particular variations
374375
x_T = initial_noise
375376
else:
376377
seed_everything(seed)
377378
if self.device.type == 'mps':
378-
x_T = self._get_noise(init_img,width,height)
379+
x_T = self._get_noise(init_latent,width,height)
379380
# make_image will do the equivalent of get_noise itself
380381
print(f' DEBUG: seed at make_image() invocation time ={seed}')
381382
image = make_image(x_T)
@@ -606,8 +607,8 @@ def load_model(self):
606607
return self.model
607608

608609
# returns a tensor filled with random numbers from a normal distribution
609-
def _get_noise(self,init_img,width,height):
610-
if init_img:
610+
def _get_noise(self,init_latent,width,height):
611+
if init_latent is not None:
611612
if self.device.type == 'mps':
612613
return torch.randn_like(init_latent, device='cpu').to(self.device)
613614
else:

0 commit comments

Comments
 (0)