Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/diffusers/models/transformers/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def forward(
block_controlnet_hidden_states: List = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
skip_layers: Optional[List[int]] = None,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`SD3Transformer2DModel`] forward method.
Expand All @@ -279,9 +280,9 @@ def forward(
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
timestep (`torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
block_controlnet_hidden_states (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
Expand All @@ -290,6 +291,8 @@ def forward(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
skip_layers (`list` of `int`, *optional*):
A list of layer indices to skip during the forward pass.

Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
Expand Down Expand Up @@ -317,6 +320,13 @@ def forward(
encoder_hidden_states = self.context_embedder(encoder_hidden_states)

for index_block, block in enumerate(self.transformer_blocks):
# Skip specified layers
if skip_layers is not None and index_block in skip_layers:
if block_controlnet_hidden_states is not None and block.context_pre_only is False:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we skip the block of code that needs to be skipped instead of adding duplicated code here
otherwise, if we have to change this part of the code that handles controlnet residual in the future, we have to remember to change both places, which is not great

interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states)
hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]
continue

if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
Expand All @@ -336,7 +346,6 @@ def custom_forward(*inputs):
temb,
**ckpt_kwargs,
)

else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,10 @@ def __call__(
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 256,
skip_guidance_layers: List[int] = None,
skip_layer_guidance_scale: int = 2.8,
skip_layer_guidance_stop: int = 0.2,
skip_layer_guidance_start: int = 0.01,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -778,6 +782,22 @@ def __call__(
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
skip_guidance_layers (`List[int]`, *optional*): A list of integers that specify layers to skip during guidance.
If not provided, all layers will be used for guidance. If provided, the guidance will only be applied
to the layers specified in the list. Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is
[7, 8, 9].
skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in
`skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers`
with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers with
a scale of `1`.
skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in
`skip_guidance_layers` will stop. The guidance will be applied to the layers specified in
`skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by
StabiltyAI for Stable Diffusion 3.5 Medium is 0.2.
skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in
`skip_guidance_layers` will start. The guidance will be applied to the layers specified in
`skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.

Examples:

Expand Down Expand Up @@ -809,6 +829,7 @@ def __call__(
)

self._guidance_scale = guidance_scale
self._skip_layer_guidance_scale = skip_layer_guidance_scale
Copy link
Collaborator

@yiyixuxu yiyixuxu Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to add a decorator for this too, like this

self._clip_skip = clip_skip
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
Expand Down Expand Up @@ -851,6 +872,9 @@ def __call__(
)

if self.do_classifier_free_guidance:
if skip_guidance_layers is not None:
original_prompt_embeds = prompt_embeds
original_pooled_prompt_embeds = pooled_prompt_embeds
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)

Expand Down Expand Up @@ -879,7 +903,7 @@ def __call__(
continue

# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance and skip_guidance_layers is None else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])

Expand All @@ -896,6 +920,18 @@ def __call__(
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
should_skip_layers = True if i > num_inference_steps * skip_layer_guidance_start and i < num_inference_steps* skip_layer_guidance_stop else False
if skip_guidance_layers is not None and should_skip_layers:
noise_pred_skip_layers = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=original_prompt_embeds,
pooled_projections=original_pooled_prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
skip_layers=skip_guidance_layers,
)[0]
noise_pred = noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale

# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
Expand Down
21 changes: 21 additions & 0 deletions tests/models/transformers/test_models_transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,24 @@ def test_set_attn_processor_for_determinism(self):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"SD3Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

def test_skip_layers(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)

# Forward pass without skipping layers
output_full = model(**inputs_dict).sample

# Forward pass with skipping layers 0 (since there's only one layer in this test setup)
inputs_dict_with_skip = inputs_dict.copy()
inputs_dict_with_skip['skip_layers'] = [0]
output_skip = model(**inputs_dict_with_skip).sample

# Check that the outputs are different
self.assertFalse(
torch.allclose(output_full, output_skip, atol=1e-5),
"Outputs should differ when layers are skipped"
)

# Check that the outputs have the same shape
self.assertEqual(output_full.shape, output_skip.shape, "Outputs should have the same shape")
Loading