Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,8 +941,12 @@ def __call__(

# compute previous image: x_t -> x_t-1
if num_inference_steps == 1:
# For DMD one step sampling: https://arxiv.org/abs/2311.18828
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample
try:
# For LCM one step sampling
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).denoised
except:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
except:
except Exception:

At minimum due to ruff. What's the exception type raised here? Is there anything else we can check instead of relying on try/except?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is due to the Output difference between

prev_sample: torch.Tensor
denoised: Optional[torch.Tensor] = None

and
prev_sample: torch.Tensor
pred_original_sample: Optional[torch.Tensor] = None

What do you think is better to write here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gentle ping @hlky

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With return_dict=False a tuple is returned, both denoised and pred_original_sample are at index 1 so we can do

latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[1]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. updated.

# For DMD one step sampling: https://arxiv.org/abs/2311.18828
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

Expand Down
Loading