Skip to content

Commit 38a3e4d

Browse files
authored
flux controlnet control_guidance_start and control_guidance_end implement (#9571)
* flux controlnet control_guidance_start and control_guidance_end implement * minor fix - added docstrings, consistent controlnet scale flux and SD3
1 parent e16fd93 commit 38a3e4d

File tree

3 files changed

+108
-6
lines changed

3 files changed

+108
-6
lines changed

src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@
7272
>>> image = pipe(
7373
... prompt,
7474
... control_image=control_image,
75-
... controlnet_conditioning_scale=0.6,
75+
... control_guidance_start=0.2,
76+
... control_guidance_end=0.8,
77+
... controlnet_conditioning_scale=1.0,
7678
... num_inference_steps=28,
7779
... guidance_scale=3.5,
7880
... ).images[0]
@@ -572,6 +574,8 @@ def __call__(
572574
num_inference_steps: int = 28,
573575
timesteps: List[int] = None,
574576
guidance_scale: float = 7.0,
577+
control_guidance_start: Union[float, List[float]] = 0.0,
578+
control_guidance_end: Union[float, List[float]] = 1.0,
575579
control_image: PipelineImageInput = None,
576580
control_mode: Optional[Union[int, List[int]]] = None,
577581
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
@@ -614,6 +618,10 @@ def __call__(
614618
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
615619
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
616620
usually at the expense of lower image quality.
621+
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
622+
The percentage of total steps at which the ControlNet starts applying.
623+
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
624+
The percentage of total steps at which the ControlNet stops applying.
617625
control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
618626
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
619627
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
@@ -674,6 +682,17 @@ def __call__(
674682
height = height or self.default_sample_size * self.vae_scale_factor
675683
width = width or self.default_sample_size * self.vae_scale_factor
676684

685+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
686+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
687+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
688+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
689+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
690+
mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
691+
control_guidance_start, control_guidance_end = (
692+
mult * [control_guidance_start],
693+
mult * [control_guidance_end],
694+
)
695+
677696
# 1. Check inputs. Raise error if not correct
678697
self.check_inputs(
679698
prompt,
@@ -839,7 +858,16 @@ def __call__(
839858
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
840859
self._num_timesteps = len(timesteps)
841860

842-
# 6. Denoising loop
861+
# 6. Create tensor stating which controlnets to keep
862+
controlnet_keep = []
863+
for i in range(len(timesteps)):
864+
keeps = [
865+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
866+
for s, e in zip(control_guidance_start, control_guidance_end)
867+
]
868+
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
869+
870+
# 7. Denoising loop
843871
with self.progress_bar(total=num_inference_steps) as progress_bar:
844872
for i, t in enumerate(timesteps):
845873
if self.interrupt:
@@ -856,12 +884,20 @@ def __call__(
856884
guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
857885
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
858886

887+
if isinstance(controlnet_keep[i], list):
888+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
889+
else:
890+
controlnet_cond_scale = controlnet_conditioning_scale
891+
if isinstance(controlnet_cond_scale, list):
892+
controlnet_cond_scale = controlnet_cond_scale[0]
893+
cond_scale = controlnet_cond_scale * controlnet_keep[i]
894+
859895
# controlnet
860896
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
861897
hidden_states=latents,
862898
controlnet_cond=control_image,
863899
controlnet_mode=control_mode,
864-
conditioning_scale=controlnet_conditioning_scale,
900+
conditioning_scale=cond_scale,
865901
timestep=timestep / 1000,
866902
guidance=guidance,
867903
pooled_projections=pooled_prompt_embeds,

src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@
6969
... prompt,
7070
... image=init_image,
7171
... control_image=control_image,
72-
... controlnet_conditioning_scale=0.6,
72+
... control_guidance_start=0.2,
73+
... control_guidance_end=0.8,
74+
... controlnet_conditioning_scale=1.0,
7375
... strength=0.7,
7476
... num_inference_steps=2,
7577
... guidance_scale=3.5,
@@ -631,6 +633,8 @@ def __call__(
631633
num_inference_steps: int = 28,
632634
timesteps: List[int] = None,
633635
guidance_scale: float = 7.0,
636+
control_guidance_start: Union[float, List[float]] = 0.0,
637+
control_guidance_end: Union[float, List[float]] = 1.0,
634638
control_mode: Optional[Union[int, List[int]]] = None,
635639
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
636640
num_images_per_prompt: Optional[int] = 1,
@@ -710,6 +714,17 @@ def __call__(
710714
height = height or self.default_sample_size * self.vae_scale_factor
711715
width = width or self.default_sample_size * self.vae_scale_factor
712716

717+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
718+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
719+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
720+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
721+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
722+
mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
723+
control_guidance_start, control_guidance_end = (
724+
mult * [control_guidance_start],
725+
mult * [control_guidance_end],
726+
)
727+
713728
self.check_inputs(
714729
prompt,
715730
prompt_2,
@@ -862,6 +877,14 @@ def __call__(
862877
latents,
863878
)
864879

880+
controlnet_keep = []
881+
for i in range(len(timesteps)):
882+
keeps = [
883+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
884+
for s, e in zip(control_guidance_start, control_guidance_end)
885+
]
886+
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
887+
865888
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
866889
self._num_timesteps = len(timesteps)
867890

@@ -877,11 +900,19 @@ def __call__(
877900
)
878901
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
879902

903+
if isinstance(controlnet_keep[i], list):
904+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
905+
else:
906+
controlnet_cond_scale = controlnet_conditioning_scale
907+
if isinstance(controlnet_cond_scale, list):
908+
controlnet_cond_scale = controlnet_cond_scale[0]
909+
cond_scale = controlnet_cond_scale * controlnet_keep[i]
910+
880911
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
881912
hidden_states=latents,
882913
controlnet_cond=control_image,
883914
controlnet_mode=control_mode,
884-
conditioning_scale=controlnet_conditioning_scale,
915+
conditioning_scale=cond_scale,
885916
timestep=timestep / 1000,
886917
guidance=guidance,
887918
pooled_projections=pooled_prompt_embeds,

src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@
7171
... image=init_image,
7272
... mask_image=mask_image,
7373
... control_image=control_image,
74+
... control_guidance_start=0.2,
75+
... control_guidance_end=0.8,
7476
... controlnet_conditioning_scale=0.7,
7577
... strength=0.7,
7678
... num_inference_steps=28,
@@ -737,6 +739,8 @@ def __call__(
737739
timesteps: List[int] = None,
738740
num_inference_steps: int = 28,
739741
guidance_scale: float = 7.0,
742+
control_guidance_start: Union[float, List[float]] = 0.0,
743+
control_guidance_end: Union[float, List[float]] = 1.0,
740744
control_mode: Optional[Union[int, List[int]]] = None,
741745
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
742746
num_images_per_prompt: Optional[int] = 1,
@@ -783,6 +787,10 @@ def __call__(
783787
Custom timesteps to use for the denoising process.
784788
guidance_scale (`float`, *optional*, defaults to 7.0):
785789
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
790+
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
791+
The percentage of total steps at which the ControlNet starts applying.
792+
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
793+
The percentage of total steps at which the ControlNet stops applying.
786794
control_mode (`int` or `List[int]`, *optional*):
787795
The mode for the ControlNet. If multiple ControlNets are used, this should be a list.
788796
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
@@ -826,6 +834,17 @@ def __call__(
826834
global_height = height
827835
global_width = width
828836

837+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
838+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
839+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
840+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
841+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
842+
mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
843+
control_guidance_start, control_guidance_end = (
844+
mult * [control_guidance_start],
845+
mult * [control_guidance_end],
846+
)
847+
829848
# 1. Check inputs
830849
self.check_inputs(
831850
prompt,
@@ -1031,6 +1050,14 @@ def __call__(
10311050
generator,
10321051
)
10331052

1053+
controlnet_keep = []
1054+
for i in range(len(timesteps)):
1055+
keeps = [
1056+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1057+
for s, e in zip(control_guidance_start, control_guidance_end)
1058+
]
1059+
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
1060+
10341061
# 9. Denoising loop
10351062
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
10361063
self._num_timesteps = len(timesteps)
@@ -1049,11 +1076,19 @@ def __call__(
10491076
else:
10501077
guidance = None
10511078

1079+
if isinstance(controlnet_keep[i], list):
1080+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1081+
else:
1082+
controlnet_cond_scale = controlnet_conditioning_scale
1083+
if isinstance(controlnet_cond_scale, list):
1084+
controlnet_cond_scale = controlnet_cond_scale[0]
1085+
cond_scale = controlnet_cond_scale * controlnet_keep[i]
1086+
10521087
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
10531088
hidden_states=latents,
10541089
controlnet_cond=control_image,
10551090
controlnet_mode=control_mode,
1056-
conditioning_scale=controlnet_conditioning_scale,
1091+
conditioning_scale=cond_scale,
10571092
timestep=timestep / 1000,
10581093
guidance=guidance,
10591094
pooled_projections=pooled_prompt_embeds,

0 commit comments

Comments
 (0)