Skip to content

Commit a388daf

Browse files
committed
flux controlnet control_guidance_start and control_guidance_end implement
1 parent 61d3764 commit a388daf

File tree

3 files changed

+97
-6
lines changed

3 files changed

+97
-6
lines changed

src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@
7272
>>> image = pipe(
7373
... prompt,
7474
... control_image=control_image,
75-
... controlnet_conditioning_scale=0.6,
75+
... control_guidance_start=0.2,
76+
... control_guidance_end=0.8,
77+
... controlnet_conditioning_scale=1.0,
7678
... num_inference_steps=28,
7779
... guidance_scale=3.5,
7880
... ).images[0]
@@ -572,6 +574,8 @@ def __call__(
572574
num_inference_steps: int = 28,
573575
timesteps: List[int] = None,
574576
guidance_scale: float = 7.0,
577+
control_guidance_start: Union[float, List[float]] = 0.0,
578+
control_guidance_end: Union[float, List[float]] = 1.0,
575579
control_image: PipelineImageInput = None,
576580
control_mode: Optional[Union[int, List[int]]] = None,
577581
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
@@ -674,6 +678,17 @@ def __call__(
674678
height = height or self.default_sample_size * self.vae_scale_factor
675679
width = width or self.default_sample_size * self.vae_scale_factor
676680

681+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
682+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
683+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
684+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
685+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
686+
mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
687+
control_guidance_start, control_guidance_end = (
688+
mult * [control_guidance_start],
689+
mult * [control_guidance_end],
690+
)
691+
677692
# 1. Check inputs. Raise error if not correct
678693
self.check_inputs(
679694
prompt,
@@ -839,7 +854,16 @@ def __call__(
839854
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
840855
self._num_timesteps = len(timesteps)
841856

842-
# 6. Denoising loop
857+
# 6. Create tensor stating which controlnets to keep
858+
controlnet_keep = []
859+
for i in range(len(timesteps)):
860+
keeps = [
861+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
862+
for s, e in zip(control_guidance_start, control_guidance_end)
863+
]
864+
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
865+
866+
# 7. Denoising loop
843867
with self.progress_bar(total=num_inference_steps) as progress_bar:
844868
for i, t in enumerate(timesteps):
845869
if self.interrupt:
@@ -856,12 +880,19 @@ def __call__(
856880
guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
857881
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
858882

883+
if isinstance(controlnet_keep[i], list):
884+
current_controlnet_conditioning_scale = [
885+
c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])
886+
]
887+
else:
888+
current_controlnet_conditioning_scale = controlnet_conditioning_scale * controlnet_keep[i]
889+
859890
# controlnet
860891
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
861892
hidden_states=latents,
862893
controlnet_cond=control_image,
863894
controlnet_mode=control_mode,
864-
conditioning_scale=controlnet_conditioning_scale,
895+
conditioning_scale=current_controlnet_conditioning_scale,
865896
timestep=timestep / 1000,
866897
guidance=guidance,
867898
pooled_projections=pooled_prompt_embeds,

src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@
6969
... prompt,
7070
... image=init_image,
7171
... control_image=control_image,
72-
... controlnet_conditioning_scale=0.6,
72+
... control_guidance_start=0.2,
73+
... control_guidance_end=0.8,
74+
... controlnet_conditioning_scale=1.0,
7375
... strength=0.7,
7476
... num_inference_steps=2,
7577
... guidance_scale=3.5,
@@ -631,6 +633,8 @@ def __call__(
631633
num_inference_steps: int = 28,
632634
timesteps: List[int] = None,
633635
guidance_scale: float = 7.0,
636+
control_guidance_start: Union[float, List[float]] = 0.0,
637+
control_guidance_end: Union[float, List[float]] = 1.0,
634638
control_mode: Optional[Union[int, List[int]]] = None,
635639
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
636640
num_images_per_prompt: Optional[int] = 1,
@@ -710,6 +714,17 @@ def __call__(
710714
height = height or self.default_sample_size * self.vae_scale_factor
711715
width = width or self.default_sample_size * self.vae_scale_factor
712716

717+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
718+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
719+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
720+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
721+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
722+
mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
723+
control_guidance_start, control_guidance_end = (
724+
mult * [control_guidance_start],
725+
mult * [control_guidance_end],
726+
)
727+
713728
self.check_inputs(
714729
prompt,
715730
prompt_2,
@@ -862,6 +877,14 @@ def __call__(
862877
latents,
863878
)
864879

880+
controlnet_keep = []
881+
for i in range(len(timesteps)):
882+
keeps = [
883+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
884+
for s, e in zip(control_guidance_start, control_guidance_end)
885+
]
886+
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
887+
865888
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
866889
self._num_timesteps = len(timesteps)
867890

@@ -877,11 +900,18 @@ def __call__(
877900
)
878901
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
879902

903+
if isinstance(controlnet_keep[i], list):
904+
current_controlnet_conditioning_scale = [
905+
c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])
906+
]
907+
else:
908+
current_controlnet_conditioning_scale = controlnet_conditioning_scale * controlnet_keep[i]
909+
880910
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
881911
hidden_states=latents,
882912
controlnet_cond=control_image,
883913
controlnet_mode=control_mode,
884-
conditioning_scale=controlnet_conditioning_scale,
914+
conditioning_scale=current_controlnet_conditioning_scale,
885915
timestep=timestep / 1000,
886916
guidance=guidance,
887917
pooled_projections=pooled_prompt_embeds,

src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@
7171
... image=init_image,
7272
... mask_image=mask_image,
7373
... control_image=control_image,
74+
... control_guidance_start=0.2,
75+
... control_guidance_end=0.8,
7476
... controlnet_conditioning_scale=0.7,
7577
... strength=0.7,
7678
... num_inference_steps=28,
@@ -737,6 +739,8 @@ def __call__(
737739
timesteps: List[int] = None,
738740
num_inference_steps: int = 28,
739741
guidance_scale: float = 7.0,
742+
control_guidance_start: Union[float, List[float]] = 0.0,
743+
control_guidance_end: Union[float, List[float]] = 1.0,
740744
control_mode: Optional[Union[int, List[int]]] = None,
741745
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
742746
num_images_per_prompt: Optional[int] = 1,
@@ -826,6 +830,17 @@ def __call__(
826830
global_height = height
827831
global_width = width
828832

833+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
834+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
835+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
836+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
837+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
838+
mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
839+
control_guidance_start, control_guidance_end = (
840+
mult * [control_guidance_start],
841+
mult * [control_guidance_end],
842+
)
843+
829844
# 1. Check inputs
830845
self.check_inputs(
831846
prompt,
@@ -1031,6 +1046,14 @@ def __call__(
10311046
generator,
10321047
)
10331048

1049+
controlnet_keep = []
1050+
for i in range(len(timesteps)):
1051+
keeps = [
1052+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1053+
for s, e in zip(control_guidance_start, control_guidance_end)
1054+
]
1055+
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
1056+
10341057
# 9. Denoising loop
10351058
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
10361059
self._num_timesteps = len(timesteps)
@@ -1049,11 +1072,18 @@ def __call__(
10491072
else:
10501073
guidance = None
10511074

1075+
if isinstance(controlnet_keep[i], list):
1076+
current_controlnet_conditioning_scale = [
1077+
c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])
1078+
]
1079+
else:
1080+
current_controlnet_conditioning_scale = controlnet_conditioning_scale * controlnet_keep[i]
1081+
10521082
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
10531083
hidden_states=latents,
10541084
controlnet_cond=control_image,
10551085
controlnet_mode=control_mode,
1056-
conditioning_scale=controlnet_conditioning_scale,
1086+
conditioning_scale=current_controlnet_conditioning_scale,
10571087
timestep=timestep / 1000,
10581088
guidance=guidance,
10591089
pooled_projections=pooled_prompt_embeds,

0 commit comments

Comments
 (0)