Skip to content

Commit 383cd48

Browse files
hlkysayakpaul
authored andcommitted
Add dynamic_shifting to SD3 (#10236)
* Add `dynamic_shifting` to SD3 * calculate_shift * FlowMatchHeunDiscreteScheduler doesn't support mu * Inpaint/img2img
1 parent 0f34f91 commit 383cd48

File tree

3 files changed

+112
-8
lines changed

3 files changed

+112
-8
lines changed

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,20 @@
6868
"""
6969

7070

71+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
72+
def calculate_shift(
73+
image_seq_len,
74+
base_seq_len: int = 256,
75+
max_seq_len: int = 4096,
76+
base_shift: float = 0.5,
77+
max_shift: float = 1.16,
78+
):
79+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
80+
b = base_shift - m * base_seq_len
81+
mu = image_seq_len * m + b
82+
return mu
83+
84+
7185
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
7286
def retrieve_timesteps(
7387
scheduler,
@@ -702,6 +716,7 @@ def __call__(
702716
skip_layer_guidance_scale: int = 2.8,
703717
skip_layer_guidance_stop: int = 0.2,
704718
skip_layer_guidance_start: int = 0.01,
719+
mu: Optional[float] = None,
705720
):
706721
r"""
707722
Function invoked when calling the pipeline for generation.
@@ -802,6 +817,7 @@ def __call__(
802817
`skip_guidance_layers` will start. The guidance will be applied to the layers specified in
803818
`skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
804819
StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.
820+
mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
805821
806822
Examples:
807823
@@ -882,12 +898,7 @@ def __call__(
882898
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
883899
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
884900

885-
# 4. Prepare timesteps
886-
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
887-
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
888-
self._num_timesteps = len(timesteps)
889-
890-
# 5. Prepare latent variables
901+
# 4. Prepare latent variables
891902
num_channels_latents = self.transformer.config.in_channels
892903
latents = self.prepare_latents(
893904
batch_size * num_images_per_prompt,
@@ -900,6 +911,33 @@ def __call__(
900911
latents,
901912
)
902913

914+
# 5. Prepare timesteps
915+
scheduler_kwargs = {}
916+
if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
917+
_, _, height, width = latents.shape
918+
image_seq_len = (height // self.transformer.config.patch_size) * (
919+
width // self.transformer.config.patch_size
920+
)
921+
mu = calculate_shift(
922+
image_seq_len,
923+
self.scheduler.config.base_image_seq_len,
924+
self.scheduler.config.max_image_seq_len,
925+
self.scheduler.config.base_shift,
926+
self.scheduler.config.max_shift,
927+
)
928+
scheduler_kwargs["mu"] = mu
929+
elif mu is not None:
930+
scheduler_kwargs["mu"] = mu
931+
timesteps, num_inference_steps = retrieve_timesteps(
932+
self.scheduler,
933+
num_inference_steps,
934+
device,
935+
sigmas=sigmas,
936+
**scheduler_kwargs,
937+
)
938+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
939+
self._num_timesteps = len(timesteps)
940+
903941
# 6. Denoising loop
904942
with self.progress_bar(total=num_inference_steps) as progress_bar:
905943
for i, t in enumerate(timesteps):

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,20 @@
7575
"""
7676

7777

78+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
79+
def calculate_shift(
80+
image_seq_len,
81+
base_seq_len: int = 256,
82+
max_seq_len: int = 4096,
83+
base_shift: float = 0.5,
84+
max_shift: float = 1.16,
85+
):
86+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
87+
b = base_shift - m * base_seq_len
88+
mu = image_seq_len * m + b
89+
return mu
90+
91+
7892
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
7993
def retrieve_latents(
8094
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
@@ -748,6 +762,7 @@ def __call__(
748762
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
749763
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
750764
max_sequence_length: int = 256,
765+
mu: Optional[float] = None,
751766
):
752767
r"""
753768
Function invoked when calling the pipeline for generation.
@@ -832,6 +847,7 @@ def __call__(
832847
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
833848
`._callback_tensor_inputs` attribute of your pipeline class.
834849
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
850+
mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
835851
836852
Examples:
837853
@@ -913,7 +929,24 @@ def __call__(
913929
image = self.image_processor.preprocess(image, height=height, width=width)
914930

915931
# 4. Prepare timesteps
916-
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
932+
scheduler_kwargs = {}
933+
if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
934+
image_seq_len = (int(height) // self.vae_scale_factor // self.transformer.config.patch_size) * (
935+
int(width) // self.vae_scale_factor // self.transformer.config.patch_size
936+
)
937+
mu = calculate_shift(
938+
image_seq_len,
939+
self.scheduler.config.base_image_seq_len,
940+
self.scheduler.config.max_image_seq_len,
941+
self.scheduler.config.base_shift,
942+
self.scheduler.config.max_shift,
943+
)
944+
scheduler_kwargs["mu"] = mu
945+
elif mu is not None:
946+
scheduler_kwargs["mu"] = mu
947+
timesteps, num_inference_steps = retrieve_timesteps(
948+
self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs
949+
)
917950
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
918951
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
919952

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,20 @@
7474
"""
7575

7676

77+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
78+
def calculate_shift(
79+
image_seq_len,
80+
base_seq_len: int = 256,
81+
max_seq_len: int = 4096,
82+
base_shift: float = 0.5,
83+
max_shift: float = 1.16,
84+
):
85+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
86+
b = base_shift - m * base_seq_len
87+
mu = image_seq_len * m + b
88+
return mu
89+
90+
7791
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
7892
def retrieve_latents(
7993
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
@@ -838,6 +852,7 @@ def __call__(
838852
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
839853
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
840854
max_sequence_length: int = 256,
855+
mu: Optional[float] = None,
841856
):
842857
r"""
843858
Function invoked when calling the pipeline for generation.
@@ -947,6 +962,7 @@ def __call__(
947962
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
948963
`._callback_tensor_inputs` attribute of your pipeline class.
949964
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
965+
mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
950966
951967
Examples:
952968
@@ -1023,7 +1039,24 @@ def __call__(
10231039
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
10241040

10251041
# 3. Prepare timesteps
1026-
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
1042+
scheduler_kwargs = {}
1043+
if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
1044+
image_seq_len = (int(height) // self.vae_scale_factor // self.transformer.config.patch_size) * (
1045+
int(width) // self.vae_scale_factor // self.transformer.config.patch_size
1046+
)
1047+
mu = calculate_shift(
1048+
image_seq_len,
1049+
self.scheduler.config.base_image_seq_len,
1050+
self.scheduler.config.max_image_seq_len,
1051+
self.scheduler.config.base_shift,
1052+
self.scheduler.config.max_shift,
1053+
)
1054+
scheduler_kwargs["mu"] = mu
1055+
elif mu is not None:
1056+
scheduler_kwargs["mu"] = mu
1057+
timesteps, num_inference_steps = retrieve_timesteps(
1058+
self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs
1059+
)
10271060
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
10281061
# check that number of inference steps is not < 1 - as this doesn't make sense
10291062
if num_inference_steps < 1:

0 commit comments

Comments
 (0)