Add logits_to_keep parameter to BertForSequenceClassification #41369
+9
−2
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds support for the
logits_to_keep
parameter toBertForSequenceClassification
Motivation
The
logits_to_keep
parameter enables memory-efficient inference by computing logits only for the last N token positions. This optimization is particularly useful for:Implementation Details
logits_to_keep
parameter to theforward()
method signatureNone
)Code Changes
The implementation adds approximately 10 lines of code:
Testing
Comprehensive local testing performed with
bert-base-uncased
:logits_to_keep
values tested (1, 2, 3, 5, 10, None)Backward Compatibility
None
None
, behavior is identical to current implementationRelated Issues
Inspired by #40984