feat(gdn): add padding index guard for bf16 decode kernel#2810
feat(gdn): add padding index guard for bf16 decode kernel#2810yzh119 merged 4 commits intoflashinfer-ai:mainfrom
Conversation
Clamp negative slot indices to 0 before passing to the bf16 fast-path kernel to prevent out-of-bounds memory access when padding indices (-1) are present in initial_state_indices. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a crucial safeguard within the BF16 decode kernel to enhance memory safety. By ensuring that all slot indices are non-negative before processing, it effectively mitigates the risk of out-of-bounds memory access that could arise from padding indices. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughAdds clamping to ensure negative Changes
Sequence Diagram(s)(omitted) Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip You can enable review details to help with troubleshooting, context usage and more.Enable the |
|
This resolves the issue: sgl-project/sglang#20791 @zihaoye PTAL |
There was a problem hiding this comment.
Code Review
This pull request adds a guard to prevent out-of-bounds memory access in the bf16 decode kernel when padding indices are present. The change clamps negative indices to 0, which successfully prevents crashes.
However, this approach leads to incorrect, non-zero output for padded items, as they are processed using state from index 0. I've left a comment with a suggestion to zero out the output for these padded items to ensure correctness and consistency with other kernels in the repository.
| else: | ||
| h_slot_indices = initial_state_indices | ||
|
|
||
| h_slot_indices = h_slot_indices.clamp(min=0) # guard -1 padding |
There was a problem hiding this comment.
This correctly prevents out-of-bounds memory access for padded indices. However, it causes padded items (with index -1) to be processed using the state of item 0, resulting in non-zero garbage output for these items. Other kernels in this repository handle padding by skipping computation and zeroing out the output for padded items.
To ensure consistent and correct behavior, the output for padded items should be zeroed out. This can be done after the kernel launch, before returning the output tensor.
For example, you could add the following after the kernel call:
if initial_state_indices is not None:
padding_mask = initial_state_indices < 0
if padding_mask.any():
output[padding_mask] = 0There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Around line 1997-1998: The current clamp of h_slot_indices to 0 corrupts real
slot 0; instead compute a validity mask (e.g., valid = (h_slot_indices >= 0) &
(h_slot_indices < pool_size)) and use that mask to only read/update gH/state for
valid rows while leaving padded rows as no-ops/zeros in output; remove the clamp
and replace any direct indexing/gather/scatter that uses h_slot_indices
(references: h_slot_indices, output, pool_size, gH read/write paths) with masked
operations or conditional gathers/scatters so pad_slot_id == -1 rows are ignored
and indices >= pool_size are treated invalid.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: fc4105fc-1168-4d77-b503-499e585773a3
📒 Files selected for processing (1)
flashinfer/gdn_kernels/gdn_decode_bf16_state.py
| h_slot_indices = h_slot_indices.clamp(min=0) # guard -1 padding | ||
| output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype) |
There was a problem hiding this comment.
Clamping padding -1 to 0 silently corrupts real state slot 0.
At Line 1997, padding rows are remapped onto a valid slot, so padded entries now participate in gH read/write and can mutate slot 0 state. This fixes OOB but breaks semantics (padding should be no-op). Also, indices >= pool_size remain unchecked.
Please switch to a validity-mask path (process only valid rows, leave padded rows zero/no-op) instead of aliasing pads to slot 0.
Based on learnings: In tests/mamba/selective_state_update_triton.py, pad_slot_id is always negative (-1), i.e., a padding sentinel rather than a real slot.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 1997 - 1998,
The current clamp of h_slot_indices to 0 corrupts real slot 0; instead compute a
validity mask (e.g., valid = (h_slot_indices >= 0) & (h_slot_indices <
pool_size)) and use that mask to only read/update gH/state for valid rows while
leaving padded rows as no-ops/zeros in output; remove the clamp and replace any
direct indexing/gather/scatter that uses h_slot_indices (references:
h_slot_indices, output, pool_size, gH read/write paths) with masked operations
or conditional gathers/scatters so pad_slot_id == -1 rows are ignored and
indices >= pool_size are treated invalid.
| else: | ||
| h_slot_indices = initial_state_indices | ||
|
|
||
| h_slot_indices = h_slot_indices.clamp(min=0) # guard -1 padding |
There was a problem hiding this comment.
Can we use inplace update: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.clamp_.html
There was a problem hiding this comment.
I feel the out-of-place clamp (current state) is safer. After the call, if the user inspects their initial_state_indices (aliased as h_slot_indices here) tensor, the padding markers are gone, ie. silently corrupted. This would break any logic that relies on those -1s after the decode step (e.g. a loop that checks which slots are padding).
Ideally, we should have change the inside of the kernel to avoid loading when idx < 0. I don't have time to test that now.
There was a problem hiding this comment.
it is actually easier to change. I modified the kernel and let it to use 0 if the index is negative. PTAL.
…ull buffer Replace the Python-level clamp(min=0) guard with an in-kernel check: if pool_batch_idx < 0, redirect to slot 0 which is a reserved null buffer (zero-initialized, never allocated to real requests). This means: - State reads from slot 0 return zeros (correct fresh initial state) - State writes to slot 0 are harmlessly discarded - No per-call tensor allocation at the Python level Applied to all 3 kernel variants: seqlen1, seqlen234_unified, seqlen1_lowBS_1chunk. Fixes sgl-project/sglang#20791 (accuracy degradation from OOB access on negative padding indices). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Verifies that the bf16 fast-path kernel handles negative (padding) indices correctly via the slot-0 null buffer pattern: - Valid slots produce correct output and state updates (vs. direct-state ref) - Unused real slots are exactly untouched - Slot 0 (null buffer) is excluded from correctness checks Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Remove module-level pytestmark skip (added in flashinfer-ai#2600) and replace with per-function @pytest.mark.skip on the previously-failing tests, so that the new test_decode_kernel_bf16_padding_indices runs in CI while the others remain skipped until their failures are addressed. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
|
Redirect negative pool_batch_idx to slot 0 (null buffer) inside all 3 bf16 decode kernel variants to prevent OOB memory access on padding indices. Also adds a unittest and enables it in CI by replacing the stale module-level skip with per-test skips. |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
tests/gdn/test_decode_delta_rule.py (2)
205-205: Preferxfail(with tracking) over broad unconditional skips.These decorators silence a large part of the suite and can hide regressions. Consider
xfail(strict=False)plus an issue link so failures stay visible in CI trends.♻️ Suggested pattern
-@pytest.mark.skip(reason="Temporarily skipped due to CI failures.") +@pytest.mark.xfail( + reason="Temporarily unstable in CI; tracked by <issue-url>", + strict=False, +)Also applies to: 371-371, 517-517, 773-773, 801-801, 1168-1168, 1209-1209, 1422-1422, 1454-1454
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_decode_delta_rule.py` at line 205, Replace broad pytest.mark.skip decorators with pytest.mark.xfail(strict=False, reason="link: <issue-or-tracker-url>") so failures remain visible in CI trends; locate the skip usages (e.g., the decorator on the test function in tests/gdn/test_decode_delta_rule.py and the other occurrences referenced at lines 371, 517, 773, 801, 1168, 1209, 1422, 1454) and change each to xfail with a brief reason that includes a link to the tracking issue or ticket.
881-884: Make padding-path coverage deterministic forbatch_size=1.Line 881 currently makes
batch_size=1coverage seed-dependent; the negative-index path may be missed for some seeds.♻️ Suggested change
- if batch_size >= 2: - mask[0] = False # ensure at least one valid - mask[-1] = True # ensure at least one padding + if batch_size == 1: + mask[0] = True # always exercise padding path + else: + mask[0] = False # ensure at least one valid + mask[-1] = True # ensure at least one padding🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_decode_delta_rule.py` around lines 881 - 884, The padding-path for batch_size==1 is seed-dependent because the block guarded by if batch_size >= 2 skips setting mask[-1] for size 1; make the negative-index padding deterministic by ensuring mask[-1] is set for all batch sizes while keeping mask[0]=False only when batch_size >= 2 so larger batches still get a guaranteed valid element; then proceed to set indices[mask] = -1 as before (refer to the variables batch_size, mask, and indices).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 931-934: The advanced indexing uses the tensor variable used
(computed as indices[valid_mask].to(device)) to index unused_mask, but used may
be an int32 tensor causing inconsistent behavior; cast used to torch.long before
indexing (e.g., ensure used = indices[valid_mask].to(device).long()) so that
unused_mask[used] = False uses long indices; update the code around the
variables used, unused_mask, valid_mask and device to perform the .long() cast
before the advanced indexing.
---
Nitpick comments:
In `@tests/gdn/test_decode_delta_rule.py`:
- Line 205: Replace broad pytest.mark.skip decorators with
pytest.mark.xfail(strict=False, reason="link: <issue-or-tracker-url>") so
failures remain visible in CI trends; locate the skip usages (e.g., the
decorator on the test function in tests/gdn/test_decode_delta_rule.py and the
other occurrences referenced at lines 371, 517, 773, 801, 1168, 1209, 1422,
1454) and change each to xfail with a brief reason that includes a link to the
tracking issue or ticket.
- Around line 881-884: The padding-path for batch_size==1 is seed-dependent
because the block guarded by if batch_size >= 2 skips setting mask[-1] for size
1; make the negative-index padding deterministic by ensuring mask[-1] is set for
all batch sizes while keeping mask[0]=False only when batch_size >= 2 so larger
batches still get a guaranteed valid element; then proceed to set indices[mask]
= -1 as before (refer to the variables batch_size, mask, and indices).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 9848b54a-1c67-4a4b-95d3-71ef4e35c3f9
📒 Files selected for processing (1)
tests/gdn/test_decode_delta_rule.py
| used = indices[valid_mask].to(device) | ||
| unused_mask = torch.ones(pool_size, dtype=torch.bool, device=device) | ||
| unused_mask[used] = False | ||
| unused_mask[0] = False # slot 0 may be modified (null buffer), don't check it |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify whether int32 tensor indexing is accepted in the current torch runtime.
python - <<'PY'
import torch
dev = "cuda" if torch.cuda.is_available() else "cpu"
mask = torch.ones(8, dtype=torch.bool, device=dev)
idx_i32 = torch.tensor([1, 3], dtype=torch.int32, device=dev)
try:
mask[idx_i32] = False
print("int32_indexing_supported")
except Exception as e:
print("int32_indexing_not_supported")
print(type(e).__name__, str(e))
PYRepository: flashinfer-ai/flashinfer
Length of output: 185
🏁 Script executed:
# First, locate and read the test file
fd test_decode_delta_rule.py -x wc -l {}Repository: flashinfer-ai/flashinfer
Length of output: 110
🏁 Script executed:
# Read the specific lines and surrounding context from the test file
fd test_decode_delta_rule.py -x head -n 950 {} | tail -n 50Repository: flashinfer-ai/flashinfer
Length of output: 2052
🌐 Web query:
PyTorch int32 tensor indexing boolean mask support requirements
💡 Result:
PyTorch supports boolean mask indexing with tensors of dtype torch.bool. An int32 tensor (torch.int32) is not supported as a boolean mask for indexing; it must be explicitly converted to bool first (e.g., mask.bool).
Citations:
- 1: Static size boolean masking pytorch/pytorch#96111
- 2: https://docs.pytorch.org/cppdocs/notes/tensor_indexing.html
- 3: https://pytorch.org/cppdocs/notes/tensor_indexing.html
🌐 Web query:
PyTorch tensor indexing dtype requirements documentation
💡 Result:
PyTorch tensor indexing with tensor indices (advanced indexing) requires the index tensors to have integer dtype, specifically torch.int64 (torch.long) for reliable behavior. While some smaller integer dtypes like int32, int16, int8 may work in basic cases, they can lead to inconsistent or unexpected results, especially with uint8 which historically was treated as boolean mask. Official documentation for functions like torch.index_select specifies "index (IntTensor or LongTensor)", where IntTensor is torch.int32 and LongTensor is torch.int64, but long is the standard and recommended dtype. The tensor.getitem follows similar conventions internally. Boolean tensors are used for boolean indexing/masking, not integer indexing. For scalar index tensors, there are known issues with uint8 and sometimes smaller signed ints producing wrong shapes. Key evidence from sources: - PyTorch forums discuss why indices must be long dtype [1]. - GitHub issues show uint8/int8/etc. inconsistencies and bugs in indexing [2,3]. - torch.index_select docs: "index (IntTensor or LongTensor)" [17]. - torch.index_add_ docs: index should have dtype torch.int64 or torch.int32 [22]. - tensor_attributes lists integer dtypes, but indexing ops prefer 32/64-bit [4]. Use torch.long indices: idx = torch.tensor([1,2,3], dtype=torch.long); x[idx]
Citations:
- 1: https://discuss.pytorch.org/t/why-does-the-indices-tensor-have-to-be-long-dtype/139675
- 2: uint8 scalar tensors cannot be used for integer indexing pytorch/pytorch#70916
- 3: Advanced indexing with uint8 tensor versus int64 tensor is inconsistent pytorch/pytorch#20149
- 4: https://docs.pytorch.org/docs/stable/tensor_attributes.html
- 5: https://pytorch.org/docs/stable/generated/torch.index_select.html
- 6: https://pytorch.org/docs/stable/generated/torch.Tensor.index_add_.html
- 7: https://docs.pytorch.org/docs/stable/tensors.html
🏁 Script executed:
# Find where 'indices' variable is defined in the test file
rg -n "indices\s*=" tests/gdn/test_decode_delta_rule.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 919
🏁 Script executed:
# Look for the indices creation with dtype info
rg -B 5 -A 5 "indices.*int32\|torch\.int32.*indices\|dtype.*int32" tests/gdn/test_decode_delta_rule.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Find the function name containing lines 931-934
sed -n '880,935p' tests/gdn/test_decode_delta_rule.py | head -60Repository: flashinfer-ai/flashinfer
Length of output: 1983
🏁 Script executed:
# Check the full function signature and location
rg -B 20 "used = indices\[valid_mask\]\.to\(device\)" tests/gdn/test_decode_delta_rule.py | head -40Repository: flashinfer-ai/flashinfer
Length of output: 761
Cast index tensor to long before advanced indexing.
Line 933 uses used (an int32 tensor) for advanced indexing into unused_mask. PyTorch's advanced indexing operations prefer and consistently support torch.long indices; int32 indices may work but can produce inconsistent or unexpected results across PyTorch versions. Cast to torch.long to ensure reliable behavior.
Proposed fix
- used = indices[valid_mask].to(device)
+ used = indices[valid_mask].to(device=device, dtype=torch.long)
unused_mask = torch.ones(pool_size, dtype=torch.bool, device=device)
unused_mask[used] = False🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/gdn/test_decode_delta_rule.py` around lines 931 - 934, The advanced
indexing uses the tensor variable used (computed as
indices[valid_mask].to(device)) to index unused_mask, but used may be an int32
tensor causing inconsistent behavior; cast used to torch.long before indexing
(e.g., ensure used = indices[valid_mask].to(device).long()) so that
unused_mask[used] = False uses long indices; update the code around the
variables used, unused_mask, valid_mask and device to perform the .long() cast
before the advanced indexing.
|
/bot run |
| "num_q_heads, num_k_heads, num_v_heads", | ||
| [(16, 16, 32)], | ||
| ) | ||
| @pytest.mark.skip(reason="Temporarily skipped due to CI failures.") |
|
[SUCCESS] Pipeline #46663351: 14/20 passed |
Clamp negative slot indices to 0 before passing to the bf16 fast-path kernel to prevent out-of-bounds memory access when padding indices (-1) are present in initial_state_indices.
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Bug Fixes
Tests