Add SpeculativeSampler for XLA-compatible decoding#2525
Add SpeculativeSampler for XLA-compatible decoding#2525Vivek1106-04 wants to merge 2 commits intokeras-team:masterfrom
Conversation
Summary of ChangesHello @Vivek1106-04, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a native Keras implementation of speculative decoding, a technique designed to significantly speed up autoregressive inference. By leveraging a smaller draft model to propose candidate tokens that are then efficiently verified by the target model, this feature enables performance improvements for users working with XLA, TPUs, and JAX/TensorFlow environments. The implementation includes robust handling for both parallel and serial verification, ensuring broad applicability and reliability. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces SpeculativeSampler, a native Keras implementation for speculative decoding. The implementation is well-structured and uses XLA-compatible operations. However, there are some critical issues in the sampling logic. The current implementation uses greedy decoding (argmax) for both generating draft tokens and the bonus token, which makes the temperature and draft_temperature parameters ineffective. This should be corrected to use proper random sampling. Additionally, there are minor opportunities to improve code clarity and adhere to the documentation style guide by adding a reference to the original paper.
| probs = self.compute_probabilities(logits) | ||
| probs = ops.cast(probs, "float32") / self.draft_temperature | ||
| next_token = ops.argmax(probs, axis=-1) | ||
| next_token = ops.cast(next_token, current_prompt.dtype) |
There was a problem hiding this comment.
The current implementation for drafting tokens is greedy due to ops.argmax. The draft_temperature parameter is also applied incorrectly to probabilities instead of logits, and its effect is nullified by argmax. For speculative decoding, draft tokens should be sampled. Additionally, self.compute_probabilities already applies self.temperature, which is not what you want for the draft model when draft_temperature is provided.
This can be fixed by applying draft_temperature to the logits and using keras.random.categorical for sampling. This will also correctly utilize the seed_generator.
| probs = self.compute_probabilities(logits) | |
| probs = ops.cast(probs, "float32") / self.draft_temperature | |
| next_token = ops.argmax(probs, axis=-1) | |
| next_token = ops.cast(next_token, current_prompt.dtype) | |
| draft_logits = ops.cast(logits, "float32") / self.draft_temperature | |
| next_token = random.categorical( | |
| draft_logits, 1, seed=self.seed_generator | |
| )[:, 0] | |
| next_token = ops.cast(next_token, current_prompt.dtype) |
| bonus_probs = ops.take_along_axis(target_probs, gather_idx, axis=1)[ | ||
| :, 0, : | ||
| ] | ||
| bonus_token = ops.argmax(bonus_probs, axis=-1) |
There was a problem hiding this comment.
The bonus token is selected greedily using ops.argmax. This makes self.temperature (used to compute target_probs) have no effect on the final token selection. The bonus token should be sampled.
A simple way to fix this is to sample from the probabilities by converting them back to logits using ops.log. Note that using ops.log can be numerically unstable if probabilities are zero; a more robust solution would be to work with logits directly throughout this section.
| bonus_token = ops.argmax(bonus_probs, axis=-1) | |
| bonus_token = random.categorical(ops.log(bonus_probs), 1, seed=self.seed_generator)[:, 0] |
| causal_lm.compile(sampler=sampler) | ||
| causal_lm.generate(["Keras is a"]) | ||
| ``` | ||
| """ |
There was a problem hiding this comment.
The docstring is missing a reference to the original paper on Speculative Decoding. The repository style guide recommends including reference citations. Please add a Reference section to the class docstring before the closing """, like so:
Reference:
- [Leviathan et al., 2023](https://arxiv.org/abs/2211.17192)References
- Include reference citations if applicable. (link)
| target_probs = self.compute_probabilities(target_logits) | ||
| target_tokens = ops.argmax(target_probs[:, :k, :], axis=-1) |
There was a problem hiding this comment.
The verification step correctly uses argmax for greedy selection. However, compute_probabilities applies self.temperature before the argmax, which has no effect on the result and is inefficient as it performs an unnecessary softmax operation. It is clearer and more efficient to perform argmax directly on the logits.
If my other feedback regarding bonus token sampling is addressed (by using logits), target_probs will no longer be needed and line 157 can be removed entirely.
| target_probs = self.compute_probabilities(target_logits) | |
| target_tokens = ops.argmax(target_probs[:, :k, :], axis=-1) | |
| target_tokens = ops.argmax(target_logits[:, :k, :], axis=-1) |
|
@sachinprasadhs ! gentle reminder can you review the pr . Thank you |
|
@Vivek1106-04 , Thanks for the PR, I will take a look into this soon. |
Description
This PR implements [SpeculativeSampler], a native Keras implementation of Speculative Decoding (Leviathan et al., 2023).
Speculative decoding accelerates autoregressive inference by using a smaller "draft" model to generate candidate tokens that are verified by the target model in parallel. This unlocks speculative speedups for TPU, XLA, and JAX/TensorFlow workflows where external engines like vLLM are not available.
Key Features
slice_update,cumprod,where) and static shapes for full graph compilation.num_speculative_tokensargument matching vLLM standards.verify_nextcallable for batch target verification.verify_nextis not provided.Algorithm
Draft K tokens (serial, cheap) → Verify K+1 positions (parallel) → Accept matching prefix + bonus token
Usage
Fixes #2513