Skip to content

Commit 4820d2a

Browse files
committed
Inpaint/img2img
1 parent 75a9feb commit 4820d2a

File tree

2 files changed

+68
-2
lines changed

2 files changed

+68
-2
lines changed

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,20 @@
7575
"""
7676

7777

78+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
79+
def calculate_shift(
80+
image_seq_len,
81+
base_seq_len: int = 256,
82+
max_seq_len: int = 4096,
83+
base_shift: float = 0.5,
84+
max_shift: float = 1.16,
85+
):
86+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
87+
b = base_shift - m * base_seq_len
88+
mu = image_seq_len * m + b
89+
return mu
90+
91+
7892
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
7993
def retrieve_latents(
8094
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
@@ -748,6 +762,7 @@ def __call__(
748762
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
749763
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
750764
max_sequence_length: int = 256,
765+
mu: Optional[float] = None,
751766
):
752767
r"""
753768
Function invoked when calling the pipeline for generation.
@@ -832,6 +847,7 @@ def __call__(
832847
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
833848
`._callback_tensor_inputs` attribute of your pipeline class.
834849
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
850+
mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
835851
836852
Examples:
837853
@@ -913,7 +929,24 @@ def __call__(
913929
image = self.image_processor.preprocess(image, height=height, width=width)
914930

915931
# 4. Prepare timesteps
916-
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
932+
scheduler_kwargs = {}
933+
if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
934+
image_seq_len = (int(height) // self.vae_scale_factor // self.transformer.config.patch_size) * (
935+
int(width) // self.vae_scale_factor // self.transformer.config.patch_size
936+
)
937+
mu = calculate_shift(
938+
image_seq_len,
939+
self.scheduler.config.base_image_seq_len,
940+
self.scheduler.config.max_image_seq_len,
941+
self.scheduler.config.base_shift,
942+
self.scheduler.config.max_shift,
943+
)
944+
scheduler_kwargs["mu"] = mu
945+
elif mu is not None:
946+
scheduler_kwargs["mu"] = mu
947+
timesteps, num_inference_steps = retrieve_timesteps(
948+
self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs
949+
)
917950
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
918951
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
919952

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,20 @@
7474
"""
7575

7676

77+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
78+
def calculate_shift(
79+
image_seq_len,
80+
base_seq_len: int = 256,
81+
max_seq_len: int = 4096,
82+
base_shift: float = 0.5,
83+
max_shift: float = 1.16,
84+
):
85+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
86+
b = base_shift - m * base_seq_len
87+
mu = image_seq_len * m + b
88+
return mu
89+
90+
7791
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
7892
def retrieve_latents(
7993
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
@@ -838,6 +852,7 @@ def __call__(
838852
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
839853
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
840854
max_sequence_length: int = 256,
855+
mu: Optional[float] = None,
841856
):
842857
r"""
843858
Function invoked when calling the pipeline for generation.
@@ -947,6 +962,7 @@ def __call__(
947962
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
948963
`._callback_tensor_inputs` attribute of your pipeline class.
949964
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
965+
mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
950966
951967
Examples:
952968
@@ -1023,7 +1039,24 @@ def __call__(
10231039
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
10241040

10251041
# 3. Prepare timesteps
1026-
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
1042+
scheduler_kwargs = {}
1043+
if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
1044+
image_seq_len = (int(height) // self.vae_scale_factor // self.transformer.config.patch_size) * (
1045+
int(width) // self.vae_scale_factor // self.transformer.config.patch_size
1046+
)
1047+
mu = calculate_shift(
1048+
image_seq_len,
1049+
self.scheduler.config.base_image_seq_len,
1050+
self.scheduler.config.max_image_seq_len,
1051+
self.scheduler.config.base_shift,
1052+
self.scheduler.config.max_shift,
1053+
)
1054+
scheduler_kwargs["mu"] = mu
1055+
elif mu is not None:
1056+
scheduler_kwargs["mu"] = mu
1057+
timesteps, num_inference_steps = retrieve_timesteps(
1058+
self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs
1059+
)
10271060
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
10281061
# check that number of inference steps is not < 1 - as this doesn't make sense
10291062
if num_inference_steps < 1:

0 commit comments

Comments
 (0)