Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/diffusers/models/transformers/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def forward(
block_controlnet_hidden_states: List = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
skip_layers: Optional[List[int]] = None,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`SD3Transformer2DModel`] forward method.
Expand All @@ -279,9 +280,9 @@ def forward(
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
timestep (`torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
block_controlnet_hidden_states (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
Expand All @@ -290,6 +291,8 @@ def forward(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
skip_layers (`list` of `int`, *optional*):
A list of layer indices to skip during the forward pass.

Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
Expand Down Expand Up @@ -317,7 +320,10 @@ def forward(
encoder_hidden_states = self.context_embedder(encoder_hidden_states)

for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
# Skip specified layers
is_skip = True if skip_layers is not None and index_block in skip_layers else False

if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand All @@ -336,8 +342,7 @@ def custom_forward(*inputs):
temb,
**ckpt_kwargs,
)

else:
elif not is_skip:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,10 @@ def prepare_latents(
def guidance_scale(self):
return self._guidance_scale

@property
def skip_guidance_layers(self):
return self._skip_guidance_layers

@property
def clip_skip(self):
return self._clip_skip
Expand Down Expand Up @@ -694,6 +698,10 @@ def __call__(
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 256,
skip_guidance_layers: List[int] = None,
skip_layer_guidance_scale: int = 2.8,
skip_layer_guidance_stop: int = 0.2,
skip_layer_guidance_start: int = 0.01,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -778,6 +786,22 @@ def __call__(
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
skip_guidance_layers (`List[int]`, *optional*):
A list of integers that specify layers to skip during guidance. If not provided, all layers will be
used for guidance. If provided, the guidance will only be applied to the layers specified in the list.
Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9].
skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in
`skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers`
with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers
with a scale of `1`.
skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in
`skip_guidance_layers` will stop. The guidance will be applied to the layers specified in
`skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by
StabiltyAI for Stable Diffusion 3.5 Medium is 0.2.
skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in
`skip_guidance_layers` will start. The guidance will be applied to the layers specified in
`skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.

Examples:

Expand Down Expand Up @@ -809,6 +833,7 @@ def __call__(
)

self._guidance_scale = guidance_scale
self._skip_layer_guidance_scale = skip_layer_guidance_scale
Copy link
Collaborator

@yiyixuxu yiyixuxu Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to add a decorator for this too, like this

self._clip_skip = clip_skip
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
Expand Down Expand Up @@ -851,6 +876,9 @@ def __call__(
)

if self.do_classifier_free_guidance:
if skip_guidance_layers is not None:
original_prompt_embeds = prompt_embeds
original_pooled_prompt_embeds = pooled_prompt_embeds
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)

Expand Down Expand Up @@ -879,7 +907,11 @@ def __call__(
continue

# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = (
torch.cat([latents] * 2)
if self.do_classifier_free_guidance and skip_guidance_layers is None
else latents
)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])

Expand All @@ -896,6 +928,25 @@ def __call__(
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
should_skip_layers = (
True
if i > num_inference_steps * skip_layer_guidance_start
and i < num_inference_steps * skip_layer_guidance_stop
else False
)
if skip_guidance_layers is not None and should_skip_layers:
noise_pred_skip_layers = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=original_prompt_embeds,
pooled_projections=original_pooled_prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
skip_layers=skip_guidance_layers,
)[0]
noise_pred = (
noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale
)

# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
Expand Down
20 changes: 20 additions & 0 deletions tests/models/transformers/test_models_transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,23 @@ def test_set_attn_processor_for_determinism(self):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"SD3Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

def test_skip_layers(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)

# Forward pass without skipping layers
output_full = model(**inputs_dict).sample

# Forward pass with skipping layers 0 (since there's only one layer in this test setup)
inputs_dict_with_skip = inputs_dict.copy()
inputs_dict_with_skip["skip_layers"] = [0]
output_skip = model(**inputs_dict_with_skip).sample

# Check that the outputs are different
self.assertFalse(
torch.allclose(output_full, output_skip, atol=1e-5), "Outputs should differ when layers are skipped"
)

# Check that the outputs have the same shape
self.assertEqual(output_full.shape, output_skip.shape, "Outputs should have the same shape")
Loading