Skip to content

Add SpeculativeSampler for XLA-compatible decoding#2525

Open
Vivek1106-04 wants to merge 2 commits intokeras-team:masterfrom
Vivek1106-04:feat/speculative-sampler
Open

Add SpeculativeSampler for XLA-compatible decoding#2525
Vivek1106-04 wants to merge 2 commits intokeras-team:masterfrom
Vivek1106-04:feat/speculative-sampler

Conversation

@Vivek1106-04
Copy link

@Vivek1106-04 Vivek1106-04 commented Jan 16, 2026

Description

This PR implements [SpeculativeSampler], a native Keras implementation of Speculative Decoding (Leviathan et al., 2023).

Speculative decoding accelerates autoregressive inference by using a smaller "draft" model to generate candidate tokens that are verified by the target model in parallel. This unlocks speculative speedups for TPU, XLA, and JAX/TensorFlow workflows where external engines like vLLM are not available.

Key Features

  • XLA & TPU Compatible: Built with vectorized ops (slice_update, cumprod, where) and static shapes for full graph compilation.
  • vLLM API Alignment: Uses num_speculative_tokens argument matching vLLM standards.
  • Parallel Verification: Supports custom verify_next callable for batch target verification.
  • Fallback Safety: Gracefully falls back to serial verification if verify_next is not provided.

Algorithm

Draft K tokens (serial, cheap) → Verify K+1 positions (parallel) → Accept matching prefix + bonus token

Usage

import keras_hub

sampler = keras_hub.samplers.SpeculativeSampler(num_speculative_tokens=5)
target_lm.compile(sampler=sampler)

# In custom generate_step, pass draft_next callback:
output = sampler(
    next=target_next,
    prompt=tokens,
    draft_next=draft_next,
)

Fixes #2513

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Vivek1106-04, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a native Keras implementation of speculative decoding, a technique designed to significantly speed up autoregressive inference. By leveraging a smaller draft model to propose candidate tokens that are then efficiently verified by the target model, this feature enables performance improvements for users working with XLA, TPUs, and JAX/TensorFlow environments. The implementation includes robust handling for both parallel and serial verification, ensuring broad applicability and reliability.

Highlights

  • New Feature: SpeculativeSampler: Introduced SpeculativeSampler, a native Keras implementation of speculative decoding for accelerated autoregressive inference.
  • XLA/TPU Compatibility: The SpeculativeSampler is built with vectorized operations and static shapes, ensuring full graph compilation and compatibility with XLA, TPUs, and JAX/TensorFlow workflows.
  • Flexible Verification: Supports parallel verification via an optional verify_next callable, with a graceful fallback to serial verification if not provided, aligning with vLLM API standards.
  • Core Algorithm Implementation: The sampler implements the three-phase speculative decoding algorithm: serial drafting of K tokens, parallel verification of K+1 positions, and acceptance of matching prefixes plus a bonus token.
  • Comprehensive Testing: Added a dedicated test suite (speculative_sampler_test.py) to validate the functionality, including stateless calls, full/partial acceptance, early stopping, serialization, and batched operations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces SpeculativeSampler, a native Keras implementation for speculative decoding. The implementation is well-structured and uses XLA-compatible operations. However, there are some critical issues in the sampling logic. The current implementation uses greedy decoding (argmax) for both generating draft tokens and the bonus token, which makes the temperature and draft_temperature parameters ineffective. This should be corrected to use proper random sampling. Additionally, there are minor opportunities to improve code clarity and adhere to the documentation style guide by adding a reference to the original paper.

Comment on lines +118 to +121
probs = self.compute_probabilities(logits)
probs = ops.cast(probs, "float32") / self.draft_temperature
next_token = ops.argmax(probs, axis=-1)
next_token = ops.cast(next_token, current_prompt.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation for drafting tokens is greedy due to ops.argmax. The draft_temperature parameter is also applied incorrectly to probabilities instead of logits, and its effect is nullified by argmax. For speculative decoding, draft tokens should be sampled. Additionally, self.compute_probabilities already applies self.temperature, which is not what you want for the draft model when draft_temperature is provided.

This can be fixed by applying draft_temperature to the logits and using keras.random.categorical for sampling. This will also correctly utilize the seed_generator.

Suggested change
probs = self.compute_probabilities(logits)
probs = ops.cast(probs, "float32") / self.draft_temperature
next_token = ops.argmax(probs, axis=-1)
next_token = ops.cast(next_token, current_prompt.dtype)
draft_logits = ops.cast(logits, "float32") / self.draft_temperature
next_token = random.categorical(
draft_logits, 1, seed=self.seed_generator
)[:, 0]
next_token = ops.cast(next_token, current_prompt.dtype)

bonus_probs = ops.take_along_axis(target_probs, gather_idx, axis=1)[
:, 0, :
]
bonus_token = ops.argmax(bonus_probs, axis=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The bonus token is selected greedily using ops.argmax. This makes self.temperature (used to compute target_probs) have no effect on the final token selection. The bonus token should be sampled.

A simple way to fix this is to sample from the probabilities by converting them back to logits using ops.log. Note that using ops.log can be numerically unstable if probabilities are zero; a more robust solution would be to work with logits directly throughout this section.

Suggested change
bonus_token = ops.argmax(bonus_probs, axis=-1)
bonus_token = random.categorical(ops.log(bonus_probs), 1, seed=self.seed_generator)[:, 0]

causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])
```
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring is missing a reference to the original paper on Speculative Decoding. The repository style guide recommends including reference citations. Please add a Reference section to the class docstring before the closing """, like so:

    Reference:
     - [Leviathan et al., 2023](https://arxiv.org/abs/2211.17192)
References
  1. Include reference citations if applicable. (link)

Comment on lines +157 to +158
target_probs = self.compute_probabilities(target_logits)
target_tokens = ops.argmax(target_probs[:, :k, :], axis=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The verification step correctly uses argmax for greedy selection. However, compute_probabilities applies self.temperature before the argmax, which has no effect on the result and is inefficient as it performs an unnecessary softmax operation. It is clearer and more efficient to perform argmax directly on the logits.

If my other feedback regarding bonus token sampling is addressed (by using logits), target_probs will no longer be needed and line 157 can be removed entirely.

Suggested change
target_probs = self.compute_probabilities(target_logits)
target_tokens = ops.argmax(target_probs[:, :k, :], axis=-1)
target_tokens = ops.argmax(target_logits[:, :k, :], axis=-1)

@Vivek1106-04
Copy link
Author

@sachinprasadhs ! gentle reminder can you review the pr . Thank you

@sachinprasadhs
Copy link
Collaborator

@Vivek1106-04 , Thanks for the PR, I will take a look into this soon.

@sachinprasadhs sachinprasadhs self-requested a review February 10, 2026 19:22
@sachinprasadhs sachinprasadhs added the type:feature New feature or request label Feb 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

type:feature New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RFC: XLA-Optimized Speculative Decoding via SpeculativeSampler

2 participants

Comments