From 21075fa665a5db55302aa7c075884e10faea71bb Mon Sep 17 00:00:00 2001 From: bghira Date: Wed, 6 Nov 2024 14:12:03 -0600 Subject: [PATCH 1/4] add skip_layers argument to SD3 transformer model class --- .../models/transformers/transformer_sd3.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index b28350b8ed9c..a06292f6682f 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -268,6 +268,7 @@ def forward( block_controlnet_hidden_states: List = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, + skip_layers: Optional[List[int]] = None, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ The [`SD3Transformer2DModel`] forward method. @@ -279,9 +280,9 @@ def forward( Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected from the embeddings of input conditions. - timestep ( `torch.LongTensor`): + timestep (`torch.LongTensor`): Used to indicate denoising step. - block_controlnet_hidden_states: (`list` of `torch.Tensor`): + block_controlnet_hidden_states (`list` of `torch.Tensor`): A list of tensors that if specified are added to the residuals of transformer blocks. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under @@ -290,6 +291,8 @@ def forward( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain tuple. + skip_layers (`list` of `int`, *optional*): + A list of layer indices to skip during the forward pass. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a @@ -317,6 +320,13 @@ def forward( encoder_hidden_states = self.context_embedder(encoder_hidden_states) for index_block, block in enumerate(self.transformer_blocks): + # Skip specified layers + if skip_layers is not None and index_block in skip_layers: + if block_controlnet_hidden_states is not None and block.context_pre_only is False: + interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states) + hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control] + continue + if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): @@ -336,7 +346,6 @@ def custom_forward(*inputs): temb, **ckpt_kwargs, ) - else: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb From fd4a22921b81d906d16b8b269462094d63e5ec41 Mon Sep 17 00:00:00 2001 From: bghira Date: Wed, 6 Nov 2024 14:14:55 -0600 Subject: [PATCH 2/4] add unit test for skip_layers in stable diffusion 3 --- .../test_models_transformer_sd3.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/models/transformers/test_models_transformer_sd3.py b/tests/models/transformers/test_models_transformer_sd3.py index af86fa9c3bc1..1ed9812e967d 100644 --- a/tests/models/transformers/test_models_transformer_sd3.py +++ b/tests/models/transformers/test_models_transformer_sd3.py @@ -147,3 +147,24 @@ def test_set_attn_processor_for_determinism(self): def test_gradient_checkpointing_is_applied(self): expected_set = {"SD3Transformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + def test_skip_layers(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + # Forward pass without skipping layers + output_full = model(**inputs_dict).sample + + # Forward pass with skipping layers 0 (since there's only one layer in this test setup) + inputs_dict_with_skip = inputs_dict.copy() + inputs_dict_with_skip['skip_layers'] = [0] + output_skip = model(**inputs_dict_with_skip).sample + + # Check that the outputs are different + self.assertFalse( + torch.allclose(output_full, output_skip, atol=1e-5), + "Outputs should differ when layers are skipped" + ) + + # Check that the outputs have the same shape + self.assertEqual(output_full.shape, output_skip.shape, "Outputs should have the same shape") From 103536fda23d940dc4bba4b5d4978f5ebf90c50d Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 7 Nov 2024 06:24:09 -0600 Subject: [PATCH 3/4] sd3: pipeline should support skip layer guidance --- .../pipeline_stable_diffusion_3.py | 38 ++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 43cb40e6e733..0cf183b4e23f 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -694,6 +694,10 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 256, + skip_guidance_layers: List[int] = None, + skip_layer_guidance_scale: int = 2.8, + skip_layer_guidance_stop: int = 0.2, + skip_layer_guidance_start: int = 0.01, ): r""" Function invoked when calling the pipeline for generation. @@ -778,6 +782,22 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + skip_guidance_layers (`List[int]`, *optional*): A list of integers that specify layers to skip during guidance. + If not provided, all layers will be used for guidance. If provided, the guidance will only be applied + to the layers specified in the list. Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is + [7, 8, 9]. + skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in + `skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers` + with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers with + a scale of `1`. + skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in + `skip_guidance_layers` will stop. The guidance will be applied to the layers specified in + `skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by + StabiltyAI for Stable Diffusion 3.5 Medium is 0.2. + skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in + `skip_guidance_layers` will start. The guidance will be applied to the layers specified in + `skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by + StabiltyAI for Stable Diffusion 3.5 Medium is 0.01. Examples: @@ -809,6 +829,7 @@ def __call__( ) self._guidance_scale = guidance_scale + self._skip_layer_guidance_scale = skip_layer_guidance_scale self._clip_skip = clip_skip self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False @@ -851,6 +872,9 @@ def __call__( ) if self.do_classifier_free_guidance: + if skip_guidance_layers is not None: + original_prompt_embeds = prompt_embeds + original_pooled_prompt_embeds = pooled_prompt_embeds prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) @@ -879,7 +903,7 @@ def __call__( continue # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance and skip_guidance_layers is None else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) @@ -896,6 +920,18 @@ def __call__( if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + should_skip_layers = True if i > num_inference_steps * skip_layer_guidance_start and i < num_inference_steps* skip_layer_guidance_stop else False + if skip_guidance_layers is not None and should_skip_layers: + noise_pred_skip_layers = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=original_prompt_embeds, + pooled_projections=original_pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + skip_layers=skip_guidance_layers, + )[0] + noise_pred = noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype From 481591a0f844636e6183df854294721cb9857d13 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 19 Nov 2024 06:08:39 +0100 Subject: [PATCH 4/4] up --- .../models/transformers/transformer_sd3.py | 10 ++---- .../pipeline_stable_diffusion_3.py | 33 ++++++++++++++----- .../test_models_transformer_sd3.py | 5 ++- 3 files changed, 29 insertions(+), 19 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index e9149fdc629d..a89a5e26ee97 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -321,13 +321,9 @@ def forward( for index_block, block in enumerate(self.transformer_blocks): # Skip specified layers - if skip_layers is not None and index_block in skip_layers: - if block_controlnet_hidden_states is not None and block.context_pre_only is False: - interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states) - hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control] - continue + is_skip = True if skip_layers is not None and index_block in skip_layers else False - if torch.is_grad_enabled() and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -346,7 +342,7 @@ def custom_forward(*inputs): temb, **ckpt_kwargs, ) - else: + elif not is_skip: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb ) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 0cf183b4e23f..a77231cdc02d 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -642,6 +642,10 @@ def prepare_latents( def guidance_scale(self): return self._guidance_scale + @property + def skip_guidance_layers(self): + return self._skip_guidance_layers + @property def clip_skip(self): return self._clip_skip @@ -782,14 +786,14 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. - skip_guidance_layers (`List[int]`, *optional*): A list of integers that specify layers to skip during guidance. - If not provided, all layers will be used for guidance. If provided, the guidance will only be applied - to the layers specified in the list. Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is - [7, 8, 9]. + skip_guidance_layers (`List[int]`, *optional*): + A list of integers that specify layers to skip during guidance. If not provided, all layers will be + used for guidance. If provided, the guidance will only be applied to the layers specified in the list. + Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9]. skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in `skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers` - with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers with - a scale of `1`. + with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers + with a scale of `1`. skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in `skip_guidance_layers` will stop. The guidance will be applied to the layers specified in `skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by @@ -903,7 +907,11 @@ def __call__( continue # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance and skip_guidance_layers is None else latents + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance and skip_guidance_layers is None + else latents + ) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) @@ -920,7 +928,12 @@ def __call__( if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - should_skip_layers = True if i > num_inference_steps * skip_layer_guidance_start and i < num_inference_steps* skip_layer_guidance_stop else False + should_skip_layers = ( + True + if i > num_inference_steps * skip_layer_guidance_start + and i < num_inference_steps * skip_layer_guidance_stop + else False + ) if skip_guidance_layers is not None and should_skip_layers: noise_pred_skip_layers = self.transformer( hidden_states=latent_model_input, @@ -931,7 +944,9 @@ def __call__( return_dict=False, skip_layers=skip_guidance_layers, )[0] - noise_pred = noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale + noise_pred = ( + noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale + ) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype diff --git a/tests/models/transformers/test_models_transformer_sd3.py b/tests/models/transformers/test_models_transformer_sd3.py index 1ed9812e967d..b9e12a11fafa 100644 --- a/tests/models/transformers/test_models_transformer_sd3.py +++ b/tests/models/transformers/test_models_transformer_sd3.py @@ -157,13 +157,12 @@ def test_skip_layers(self): # Forward pass with skipping layers 0 (since there's only one layer in this test setup) inputs_dict_with_skip = inputs_dict.copy() - inputs_dict_with_skip['skip_layers'] = [0] + inputs_dict_with_skip["skip_layers"] = [0] output_skip = model(**inputs_dict_with_skip).sample # Check that the outputs are different self.assertFalse( - torch.allclose(output_full, output_skip, atol=1e-5), - "Outputs should differ when layers are skipped" + torch.allclose(output_full, output_skip, atol=1e-5), "Outputs should differ when layers are skipped" ) # Check that the outputs have the same shape