@@ -642,6 +642,10 @@ def prepare_latents(
642642 def guidance_scale (self ):
643643 return self ._guidance_scale
644644
645+ @property
646+ def skip_guidance_layers (self ):
647+ return self ._skip_guidance_layers
648+
645649 @property
646650 def clip_skip (self ):
647651 return self ._clip_skip
@@ -694,6 +698,10 @@ def __call__(
694698 callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
695699 callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
696700 max_sequence_length : int = 256 ,
701+ skip_guidance_layers : List [int ] = None ,
702+ skip_layer_guidance_scale : int = 2.8 ,
703+ skip_layer_guidance_stop : int = 0.2 ,
704+ skip_layer_guidance_start : int = 0.01 ,
697705 ):
698706 r"""
699707 Function invoked when calling the pipeline for generation.
@@ -778,6 +786,22 @@ def __call__(
778786 will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
779787 `._callback_tensor_inputs` attribute of your pipeline class.
780788 max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
789+ skip_guidance_layers (`List[int]`, *optional*):
790+ A list of integers that specify layers to skip during guidance. If not provided, all layers will be
791+ used for guidance. If provided, the guidance will only be applied to the layers specified in the list.
792+ Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9].
793+ skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in
794+ `skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers`
795+ with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers
796+ with a scale of `1`.
797+ skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in
798+ `skip_guidance_layers` will stop. The guidance will be applied to the layers specified in
799+ `skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by
800+ StabiltyAI for Stable Diffusion 3.5 Medium is 0.2.
801+ skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in
802+ `skip_guidance_layers` will start. The guidance will be applied to the layers specified in
803+ `skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
804+ StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.
781805
782806 Examples:
783807
@@ -809,6 +833,7 @@ def __call__(
809833 )
810834
811835 self ._guidance_scale = guidance_scale
836+ self ._skip_layer_guidance_scale = skip_layer_guidance_scale
812837 self ._clip_skip = clip_skip
813838 self ._joint_attention_kwargs = joint_attention_kwargs
814839 self ._interrupt = False
@@ -851,6 +876,9 @@ def __call__(
851876 )
852877
853878 if self .do_classifier_free_guidance :
879+ if skip_guidance_layers is not None :
880+ original_prompt_embeds = prompt_embeds
881+ original_pooled_prompt_embeds = pooled_prompt_embeds
854882 prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 0 )
855883 pooled_prompt_embeds = torch .cat ([negative_pooled_prompt_embeds , pooled_prompt_embeds ], dim = 0 )
856884
@@ -879,7 +907,11 @@ def __call__(
879907 continue
880908
881909 # expand the latents if we are doing classifier free guidance
882- latent_model_input = torch .cat ([latents ] * 2 ) if self .do_classifier_free_guidance else latents
910+ latent_model_input = (
911+ torch .cat ([latents ] * 2 )
912+ if self .do_classifier_free_guidance and skip_guidance_layers is None
913+ else latents
914+ )
883915 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
884916 timestep = t .expand (latent_model_input .shape [0 ])
885917
@@ -896,6 +928,25 @@ def __call__(
896928 if self .do_classifier_free_guidance :
897929 noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
898930 noise_pred = noise_pred_uncond + self .guidance_scale * (noise_pred_text - noise_pred_uncond )
931+ should_skip_layers = (
932+ True
933+ if i > num_inference_steps * skip_layer_guidance_start
934+ and i < num_inference_steps * skip_layer_guidance_stop
935+ else False
936+ )
937+ if skip_guidance_layers is not None and should_skip_layers :
938+ noise_pred_skip_layers = self .transformer (
939+ hidden_states = latent_model_input ,
940+ timestep = timestep ,
941+ encoder_hidden_states = original_prompt_embeds ,
942+ pooled_projections = original_pooled_prompt_embeds ,
943+ joint_attention_kwargs = self .joint_attention_kwargs ,
944+ return_dict = False ,
945+ skip_layers = skip_guidance_layers ,
946+ )[0 ]
947+ noise_pred = (
948+ noise_pred + (noise_pred_text - noise_pred_skip_layers ) * self ._skip_layer_guidance_scale
949+ )
899950
900951 # compute the previous noisy sample x_t -> x_t-1
901952 latents_dtype = latents .dtype
0 commit comments