Skip to content

Conversation

philiproeleveld
Copy link

What does this PR do?

Fixes #40984

Adds logits_to_keep to many (older) ForCausalLM models that inherit from GenerationMixin.
Also consistently renames to loss and logits, and removes some code for float casting and mapping labels to the logits' device for models where that is already handled by the loss function (e.g. gpt_neo).

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR. @Rocketknight1

@philiproeleveld
Copy link
Author

I was wondering if I could also add this to the many seq2seq models that also inherit from GenerationMixin? T5 for example. In theory they would benefit when the user provides a very large decoder_input_ids, but that's not really what they're designed for...

Copy link
Contributor

github-actions bot commented Oct 3, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: bart, bert, bert_generation, big_bird, bigbird_pegasus, biogpt, blenderbot, blenderbot_small, blip, bloom, camembert, chameleon, codegen, cpmant, ctrl, data2vec

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly looks good, with one comment that applies to a couple of models!

This is a fairly big change that standardizes a lot of older models with the modern API so cc core maintainers @ArthurZucker @Cyrilvallez

vocab_size=self.config.vocab_size,
**kwargs,
)
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of these cases slightly change other behaviour (e.g. not moving labels to logits.device). Have you checked that this is equivalent?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Adding logits_to_keep to older models
2 participants