Skip to content

Commit f1ade2e

Browse files
authored
Merge branch 'main' into update_ptxla_training
2 parents df31c9d + cfdeebd commit f1ade2e

File tree

9 files changed

+65
-76
lines changed

9 files changed

+65
-76
lines changed

src/diffusers/pipelines/allegro/pipeline_allegro.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -251,13 +251,6 @@ def encode_prompt(
251251
if device is None:
252252
device = self._execution_device
253253

254-
if prompt is not None and isinstance(prompt, str):
255-
batch_size = 1
256-
elif prompt is not None and isinstance(prompt, list):
257-
batch_size = len(prompt)
258-
else:
259-
batch_size = prompt_embeds.shape[0]
260-
261254
# See Section 3.1. of the paper.
262255
max_length = max_sequence_length
263256

@@ -302,12 +295,12 @@ def encode_prompt(
302295
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
303296
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
304297
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
305-
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
306-
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
298+
prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt)
299+
prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_videos_per_prompt, -1)
307300

308301
# get unconditional embeddings for classifier free guidance
309302
if do_classifier_free_guidance and negative_prompt_embeds is None:
310-
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
303+
uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt
311304
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
312305
max_length = prompt_embeds.shape[1]
313306
uncond_input = self.tokenizer(
@@ -334,10 +327,10 @@ def encode_prompt(
334327
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
335328

336329
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
337-
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
330+
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
338331

339-
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
340-
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_videos_per_prompt, 1)
332+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_videos_per_prompt)
333+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_videos_per_prompt, -1)
341334
else:
342335
negative_prompt_embeds = None
343336
negative_prompt_attention_mask = None

