-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Refactor AttentionProcessorSkipHook to Support Custom STG Logic
#13220
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -108,8 +108,17 @@ def __torch_function__(self, func, types, args=(), kwargs=None): | |
|
|
||
|
|
||
| class AttentionProcessorSkipHook(ModelHook): | ||
| def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0): | ||
| def __init__( | ||
| self, | ||
| skip_processor_output_fn: Callable, | ||
| skip_attn_scores_fn: Callable | None = None, | ||
| skip_attention_scores: bool = False, | ||
| dropout: float = 1.0, | ||
| ): | ||
| super().__init__() | ||
| self.skip_processor_output_fn = skip_processor_output_fn | ||
| # STG default: return the values as attention output | ||
| self.skip_attn_scores_fn = skip_attn_scores_fn or (lambda attn, q, k, v: v) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it safe to assume that the
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| self.skip_attention_scores = skip_attention_scores | ||
| self.dropout = dropout | ||
|
|
||
|
|
@@ -119,8 +128,22 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): | |
| raise ValueError( | ||
| "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0." | ||
| ) | ||
| with AttentionScoreSkipFunctionMode(): | ||
| output = self.fn_ref.original_forward(*args, **kwargs) | ||
| processor_supports_skip_fn = hasattr(module.processor, "_skip_attn_scores") | ||
| if processor_supports_skip_fn: | ||
| module.processor._skip_attn_scores = True | ||
| module.processor._skip_attn_scores_fn = self.skip_attn_scores_fn | ||
| # Use try block in case attn processor raises an exception | ||
| try: | ||
| if processor_supports_skip_fn: | ||
| output = self.fn_ref.original_forward(*args, **kwargs) | ||
| else: | ||
| # Fallback to torch native SDPA intercept approach | ||
| with AttentionScoreSkipFunctionMode(): | ||
| output = self.fn_ref.original_forward(*args, **kwargs) | ||
|
Comment on lines
+137
to
+142
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we have to condition like this?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We condition on |
||
| finally: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should also not raise an exception for user clarity?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm confused by this question, is your suggestion to catch any exceptions in a try:
...
except Exception as e:
logger.error(f"Tried to skip attn scores but got error {e}", exc_info=True)
raise
finally:
# Clean up if necessary
if processor_supports_skip_fn:
module.processor._skip_attn_scores = False
module.processor._skip_attn_scores_fn = None? |
||
| if processor_supports_skip_fn: | ||
| module.processor._skip_attn_scores = False | ||
| module.processor._skip_attn_scores_fn = None | ||
| else: | ||
| if math.isclose(self.dropout, 1.0): | ||
| output = self.skip_processor_output_fn(module, *args, **kwargs) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it go at the end to preserve BC compatibility (in case someone initialized this class with positional arguments only)?
No strong opinions here.