Skip to content

Commit 2f68e52

Browse files
hlkynikitabalabin
authored andcommitted
Fix multi-prompt inference (#10103)
* Fix multi-prompt inference Fix generation of multiple images with multiple prompts, e.g len(prompts)>1, num_images_per_prompt>1 * make * fix copies --------- Co-authored-by: Nikita Balabin <[email protected]>
1 parent 69acaee commit 2f68e52

File tree

4 files changed

+24
-52
lines changed

4 files changed

+24
-52
lines changed

src/diffusers/pipelines/allegro/pipeline_allegro.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -251,13 +251,6 @@ def encode_prompt(
251251
if device is None:
252252
device = self._execution_device
253253

254-
if prompt is not None and isinstance(prompt, str):
255-
batch_size = 1
256-
elif prompt is not None and isinstance(prompt, list):
257-
batch_size = len(prompt)
258-
else:
259-
batch_size = prompt_embeds.shape[0]
260-
261254
# See Section 3.1. of the paper.
262255
max_length = max_sequence_length
263256

@@ -302,12 +295,12 @@ def encode_prompt(
302295
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
303296
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
304297
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
305-
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
306-
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
298+
prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt)
299+
prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_videos_per_prompt, -1)
307300

308301
# get unconditional embeddings for classifier free guidance
309302
if do_classifier_free_guidance and negative_prompt_embeds is None:
310-
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
303+
uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt
311304
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
312305
max_length = prompt_embeds.shape[1]
313306
uncond_input = self.tokenizer(
@@ -334,10 +327,10 @@ def encode_prompt(
334327
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
335328

336329
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
337-
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
330+
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
338331

339-
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
340-
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_videos_per_prompt, 1)
332+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_videos_per_prompt)
333+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_videos_per_prompt, -1)
341334
else:
342335
negative_prompt_embeds = None
343336
negative_prompt_attention_mask = None

src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -227,13 +227,6 @@ def encode_prompt(
227227
if device is None:
228228
device = self._execution_device
229229

230-
if prompt is not None and isinstance(prompt, str):
231-
batch_size = 1
232-
elif prompt is not None and isinstance(prompt, list):
233-
batch_size = len(prompt)
234-
else:
235-
batch_size = prompt_embeds.shape[0]
236-
237230
# See Section 3.1. of the paper.
238231
max_length = max_sequence_length
239232

@@ -278,12 +271,12 @@ def encode_prompt(
278271
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
279272
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
280273
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
281-
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
282-
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
274+
prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
275+
prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
283276

284277
# get unconditional embeddings for classifier free guidance
285278
if do_classifier_free_guidance and negative_prompt_embeds is None:
286-
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
279+
uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt
287280
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
288281
max_length = prompt_embeds.shape[1]
289282
uncond_input = self.tokenizer(
@@ -310,10 +303,10 @@ def encode_prompt(
310303
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
311304

312305
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
313-
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
306+
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
314307

315-
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
316-
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
308+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt)
309+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
317310
else:
318311
negative_prompt_embeds = None
319312
negative_prompt_attention_mask = None

src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -338,13 +338,6 @@ def encode_prompt(
338338
if device is None:
339339
device = self._execution_device
340340

341-
if prompt is not None and isinstance(prompt, str):
342-
batch_size = 1
343-
elif prompt is not None and isinstance(prompt, list):
344-
batch_size = len(prompt)
345-
else:
346-
batch_size = prompt_embeds.shape[0]
347-
348341
# See Section 3.1. of the paper.
349342
max_length = max_sequence_length
350343

@@ -389,12 +382,12 @@ def encode_prompt(
389382
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
390383
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
391384
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
392-
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
393-
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
385+
prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
386+
prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
394387

395388
# get unconditional embeddings for classifier free guidance
396389
if do_classifier_free_guidance and negative_prompt_embeds is None:
397-
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
390+
uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt
398391
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
399392
max_length = prompt_embeds.shape[1]
400393
uncond_input = self.tokenizer(
@@ -421,10 +414,10 @@ def encode_prompt(
421414
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
422415

423416
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
424-
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
417+
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
425418

426-
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
427-
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
419+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt)
420+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
428421
else:
429422
negative_prompt_embeds = None
430423
negative_prompt_attention_mask = None

src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -264,13 +264,6 @@ def encode_prompt(
264264
if device is None:
265265
device = self._execution_device
266266

267-
if prompt is not None and isinstance(prompt, str):
268-
batch_size = 1
269-
elif prompt is not None and isinstance(prompt, list):
270-
batch_size = len(prompt)
271-
else:
272-
batch_size = prompt_embeds.shape[0]
273-
274267
# See Section 3.1. of the paper.
275268
max_length = max_sequence_length
276269

@@ -315,12 +308,12 @@ def encode_prompt(
315308
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
316309
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
317310
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
318-
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
319-
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
311+
prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
312+
prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
320313

321314
# get unconditional embeddings for classifier free guidance
322315
if do_classifier_free_guidance and negative_prompt_embeds is None:
323-
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
316+
uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt
324317
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
325318
max_length = prompt_embeds.shape[1]
326319
uncond_input = self.tokenizer(
@@ -347,10 +340,10 @@ def encode_prompt(
347340
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
348341

349342
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
350-
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
343+
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
351344

352-
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
353-
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
345+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt)
346+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
354347
else:
355348
negative_prompt_embeds = None
356349
negative_prompt_attention_mask = None

0 commit comments

Comments
 (0)