Skip to content

Commit 33f5e50

Browse files
committed
Add dynamic_shifting to SD3
1 parent 5fb3a98 commit 33f5e50

File tree

1 file changed

+36
-6
lines changed

1 file changed

+36
-6
lines changed

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,20 @@
6868
"""
6969

7070

71+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
72+
def calculate_shift(
73+
image_seq_len,
74+
base_seq_len: int = 256,
75+
max_seq_len: int = 4096,
76+
base_shift: float = 0.5,
77+
max_shift: float = 1.16,
78+
):
79+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
80+
b = base_shift - m * base_seq_len
81+
mu = image_seq_len * m + b
82+
return mu
83+
84+
7185
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
7286
def retrieve_timesteps(
7387
scheduler,
@@ -882,12 +896,7 @@ def __call__(
882896
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
883897
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
884898

885-
# 4. Prepare timesteps
886-
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
887-
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
888-
self._num_timesteps = len(timesteps)
889-
890-
# 5. Prepare latent variables
899+
# 4. Prepare latent variables
891900
num_channels_latents = self.transformer.config.in_channels
892901
latents = self.prepare_latents(
893902
batch_size * num_images_per_prompt,
@@ -900,6 +909,27 @@ def __call__(
900909
latents,
901910
)
902911

912+
# 5. Prepare timesteps
913+
mu = None
914+
if self.scheduler.config.use_dynamic_shifting:
915+
image_seq_len = latents.shape[1]
916+
mu = calculate_shift(
917+
image_seq_len,
918+
self.scheduler.config.base_image_seq_len,
919+
self.scheduler.config.max_image_seq_len,
920+
self.scheduler.config.base_shift,
921+
self.scheduler.config.max_shift,
922+
)
923+
timesteps, num_inference_steps = retrieve_timesteps(
924+
self.scheduler,
925+
num_inference_steps,
926+
device,
927+
sigmas=sigmas,
928+
mu=mu,
929+
)
930+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
931+
self._num_timesteps = len(timesteps)
932+
903933
# 6. Denoising loop
904934
with self.progress_bar(total=num_inference_steps) as progress_bar:
905935
for i, t in enumerate(timesteps):

0 commit comments

Comments
 (0)