Skip to content

Commit 2dae26c

Browse files
committed
fix bug when pixart-dmd inference with num_inference_steps=1
1 parent 1001425 commit 2dae26c

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -941,8 +941,12 @@ def __call__(
941941

942942
# compute previous image: x_t -> x_t-1
943943
if num_inference_steps == 1:
944-
# For DMD one step sampling: https://arxiv.org/abs/2311.18828
945-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample
944+
try:
945+
# For LCM one step sampling
946+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).denoised
947+
except:
948+
# For DMD one step sampling: https://arxiv.org/abs/2311.18828
949+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample
946950
else:
947951
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
948952

0 commit comments

Comments
 (0)