Skip to content

Commit 8c07fcc

Browse files
committed
up
1 parent 3734af8 commit 8c07fcc

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

src/diffusers/pipelines/sana/pipeline_sana_sprint.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,24 @@ def encode_prompt(
394394

395395
return prompt_embeds, prompt_attention_mask
396396

397+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
398+
def prepare_extra_step_kwargs(self, generator, eta):
399+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
400+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
401+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
402+
# and should be between [0, 1]
403+
404+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
405+
extra_step_kwargs = {}
406+
if accepts_eta:
407+
extra_step_kwargs["eta"] = eta
408+
409+
# check if the scheduler accepts generator
410+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
411+
if accepts_generator:
412+
extra_step_kwargs["generator"] = generator
413+
return extra_step_kwargs
414+
397415
def check_inputs(
398416
self,
399417
prompt,
@@ -835,10 +853,11 @@ def __call__(
835853
guidance = guidance.expand(latents.shape[0]).to(prompt_embeds.dtype)
836854
guidance = guidance * self.transformer.config.guidance_embeds_scale
837855

838-
# YiYi TODO: refactor this
839-
timesteps = timesteps[:-1]
856+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
857+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
840858

841859
# 7. Denoising loop
860+
timesteps = timesteps[:-1]
842861
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
843862
self._num_timesteps = len(timesteps)
844863

0 commit comments

Comments
 (0)