Skip to content

Commit 087bee1

Browse files
committed
fix copies
1 parent fdb6fb2 commit 087bee1

File tree

3 files changed

+18
-39
lines changed

3 files changed

+18
-39
lines changed

src/diffusers/pipelines/allegro/pipeline_allegro.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -251,13 +251,6 @@ def encode_prompt(
251251
if device is None:
252252
device = self._execution_device
253253

254-
if prompt is not None and isinstance(prompt, str):
255-
batch_size = 1
256-
elif prompt is not None and isinstance(prompt, list):
257-
batch_size = len(prompt)
258-
else:
259-
batch_size = prompt_embeds.shape[0]
260-
261254
# See Section 3.1. of the paper.
262255
max_length = max_sequence_length
263256

@@ -302,12 +295,12 @@ def encode_prompt(
302295
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
303296
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
304297
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
305-
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
306-
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
298+
prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt)
299+
prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_videos_per_prompt, -1)
307300

308301
# get unconditional embeddings for classifier free guidance
309302
if do_classifier_free_guidance and negative_prompt_embeds is None:
310-
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
303+
uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt
311304
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
312305
max_length = prompt_embeds.shape[1]
313306
uncond_input = self.tokenizer(
@@ -334,10 +327,10 @@ def encode_prompt(
334327
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
335328

336329
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
337-
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
330+
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
338331

339-
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
340-
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_videos_per_prompt, 1)
332+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_videos_per_prompt)
333+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_videos_per_prompt, -1)
341334
else:
342335
negative_prompt_embeds = None
343336
negative_prompt_attention_mask = None

src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -227,13 +227,6 @@ def encode_prompt(
227227
if device is None:
228228
device = self._execution_device
229229

230-
if prompt is not None and isinstance(prompt, str):
231-
batch_size = 1
232-
elif prompt is not None and isinstance(prompt, list):
233-
batch_size = len(prompt)
234-
else:
235-
batch_size = prompt_embeds.shape[0]
236-
237230
# See Section 3.1. of the paper.
238231
max_length = max_sequence_length
239232

@@ -278,12 +271,12 @@ def encode_prompt(
278271
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
279272
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
280273
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
281-
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
282-
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
274+
prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
275+
prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
283276

284277
# get unconditional embeddings for classifier free guidance
285278
if do_classifier_free_guidance and negative_prompt_embeds is None:
286-
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
279+
uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt
287280
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
288281
max_length = prompt_embeds.shape[1]
289282
uncond_input = self.tokenizer(
@@ -310,10 +303,10 @@ def encode_prompt(
310303
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
311304

312305
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
313-
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
306+
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
314307

315-
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
316-
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
308+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt)
309+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
317310
else:
318311
negative_prompt_embeds = None
319312
negative_prompt_attention_mask = None

src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -264,13 +264,6 @@ def encode_prompt(
264264
if device is None:
265265
device = self._execution_device
266266

267-
if prompt is not None and isinstance(prompt, str):
268-
batch_size = 1
269-
elif prompt is not None and isinstance(prompt, list):
270-
batch_size = len(prompt)
271-
else:
272-
batch_size = prompt_embeds.shape[0]
273-
274267
# See Section 3.1. of the paper.
275268
max_length = max_sequence_length
276269

@@ -315,12 +308,12 @@ def encode_prompt(
315308
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
316309
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
317310
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
318-
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
319-
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
311+
prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
312+
prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
320313

321314
# get unconditional embeddings for classifier free guidance
322315
if do_classifier_free_guidance and negative_prompt_embeds is None:
323-
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
316+
uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt
324317
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
325318
max_length = prompt_embeds.shape[1]
326319
uncond_input = self.tokenizer(
@@ -347,10 +340,10 @@ def encode_prompt(
347340
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
348341

349342
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
350-
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
343+
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
351344

352-
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
353-
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
345+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt)
346+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
354347
else:
355348
negative_prompt_embeds = None
356349
negative_prompt_attention_mask = None

0 commit comments

Comments
 (0)