72
72
>>> image = pipe(
73
73
... prompt,
74
74
... control_image=control_image,
75
- ... controlnet_conditioning_scale=0.6,
75
+ ... control_guidance_start=0.2,
76
+ ... control_guidance_end=0.8,
77
+ ... controlnet_conditioning_scale=1.0,
76
78
... num_inference_steps=28,
77
79
... guidance_scale=3.5,
78
80
... ).images[0]
@@ -572,6 +574,8 @@ def __call__(
572
574
num_inference_steps : int = 28 ,
573
575
timesteps : List [int ] = None ,
574
576
guidance_scale : float = 7.0 ,
577
+ control_guidance_start : Union [float , List [float ]] = 0.0 ,
578
+ control_guidance_end : Union [float , List [float ]] = 1.0 ,
575
579
control_image : PipelineImageInput = None ,
576
580
control_mode : Optional [Union [int , List [int ]]] = None ,
577
581
controlnet_conditioning_scale : Union [float , List [float ]] = 1.0 ,
@@ -614,6 +618,10 @@ def __call__(
614
618
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
615
619
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
616
620
usually at the expense of lower image quality.
621
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
622
+ The percentage of total steps at which the ControlNet starts applying.
623
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
624
+ The percentage of total steps at which the ControlNet stops applying.
617
625
control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
618
626
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
619
627
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
@@ -674,6 +682,17 @@ def __call__(
674
682
height = height or self .default_sample_size * self .vae_scale_factor
675
683
width = width or self .default_sample_size * self .vae_scale_factor
676
684
685
+ if not isinstance (control_guidance_start , list ) and isinstance (control_guidance_end , list ):
686
+ control_guidance_start = len (control_guidance_end ) * [control_guidance_start ]
687
+ elif not isinstance (control_guidance_end , list ) and isinstance (control_guidance_start , list ):
688
+ control_guidance_end = len (control_guidance_start ) * [control_guidance_end ]
689
+ elif not isinstance (control_guidance_start , list ) and not isinstance (control_guidance_end , list ):
690
+ mult = len (self .controlnet .nets ) if isinstance (self .controlnet , FluxMultiControlNetModel ) else 1
691
+ control_guidance_start , control_guidance_end = (
692
+ mult * [control_guidance_start ],
693
+ mult * [control_guidance_end ],
694
+ )
695
+
677
696
# 1. Check inputs. Raise error if not correct
678
697
self .check_inputs (
679
698
prompt ,
@@ -839,7 +858,16 @@ def __call__(
839
858
num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
840
859
self ._num_timesteps = len (timesteps )
841
860
842
- # 6. Denoising loop
861
+ # 6. Create tensor stating which controlnets to keep
862
+ controlnet_keep = []
863
+ for i in range (len (timesteps )):
864
+ keeps = [
865
+ 1.0 - float (i / len (timesteps ) < s or (i + 1 ) / len (timesteps ) > e )
866
+ for s , e in zip (control_guidance_start , control_guidance_end )
867
+ ]
868
+ controlnet_keep .append (keeps [0 ] if isinstance (self .controlnet , FluxControlNetModel ) else keeps )
869
+
870
+ # 7. Denoising loop
843
871
with self .progress_bar (total = num_inference_steps ) as progress_bar :
844
872
for i , t in enumerate (timesteps ):
845
873
if self .interrupt :
@@ -856,12 +884,20 @@ def __call__(
856
884
guidance = torch .tensor ([guidance_scale ], device = device ) if use_guidance else None
857
885
guidance = guidance .expand (latents .shape [0 ]) if guidance is not None else None
858
886
887
+ if isinstance (controlnet_keep [i ], list ):
888
+ cond_scale = [c * s for c , s in zip (controlnet_conditioning_scale , controlnet_keep [i ])]
889
+ else :
890
+ controlnet_cond_scale = controlnet_conditioning_scale
891
+ if isinstance (controlnet_cond_scale , list ):
892
+ controlnet_cond_scale = controlnet_cond_scale [0 ]
893
+ cond_scale = controlnet_cond_scale * controlnet_keep [i ]
894
+
859
895
# controlnet
860
896
controlnet_block_samples , controlnet_single_block_samples = self .controlnet (
861
897
hidden_states = latents ,
862
898
controlnet_cond = control_image ,
863
899
controlnet_mode = control_mode ,
864
- conditioning_scale = controlnet_conditioning_scale ,
900
+ conditioning_scale = cond_scale ,
865
901
timestep = timestep / 1000 ,
866
902
guidance = guidance ,
867
903
pooled_projections = pooled_prompt_embeds ,
0 commit comments