From 2dae26c0eb799ef87a0be79df15197861bc69d5d Mon Sep 17 00:00:00 2001 From: junsong Date: Sun, 16 Mar 2025 22:25:16 -0700 Subject: [PATCH 1/2] fix bug when pixart-dmd inference with `num_inference_steps=1` --- .../pipelines/pixart_alpha/pipeline_pixart_alpha.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index b550a442fe15..82a9a0fc4ec7 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -941,8 +941,12 @@ def __call__( # compute previous image: x_t -> x_t-1 if num_inference_steps == 1: - # For DMD one step sampling: https://arxiv.org/abs/2311.18828 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample + try: + # For LCM one step sampling + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).denoised + except: + # For DMD one step sampling: https://arxiv.org/abs/2311.18828 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample else: latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] From ae78a65b94d0adda7ad8d43dc2a0edea58d2c44e Mon Sep 17 00:00:00 2001 From: lawrence-cj Date: Thu, 20 Mar 2025 15:02:16 +0800 Subject: [PATCH 2/2] use return_dict=False and return [1] element for 1-step pixart model, which works for both lcm and dmd --- .../pipelines/pixart_alpha/pipeline_pixart_alpha.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index 82a9a0fc4ec7..988e049dd684 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -941,12 +941,7 @@ def __call__( # compute previous image: x_t -> x_t-1 if num_inference_steps == 1: - try: - # For LCM one step sampling - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).denoised - except: - # For DMD one step sampling: https://arxiv.org/abs/2311.18828 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[1] else: latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]