Skip to content

Fix mixed precision crash in TopP and Random samplers#2585

Open
AlejandroPG06 wants to merge 3 commits intokeras-team:masterfrom
AlejandroPG06:fix-samplers-final
Open

Fix mixed precision crash in TopP and Random samplers#2585
AlejandroPG06 wants to merge 3 commits intokeras-team:masterfrom
AlejandroPG06:fix-samplers-final

Conversation

@AlejandroPG06
Copy link

Overview

This PR addresses a stability issue in TopPSampler and RandomSampler when running with mixed precision (float16). It aligns the implementation with TopKSampler by ensuring categorical sampling inputs are cast to float32.

Fixes #2584

Problem

The TensorFlow backend (and potentially others) does not strictly support half-precision (float16) for multinomial sampling operations. This limitation was previously addressed in TopKSampler but remained unpatched in other samplers.

Solution

I have added an explicit ops.cast(..., "float32") to the log-probabilities before passing them to random.categorical in:

  • src/samplers/top_p_sampler.py
  • src/samplers/random_sampler.py

Verification

  • Static Analysis: Verified that the implementation now matches the established pattern in src/samplers/top_k_sampler.py.
  • Formatting: Passed pre-commit run --all-files.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @AlejandroPG06, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical stability issue in TopPSampler and RandomSampler that caused crashes when operating in mixed precision (float16) environments. By explicitly casting the log-probabilities to float32 before passing them to the random.categorical function, the samplers now correctly handle the precision requirements of the TensorFlow backend, ensuring robust behavior consistent with TopKSampler.

Highlights

  • Mixed Precision Crash Fix: Implemented explicit casting of log-probabilities to "float32" before multinomial sampling in TopPSampler and RandomSampler to prevent crashes when using float16 (mixed precision).
  • TensorFlow Compatibility: Ensured compatibility with TensorFlow's multinomial sampling operations, which do not strictly support half-precision, by aligning the implementation with TopKSampler.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • keras_hub/src/samplers/random_sampler.py
    • Added ops.cast(ops.log(probabilities), "float32") to ensure full precision for multinomial sampling.
    • Added a comment explaining the necessity of float32 casting due to TensorFlow's half-precision multinomial sampling limitations.
  • keras_hub/src/samplers/top_p_sampler.py
    • Added ops.cast(ops.log(probabilities), "float32") to ensure full precision for multinomial sampling.
    • Added a comment explaining the necessity of float32 casting due to TensorFlow's half-precision multinomial sampling limitations.
Activity
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively resolves a mixed precision crash in TopPSampler and RandomSampler by ensuring the inputs to random.categorical are cast to float32. The changes are correct and align with existing patterns in other samplers. I've added a couple of minor suggestions to perform the type casting before the log operation for potentially improved numerical precision. Overall, this is a good fix.

AlejandroPG06 and others added 2 commits February 6, 2026 12:45
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@AlejandroPG06
Copy link
Author

Hi @dhantule,

Just a friendly ping on this!

It seems the keras-nightly check failure is unrelated to my changes (flaky test), as the stable tests passed successfully.
Please let me know if there is anything else needed from my end to move this forward.

Thanks!

@sachinprasadhs
Copy link
Collaborator

I have commented on the original issue. Please respond. Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Missing float32 cast in TopP and Random samplers causes crashes in Mixed Precision

2 participants

Comments