Skip to content

Commit 68267bd

Browse files
committed
calculate_shift
1 parent 9195ef7 commit 68267bd

File tree

1 file changed

+38
-12
lines changed

1 file changed

+38
-12
lines changed

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,20 @@
6868
"""
6969

7070

71+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
72+
def calculate_shift(
73+
image_seq_len,
74+
base_seq_len: int = 256,
75+
max_seq_len: int = 4096,
76+
base_shift: float = 0.5,
77+
max_shift: float = 1.16,
78+
):
79+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
80+
b = base_shift - m * base_seq_len
81+
mu = image_seq_len * m + b
82+
return mu
83+
84+
7185
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
7286
def retrieve_timesteps(
7387
scheduler,
@@ -884,18 +898,7 @@ def __call__(
884898
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
885899
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
886900

887-
# 4. Prepare timesteps
888-
timesteps, num_inference_steps = retrieve_timesteps(
889-
self.scheduler,
890-
num_inference_steps,
891-
device,
892-
sigmas=sigmas,
893-
mu=mu,
894-
)
895-
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
896-
self._num_timesteps = len(timesteps)
897-
898-
# 5. Prepare latent variables
901+
# 4. Prepare latent variables
899902
num_channels_latents = self.transformer.config.in_channels
900903
latents = self.prepare_latents(
901904
batch_size * num_images_per_prompt,
@@ -908,6 +911,29 @@ def __call__(
908911
latents,
909912
)
910913

914+
# 5. Prepare timesteps
915+
if self.scheduler.config.use_dynamic_shifting and mu is None:
916+
_, _, height, width = latents.shape
917+
image_seq_len = (height // self.transformer.config.patch_size) * (
918+
width // self.transformer.config.patch_size
919+
)
920+
mu = calculate_shift(
921+
image_seq_len,
922+
self.scheduler.config.base_image_seq_len,
923+
self.scheduler.config.max_image_seq_len,
924+
self.scheduler.config.base_shift,
925+
self.scheduler.config.max_shift,
926+
)
927+
timesteps, num_inference_steps = retrieve_timesteps(
928+
self.scheduler,
929+
num_inference_steps,
930+
device,
931+
sigmas=sigmas,
932+
mu=mu,
933+
)
934+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
935+
self._num_timesteps = len(timesteps)
936+
911937
# 6. Denoising loop
912938
with self.progress_bar(total=num_inference_steps) as progress_bar:
913939
for i, t in enumerate(timesteps):

0 commit comments

Comments
 (0)