7272 >>> image = pipe(
7373 ... prompt,
7474 ... control_image=control_image,
75- ... controlnet_conditioning_scale=0.6,
75+ ... control_guidance_start=0.2,
76+ ... control_guidance_end=0.8,
77+ ... controlnet_conditioning_scale=1.0,
7678 ... num_inference_steps=28,
7779 ... guidance_scale=3.5,
7880 ... ).images[0]
@@ -572,6 +574,8 @@ def __call__(
572574 num_inference_steps : int = 28 ,
573575 timesteps : List [int ] = None ,
574576 guidance_scale : float = 7.0 ,
577+ control_guidance_start : Union [float , List [float ]] = 0.0 ,
578+ control_guidance_end : Union [float , List [float ]] = 1.0 ,
575579 control_image : PipelineImageInput = None ,
576580 control_mode : Optional [Union [int , List [int ]]] = None ,
577581 controlnet_conditioning_scale : Union [float , List [float ]] = 1.0 ,
@@ -614,6 +618,10 @@ def __call__(
614618 Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
615619 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
616620 usually at the expense of lower image quality.
621+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
622+ The percentage of total steps at which the ControlNet starts applying.
623+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
624+ The percentage of total steps at which the ControlNet stops applying.
617625 control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
618626 `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
619627 The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
@@ -674,6 +682,17 @@ def __call__(
674682 height = height or self .default_sample_size * self .vae_scale_factor
675683 width = width or self .default_sample_size * self .vae_scale_factor
676684
685+ if not isinstance (control_guidance_start , list ) and isinstance (control_guidance_end , list ):
686+ control_guidance_start = len (control_guidance_end ) * [control_guidance_start ]
687+ elif not isinstance (control_guidance_end , list ) and isinstance (control_guidance_start , list ):
688+ control_guidance_end = len (control_guidance_start ) * [control_guidance_end ]
689+ elif not isinstance (control_guidance_start , list ) and not isinstance (control_guidance_end , list ):
690+ mult = len (self .controlnet .nets ) if isinstance (self .controlnet , FluxMultiControlNetModel ) else 1
691+ control_guidance_start , control_guidance_end = (
692+ mult * [control_guidance_start ],
693+ mult * [control_guidance_end ],
694+ )
695+
677696 # 1. Check inputs. Raise error if not correct
678697 self .check_inputs (
679698 prompt ,
@@ -839,7 +858,16 @@ def __call__(
839858 num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
840859 self ._num_timesteps = len (timesteps )
841860
842- # 6. Denoising loop
861+ # 6. Create tensor stating which controlnets to keep
862+ controlnet_keep = []
863+ for i in range (len (timesteps )):
864+ keeps = [
865+ 1.0 - float (i / len (timesteps ) < s or (i + 1 ) / len (timesteps ) > e )
866+ for s , e in zip (control_guidance_start , control_guidance_end )
867+ ]
868+ controlnet_keep .append (keeps [0 ] if isinstance (self .controlnet , FluxControlNetModel ) else keeps )
869+
870+ # 7. Denoising loop
843871 with self .progress_bar (total = num_inference_steps ) as progress_bar :
844872 for i , t in enumerate (timesteps ):
845873 if self .interrupt :
@@ -856,12 +884,20 @@ def __call__(
856884 guidance = torch .tensor ([guidance_scale ], device = device ) if use_guidance else None
857885 guidance = guidance .expand (latents .shape [0 ]) if guidance is not None else None
858886
887+ if isinstance (controlnet_keep [i ], list ):
888+ cond_scale = [c * s for c , s in zip (controlnet_conditioning_scale , controlnet_keep [i ])]
889+ else :
890+ controlnet_cond_scale = controlnet_conditioning_scale
891+ if isinstance (controlnet_cond_scale , list ):
892+ controlnet_cond_scale = controlnet_cond_scale [0 ]
893+ cond_scale = controlnet_cond_scale * controlnet_keep [i ]
894+
859895 # controlnet
860896 controlnet_block_samples , controlnet_single_block_samples = self .controlnet (
861897 hidden_states = latents ,
862898 controlnet_cond = control_image ,
863899 controlnet_mode = control_mode ,
864- conditioning_scale = controlnet_conditioning_scale ,
900+ conditioning_scale = cond_scale ,
865901 timestep = timestep / 1000 ,
866902 guidance = guidance ,
867903 pooled_projections = pooled_prompt_embeds ,
0 commit comments