-
Notifications
You must be signed in to change notification settings - Fork 40
[BUG FIX] Prevent mask/bias materialization; avoid OOB for irregular seqlen #168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Uses in-place nan_to_num_ operation for better memory efficiency. Updates tensor sanitization to use dtype-specific infinity bounds instead of fixed values, preventing potential overflow issues. Changes tensor initialization from empty_like to zeros_like to ensure deterministic starting values for gradients. Fixes bias padding value from minimum float to zero for better numerical behavior. Enhances documentation to clarify support for flexible mask and bias head dimensions in MQA/GQA scenarios.
Eliminates unnecessary padding of key and value tensors to multiples of 128 in sequence length dimension. Removes associated context saving and gradient unpadding operations that are no longer needed without the sequence length padding. Simplifies the forward and backward pass implementation by removing conditional padding logic for masks and biases.
Replaces vectorized copy with element-wise assignment to prevent memory access violations when bounds checking is disabled. Changes predicate handling to use dedicated predicate tensor instead of coordinate-based bounds checking for improved safety. Updates default Clear_OOB_MN to false and removes max_N parameter as bounds checking now relies on predicate tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes a critical bug that prevented mask/bias materialization and caused out-of-bounds (OOB) errors for irregular sequence lengths. The fix replaces vectorized loads with scalar per-element loads to handle non-128-aligned sequence lengths robustly while preserving memory efficiency.
- Removes K-dimension padding requirement for mask/bias tensors, preserving expand views
- Updates copy utilities to use scalar loads with column predicates instead of vectorized operations
- Improves tensor sanitization to use appropriate floating-point limits
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| flash_dmattn/flash_dmattn_interface.py | Removes K-dimension padding logic and improves tensor sanitization |
| csrc/flash_dmattn/src/utils.h | Updates copy_MN function to use scalar loads with predicates instead of vectorized copies |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| ): | ||
| q, k, v, mask, bias, out, softmax_lse = ctx.saved_tensors | ||
| dq, dk, dv, dbias = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.empty_like(bias) | ||
| dq, dk, dv, dbias = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v), torch.zeros_like(bias) |
Copilot
AI
Sep 16, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using torch.zeros_like() initializes all tensors with zeros, which is unnecessary overhead since these gradient tensors will be fully written by the backward kernel. Consider using torch.empty_like() for better performance.
| dq, dk, dv, dbias = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v), torch.zeros_like(bias) | |
| dq, dk, dv, dbias = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.empty_like(bias) |
| ): | ||
| """ | ||
| Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads | ||
| Supports multi-query attention and grouped-query attention (MQA/GQA) by passing in KV with fewer heads |
Copilot
AI
Sep 16, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The word 'Similarity' on line 331 should be 'Similarly' (missing 'l').
csrc/flash_dmattn/src/utils.h
Outdated
| typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, | ||
| typename Engine2, typename Layout2 | ||
| bool Is_even_MN=true, bool Clear_OOB_MN=false, bool Bool_to_Element=false, typename To_type=void, | ||
| // typename TiledCopy, |
Copilot
AI
Sep 16, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The commented-out typename TiledCopy parameter should be removed entirely rather than left as a comment, as it's no longer used in the function signature.
| // typename TiledCopy, |
csrc/flash_dmattn/src/utils.h
Outdated
| // typename TiledCopy, | ||
| typename Engine0, typename Layout0, typename Engine1, typename Layout1, | ||
| typename Engine2, typename Layout2, typename Engine3, typename Layout3 | ||
| > | ||
| __forceinline__ __device__ void copy_MN( | ||
| TiledCopy tiled_copy, | ||
| // TiledCopy tiled_copy, |
Copilot
AI
Sep 16, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The commented-out TiledCopy tiled_copy parameter should be removed entirely rather than left as a comment, as it's no longer used in the function.
Improves memory alignment by ensuring head dimensions are padded to multiples of 8 for 16-bit memory allocations. Comments out sequence length padding implementation for future consideration, including corresponding mask and bias padding logic in both forward and backward passes.
Implements a device function that performs logical OR reduction across mask tensor elements and synchronizes the result across thread blocks using warp-level primitives. Enables efficient sparse attention pattern processing by allowing threads to collectively determine if any mask elements are active within a given region.
Splits the generic copy_MN function into four specialized functions: - copy_MN for basic tensor copying with tiled copy operations - copy_mask for masked copying operations - copy_mask_with_or_reduce for copying with OR reduction and block activity tracking - copy_bias for bias-specific copying with element-wise assignment Removes the Bool_to_Element template parameter and related conditional logic, simplifying the codebase by creating purpose-specific functions instead of a single overloaded function with multiple behaviors.
Refactors combined mask-bias memory operations into separate dedicated operations to improve performance and maintainability. Introduces specialized copy functions for mask and bias operations with proper bounds checking and OR-reduction for mask activity detection. Removes redundant synchronization points by leveraging built-in synchronization in the new copy functions. Adds predicate tensor allocation for proper boundary handling in both regular and split-KV attention kernels.
Splits the unified GmemTiledCopyMaskBias into separate GmemTiledCopyMask and GmemTiledCopyBias operations to enable independent optimization of memory access patterns. Introduces specialized copy_mask_with_or_reduce and copy_bias functions that replace generic copy_MN calls, allowing for better memory coalescing and reduced synchronization overhead. Adds predicate tensor allocation for bounds checking on the N dimension to improve memory safety and enable more efficient vectorized operations in future optimizations.
Splits the combined GmemTiledCopyMaskBias type into separate GmemTiledCopyMask and GmemTiledCopyBias types in both forward and backward kernel traits. This separation improves code clarity and allows for independent handling of mask and bias copy operations, enabling more flexible memory access patterns and potential optimizations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| def _sanitize_tensors(*tensors: Optional[torch.Tensor], nan: float = 0.0, posinf: float = 1e6, neginf: float = -1e6) -> None: | ||
| for t in tensors: | ||
| if t is not None and isinstance(t, torch.Tensor): | ||
| torch.nan_to_num(t, nan=nan, posinf=posinf, neginf=neginf, out=t) | ||
| torch.nan_to_num_(t, nan=nan, posinf=posinf, neginf=neginf) | ||
|
|
||
|
|
Copilot
AI
Sep 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function signature changed from using out=t parameter to in-place operation torch.nan_to_num_(), but the function parameters still include nan, posinf, neginf which suggests the old API expected these to be configurable. However, the function calls at lines 98 and 173 now pass specific dtype-based values, making the default parameters in the function signature potentially misleading.
| ): | ||
| q, k, v, mask, bias, out, softmax_lse = ctx.saved_tensors | ||
| dq, dk, dv, dbias = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.empty_like(bias) | ||
| dq, dk, dv, dbias = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v), torch.zeros_like(bias) |
Copilot
AI
Sep 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using torch.zeros_like() instead of torch.empty_like() initializes the tensors with zeros, which adds unnecessary overhead since these tensors will be completely overwritten by the backward computation.
| dq, dk, dv, dbias = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v), torch.zeros_like(bias) | |
| dq, dk, dv, dbias = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.empty_like(bias) |
| typename Engine0, typename Layout0, typename Engine1, typename Layout1, | ||
| typename Engine2, typename Layout2, typename Engine3, typename Layout3 | ||
| > | ||
| __forceinline__ __device__ void copy_mask( |
Copilot
AI
Sep 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The copy_mask function at lines 585-612 is nearly identical to copy_MN function at lines 548-575, with only minor template parameter differences. This duplicates logic and creates maintenance burden.
| // cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); | ||
| #pragma unroll | ||
| for (int i = 0; i < size<0>(S); ++i) { | ||
| D(i, m, n) = S(i, m, n); | ||
| } |
Copilot
AI
Sep 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The commented out cute::copy call on line 685 suggests this was changed to manual scalar copying, but the comment should be removed or explain why the manual loop is necessary instead of using the copy utility.
| FLASH_NAMESPACE::copy_bias<Is_even_MN, /*Clear_OOB_MN=*/true>( | ||
| gmem_tiled_copy_Bias, | ||
| tBiasgBias(_, _, _, n_block), tBiassBias, | ||
| tBiascBias, | ||
| binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN | ||
| tBiascBias, tBiaspBias, | ||
| binfo.actual_seqlen_q - m_block * kBlockM | ||
| ); | ||
| // Because copy_bias currently uses scalar loads, we need to sync here. | ||
| // TODO: Remove sync after fixing to vectorized loads. | ||
| __syncthreads(); |
Copilot
AI
Sep 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Multiple instances of manual __syncthreads() calls are added (lines 399, 525, 654, 1085, 1233, 1380) specifically for scalar bias copying. This adds synchronization overhead that could impact performance, and the TODO comments indicate this is a temporary workaround.
Summary
#161 #169
Root Cause
seqlen_k = 4095, bias rows became misaligned (e.g., 8190B stride for fp16), causing CUDA misaligned address errors.expandviews.Changes
predicate_N) and updated copy helper:utils.h:copy_MNsupports per-column predicate and row-limit, with scalar elementwise copy. WhenBool_to_Element=true, converts bool mask to numeric element.flash_fwd_kernel.h: Usecopy_MNto load mask/bias with per-N predicates in masking steps; use even fast-path in non-masking loop.flash_bwd_kernel.h: Symmetric updates for dS/mask/bias loads and dbias writes withClear_OOB_MN=false.expandviews.Reproduction
Before fix:
benchmarks/forward_equivalence.pywith:RuntimeError: CUDA error: misaligned addressduring forward.After fix:
Tests
Compatibility
expandviews for mask/bias (no materialization).Checklist