Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions src/diffusers/hooks/layer_skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,17 @@ def __torch_function__(self, func, types, args=(), kwargs=None):


class AttentionProcessorSkipHook(ModelHook):
def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0):
def __init__(
self,
skip_processor_output_fn: Callable,
skip_attention_scores: bool = False,
dropout: float = 1.0,
skip_attn_scores_fn: Callable | None = None,
):
super().__init__()
self.skip_processor_output_fn = skip_processor_output_fn
# STG default: return the values as attention output
self.skip_attn_scores_fn = skip_attn_scores_fn or (lambda attn, q, k, v: v)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it safe to assume that the skip_attn_scores_fn will only take those four (attn, q, k, v) as inputs and always return v as the output?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lambda attn, q, k, v: v is intended to be the default skip_attn_scores_fn (unless I made a mistake), which performs standard STG by returning the values v. I don't think this constrains the signature or behavior of skip_attn_scores_fn if we pass it as an argument (and we can always manually implement the skip logic in the if self._skip_attn_scores: branch in the attention processor itself if necessary).

self.skip_attention_scores = skip_attention_scores
self.dropout = dropout

Expand All @@ -119,8 +128,22 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
raise ValueError(
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
)
with AttentionScoreSkipFunctionMode():
output = self.fn_ref.original_forward(*args, **kwargs)
processor_supports_skip_fn = hasattr(module.processor, "_skip_attn_scores")
if processor_supports_skip_fn:
module.processor._skip_attn_scores = True
module.processor._skip_attn_scores_fn = self.skip_attn_scores_fn
# Use try block in case attn processor raises an exception
try:
if processor_supports_skip_fn:
output = self.fn_ref.original_forward(*args, **kwargs)
else:
# Fallback to torch native SDPA intercept approach
with AttentionScoreSkipFunctionMode():
output = self.fn_ref.original_forward(*args, **kwargs)
Comment on lines +137 to +142
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have to condition like this?

Copy link
Collaborator Author

@dg845 dg845 Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We condition on processor_supports_skip_fn here in case the attention processor doesn't define a _skip_attn_scores attribute. If it doesn't we will fallback to the current behavior, which is to intercept a torch.nn.functional.scaled_dot_product_attention call and return the value from there. (The AttentionScoreSkipFunctionMode context manager performs the interception.)

finally:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should also not raise an exception for user clarity?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused by this question, is your suggestion to catch any exceptions in a except block? Maybe something like

            try:
                ...
            except Exception as e:
                logger.error(f"Tried to skip attn scores but got error {e}", exc_info=True)
                raise
            finally:
                # Clean up if necessary
                if processor_supports_skip_fn:
                    module.processor._skip_attn_scores = False
                    module.processor._skip_attn_scores_fn = None

?

if processor_supports_skip_fn:
module.processor._skip_attn_scores = False
module.processor._skip_attn_scores_fn = None
else:
if math.isclose(self.dropout, 1.0):
output = self.skip_processor_output_fn(module, *args, **kwargs)
Expand Down
Loading