Skip to content

Commit 42bd9bf

Browse files
committed
Add dynamic_shifting to SD3
1 parent 5fb3a98 commit 42bd9bf

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,7 @@ def __call__(
702702
skip_layer_guidance_scale: int = 2.8,
703703
skip_layer_guidance_stop: int = 0.2,
704704
skip_layer_guidance_start: int = 0.01,
705+
mu: Optional[float] = None,
705706
):
706707
r"""
707708
Function invoked when calling the pipeline for generation.
@@ -802,6 +803,7 @@ def __call__(
802803
`skip_guidance_layers` will start. The guidance will be applied to the layers specified in
803804
`skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
804805
StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.
806+
mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
805807
806808
Examples:
807809
@@ -882,12 +884,7 @@ def __call__(
882884
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
883885
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
884886

885-
# 4. Prepare timesteps
886-
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
887-
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
888-
self._num_timesteps = len(timesteps)
889-
890-
# 5. Prepare latent variables
887+
# 4. Prepare latent variables
891888
num_channels_latents = self.transformer.config.in_channels
892889
latents = self.prepare_latents(
893890
batch_size * num_images_per_prompt,
@@ -900,6 +897,17 @@ def __call__(
900897
latents,
901898
)
902899

900+
# 5. Prepare timesteps
901+
timesteps, num_inference_steps = retrieve_timesteps(
902+
self.scheduler,
903+
num_inference_steps,
904+
device,
905+
sigmas=sigmas,
906+
mu=mu,
907+
)
908+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
909+
self._num_timesteps = len(timesteps)
910+
903911
# 6. Denoising loop
904912
with self.progress_bar(total=num_inference_steps) as progress_bar:
905913
for i, t in enumerate(timesteps):

0 commit comments

Comments
 (0)