src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -227,13 +227,6 @@ def encode_prompt(
227227
if device is None:
228228
device = self._execution_device
229229

230-
if prompt is not None and isinstance(prompt, str):
231-
batch_size = 1
232-
elif prompt is not None and isinstance(prompt, list):
233-
batch_size = len(prompt)
234-
else:
235-
batch_size = prompt_embeds.shape[0]
236-
237230
# See Section 3.1. of the paper.
238231
max_length = max_sequence_length
239232

@@ -278,12 +271,12 @@ def encode_prompt(
278271
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
279272
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
280273
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
281-
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
282-
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
274+
prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
275+
prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
283276

284277
# get unconditional embeddings for classifier free guidance
285278
if do_classifier_free_guidance and negative_prompt_embeds is None:
286-
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
279+
uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt
287280
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
288281
max_length = prompt_embeds.shape[1]
289282
uncond_input = self.tokenizer(
@@ -310,10 +303,10 @@ def encode_prompt(
310303
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
311304

312305
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
313-
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
306+
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
314307

315-
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
316-
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
308+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt)
309+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
317310
else:
318311
negative_prompt_embeds = None
319312
negative_prompt_attention_mask = None

src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -338,13 +338,6 @@ def encode_prompt(
338338
if device is None:
339339
device = self._execution_device
340340

341-
if prompt is not None and isinstance(prompt, str):
342-
batch_size = 1
343-
elif prompt is not None and isinstance(prompt, list):
344-
batch_size = len(prompt)
345-
else:
346-
batch_size = prompt_embeds.shape[0]
347-
348341
# See Section 3.1. of the paper.
349342
max_length = max_sequence_length
350343

@@ -389,12 +382,12 @@ def encode_prompt(
389382
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
390383
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
391384
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
392-
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
393-
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
385+
prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
386+
prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
394387

395388
# get unconditional embeddings for classifier free guidance
396389
if do_classifier_free_guidance and negative_prompt_embeds is None:
397-
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
390+
uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt
398391
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
399392
max_length = prompt_embeds.shape[1]
400393
uncond_input = self.tokenizer(
@@ -421,10 +414,10 @@ def encode_prompt(
421414
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
422415

423416
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
424-
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
417+
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
425418

426-
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
427-
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
419+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt)
420+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
428421
else:
429422
negative_prompt_embeds = None
430423
negative_prompt_attention_mask = None

src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -264,13 +264,6 @@ def encode_prompt(
264264
if device is None:
265265
device = self._execution_device
266266

267-
if prompt is not None and isinstance(prompt, str):
268-
batch_size = 1
269-
elif prompt is not None and isinstance(prompt, list):
270-
batch_size = len(prompt)
271-
else:
272-
batch_size = prompt_embeds.shape[0]
273-
274267
# See Section 3.1. of the paper.
275268
max_length = max_sequence_length
276269

@@ -315,12 +308,12 @@ def encode_prompt(
315308
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
316309
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
317310
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
318-
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
319-
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
311+
prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
312+
prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
320313

321314
# get unconditional embeddings for classifier free guidance
322315
if do_classifier_free_guidance and negative_prompt_embeds is None:
323-
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
316+
uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt
324317
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
325318
max_length = prompt_embeds.shape[1]
326319
uncond_input = self.tokenizer(
@@ -347,10 +340,10 @@ def encode_prompt(
347340
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
348341

349342
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
350-
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
343+
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
351344

352-
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
353-
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
345+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt)
346+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
354347
else:
355348
negative_prompt_embeds = None
356349
negative_prompt_attention_mask = None

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -548,16 +548,12 @@ def __len__(self):
548548
return self.config.num_train_timesteps
549549

550550
def previous_timestep(self, timestep):
551-
if self.custom_timesteps:
551+
if self.custom_timesteps or self.num_inference_steps:
552552
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
553553
if index == self.timesteps.shape[0] - 1:
554554
prev_t = torch.tensor(-1)
555555
else:
556556
prev_t = self.timesteps[index + 1]
557557
else:
558-
num_inference_steps = (
559-
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
560-
)
561-
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
562-
558+
prev_t = timestep - 1
563559
return prev_t

src/diffusers/schedulers/scheduling_ddpm_parallel.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -639,16 +639,12 @@ def __len__(self):
639639

640640
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
641641
def previous_timestep(self, timestep):
642-
if self.custom_timesteps:
642+
if self.custom_timesteps or self.num_inference_steps:
643643
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
644644
if index == self.timesteps.shape[0] - 1:
645645
prev_t = torch.tensor(-1)
646646
else:
647647
prev_t = self.timesteps[index + 1]
648648
else:
649-
num_inference_steps = (
650-
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
651-
)
652-
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
653-
649+
prev_t = timestep - 1
654650
return prev_t

src/diffusers/schedulers/scheduling_lcm.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -643,16 +643,12 @@ def __len__(self):
643643

644644
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
645645
def previous_timestep(self, timestep):
646-
if self.custom_timesteps:
646+
if self.custom_timesteps or self.num_inference_steps:
647647
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
648648
if index == self.timesteps.shape[0] - 1:
649649
prev_t = torch.tensor(-1)
650650
else:
651651
prev_t = self.timesteps[index + 1]
652652
else:
653-
num_inference_steps = (
654-
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
655-
)
656-
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
657-
653+
prev_t = timestep - 1
658654
return prev_t

src/diffusers/schedulers/scheduling_tcd.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -680,16 +680,12 @@ def __len__(self):
680680

681681
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
682682
def previous_timestep(self, timestep):
683-
if self.custom_timesteps:
683+
if self.custom_timesteps or self.num_inference_steps:
684684
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
685685
if index == self.timesteps.shape[0] - 1:
686686
prev_t = torch.tensor(-1)
687687
else:
688688
prev_t = self.timesteps[index + 1]
689689
else:
690-
num_inference_steps = (
691-
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
692-
)
693-
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
694-
690+
prev_t = timestep - 1
695691
return prev_t

tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,39 @@ def test_fused_qkv_projections(self):
225225
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
226226
), "Original outputs should match when fused QKV projections are disabled."
227227

228+
def test_skip_guidance_layers(self):
229+
components = self.get_dummy_components()
230+
pipe = self.pipeline_class(**components)
231+
pipe = pipe.to(torch_device)
232+
pipe.set_progress_bar_config(disable=None)
233+
234+
inputs = self.get_dummy_inputs(torch_device)
235+
236+
output_full = pipe(**inputs)[0]
237+
238+
inputs_with_skip = inputs.copy()
239+
inputs_with_skip["skip_guidance_layers"] = [0]
240+
output_skip = pipe(**inputs_with_skip)[0]
241+
242+
self.assertFalse(
243+
np.allclose(output_full, output_skip, atol=1e-5), "Outputs should differ when layers are skipped"
244+
)
245+
246+
self.assertEqual(output_full.shape, output_skip.shape, "Outputs should have the same shape")
247+
248+
inputs["num_images_per_prompt"] = 2
249+
output_full = pipe(**inputs)[0]
250+
251+
inputs_with_skip = inputs.copy()
252+
inputs_with_skip["skip_guidance_layers"] = [0]
253+
output_skip = pipe(**inputs_with_skip)[0]
254+
255+
self.assertFalse(
256+
np.allclose(output_full, output_skip, atol=1e-5), "Outputs should differ when layers are skipped"
257+
)
258+
259+
self.assertEqual(output_full.shape, output_skip.shape, "Outputs should have the same shape")
260+
228261

229262
@slow
230263
@require_big_gpu_with_torch_cuda

0 commit comments

Comments
 (0)