-
Notifications
You must be signed in to change notification settings - Fork 301
[FIX] Prevent TypeError in text-only Gemma3CausalLM and improve gener… #2423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
…ate_step defaults 1. Bug Fix: TypeError during generation for text-only models The Problem: When using a Gemma3CausalLM model configured for text-only processing (i.e., with vision_encoder=None and preprocessor=None), a call to causal_lm.generate() fails with a TypeError. The root cause is that the internal generate_step method returns a dictionary containing an 'images': None key-value pair. This None value is eventually passed to ops.concatenate during the output normalization step, which does not accept None as a valid input. This workflow is common when pretraining a model from scratch. The Fix: The generate_step method has been modified to only include the 'images' key in its returned dictionary if an image tensor is actually present. This ensures that a None value is never passed to downstream functions, resolving the TypeError. Proof of Bug and Fix: The following Colab notebook demonstrates the bug with the original code and shows the successful execution after applying this fix: https://colab.research.google.com/drive/1QVk2idB6fcdYYJb1cBQGaKHe5QSGjCti?usp=sharing 2. Refactoring: Remove Hardcoded Stop Token The Problem: The internal generate_step method has a hardcoded default stop_token_ids=[106], which corresponds to the <end_of_turn> token. This is conceptually incorrect for a base architectural model, as the model itself should not have opinions about instruction-following or conversational tokens. This hardcoded value can interfere with pretraining or sampling raw text. The Fix: The method signature has been changed from stop_token_ids=[106] to stop_token_ids=None. This is a safe, non-breaking change because the public-facing Gemma3CausalLM.generate() method is already responsible for setting the appropriate stop tokens when a user specifies stop_token_ids="auto".
Summary of ChangesHello @emparu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses two key areas within the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces two valuable improvements to the Gemma3CausalLM
model. First, it corrects a TypeError
that occurred during text-only generation by ensuring the images
key is only included in the output dictionary when an image is present. Second, it refactors the generate_step
method to remove a hardcoded stop_token_ids
default value, correctly changing it to None
. This change not only makes the base model more generic but also resolves a potential issue with mutable default arguments. The changes are well-implemented and align with the project's coding standards. The pull request is well-documented with a clear explanation of the problem and the fix.
…ate_step defaults
The Problem:
When using a Gemma3CausalLM model configured for text-only processing (i.e., with vision_encoder=None and preprocessor=None), a call to causal_lm.generate() fails with a TypeError.
The root cause is that the internal generate_step method returns a dictionary containing an 'images': None key-value pair. This None value is eventually passed to ops.concatenate during the output normalization step, which does not accept None as a valid input. This workflow is common when pretraining a model from scratch.
The Fix:
The generate_step method has been modified to only include the 'images' key in its returned dictionary if an image tensor is actually present. This ensures that a None value is never passed to downstream functions, resolving the TypeError.
Proof of Bug and Fix:
The following Colab notebook demonstrates the bug with the original code and shows the successful execution after applying this fix: https://colab.research.google.com/drive/1QVk2idB6fcdYYJb1cBQGaKHe5QSGjCti?usp=sharing
The Problem:
The internal generate_step method has a hardcoded default stop_token_ids=[106], which corresponds to the <end_of_turn> token. This is conceptually incorrect for a base architectural model, as the model itself should not have opinions about instruction-following or conversational tokens. This hardcoded value can interfere with pretraining or sampling raw text.
The Fix:
The method signature has been changed from stop_token_ids=[106] to stop_token_ids=None.
This is a safe, non-breaking change because the public-facing Gemma3CausalLM.generate() method is already responsible for setting the appropriate stop tokens when a user specifies stop_token_ids="auto".
Description of the change
Reference
Colab Notebook
Checklist