diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index f39a102c7256..a89a5e26ee97 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -268,6 +268,7 @@ def forward( block_controlnet_hidden_states: List = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, + skip_layers: Optional[List[int]] = None, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ The [`SD3Transformer2DModel`] forward method. @@ -279,9 +280,9 @@ def forward( Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected from the embeddings of input conditions. - timestep ( `torch.LongTensor`): + timestep (`torch.LongTensor`): Used to indicate denoising step. - block_controlnet_hidden_states: (`list` of `torch.Tensor`): + block_controlnet_hidden_states (`list` of `torch.Tensor`): A list of tensors that if specified are added to the residuals of transformer blocks. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under @@ -290,6 +291,8 @@ def forward( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain tuple. + skip_layers (`list` of `int`, *optional*): + A list of layer indices to skip during the forward pass. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a @@ -317,7 +320,10 @@ def forward( encoder_hidden_states = self.context_embedder(encoder_hidden_states) for index_block, block in enumerate(self.transformer_blocks): - if torch.is_grad_enabled() and self.gradient_checkpointing: + # Skip specified layers + is_skip = True if skip_layers is not None and index_block in skip_layers else False + + if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -336,8 +342,7 @@ def custom_forward(*inputs): temb, **ckpt_kwargs, ) - - else: + elif not is_skip: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb ) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 43cb40e6e733..a77231cdc02d 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -642,6 +642,10 @@ def prepare_latents( def guidance_scale(self): return self._guidance_scale + @property + def skip_guidance_layers(self): + return self._skip_guidance_layers + @property def clip_skip(self): return self._clip_skip @@ -694,6 +698,10 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 256, + skip_guidance_layers: List[int] = None, + skip_layer_guidance_scale: int = 2.8, + skip_layer_guidance_stop: int = 0.2, + skip_layer_guidance_start: int = 0.01, ): r""" Function invoked when calling the pipeline for generation. @@ -778,6 +786,22 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + skip_guidance_layers (`List[int]`, *optional*): + A list of integers that specify layers to skip during guidance. If not provided, all layers will be + used for guidance. If provided, the guidance will only be applied to the layers specified in the list. + Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9]. + skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in + `skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers` + with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers + with a scale of `1`. + skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in + `skip_guidance_layers` will stop. The guidance will be applied to the layers specified in + `skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by + StabiltyAI for Stable Diffusion 3.5 Medium is 0.2. + skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in + `skip_guidance_layers` will start. The guidance will be applied to the layers specified in + `skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by + StabiltyAI for Stable Diffusion 3.5 Medium is 0.01. Examples: @@ -809,6 +833,7 @@ def __call__( ) self._guidance_scale = guidance_scale + self._skip_layer_guidance_scale = skip_layer_guidance_scale self._clip_skip = clip_skip self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False @@ -851,6 +876,9 @@ def __call__( ) if self.do_classifier_free_guidance: + if skip_guidance_layers is not None: + original_prompt_embeds = prompt_embeds + original_pooled_prompt_embeds = pooled_prompt_embeds prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) @@ -879,7 +907,11 @@ def __call__( continue # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance and skip_guidance_layers is None + else latents + ) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) @@ -896,6 +928,25 @@ def __call__( if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + should_skip_layers = ( + True + if i > num_inference_steps * skip_layer_guidance_start + and i < num_inference_steps * skip_layer_guidance_stop + else False + ) + if skip_guidance_layers is not None and should_skip_layers: + noise_pred_skip_layers = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=original_prompt_embeds, + pooled_projections=original_pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + skip_layers=skip_guidance_layers, + )[0] + noise_pred = ( + noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale + ) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype diff --git a/tests/models/transformers/test_models_transformer_sd3.py b/tests/models/transformers/test_models_transformer_sd3.py index af86fa9c3bc1..b9e12a11fafa 100644 --- a/tests/models/transformers/test_models_transformer_sd3.py +++ b/tests/models/transformers/test_models_transformer_sd3.py @@ -147,3 +147,23 @@ def test_set_attn_processor_for_determinism(self): def test_gradient_checkpointing_is_applied(self): expected_set = {"SD3Transformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + def test_skip_layers(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + # Forward pass without skipping layers + output_full = model(**inputs_dict).sample + + # Forward pass with skipping layers 0 (since there's only one layer in this test setup) + inputs_dict_with_skip = inputs_dict.copy() + inputs_dict_with_skip["skip_layers"] = [0] + output_skip = model(**inputs_dict_with_skip).sample + + # Check that the outputs are different + self.assertFalse( + torch.allclose(output_full, output_skip, atol=1e-5), "Outputs should differ when layers are skipped" + ) + + # Check that the outputs have the same shape + self.assertEqual(output_full.shape, output_skip.shape, "Outputs should have the same shape")