-
Notifications
You must be signed in to change notification settings - Fork 40
[BUG FIX] Prevent mask/bias materialization; avoid OOB for irregular seqlen #168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
e69b1c7
c82f7dc
8a3bb04
e23b08f
510aaf5
a148a3a
a0475b2
3b7b57b
e9f9fcc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -521,15 +521,16 @@ __forceinline__ __device__ void copy( | |
| //////////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
|
||
| template < | ||
| bool Is_even_MN=true, bool Clear_OOB_MN=true, bool Bool_to_Element=false, typename To_type=void, | ||
| typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, | ||
| typename Engine2, typename Layout2 | ||
| bool Is_even_MN=true, bool Clear_OOB_MN=false, bool Bool_to_Element=false, typename To_type=void, | ||
| // typename TiledCopy, | ||
| typename Engine0, typename Layout0, typename Engine1, typename Layout1, | ||
| typename Engine2, typename Layout2, typename Engine3, typename Layout3 | ||
| > | ||
| __forceinline__ __device__ void copy_MN( | ||
| TiledCopy tiled_copy, | ||
| // TiledCopy tiled_copy, | ||
|
||
| Tensor<Engine0, Layout0> const &S, Tensor<Engine1, Layout1> &D, | ||
| Tensor<Engine2, Layout2> const &identity_MN, | ||
| const int max_M=0, const int max_N=0 | ||
| Tensor<Engine2, Layout2> const &identity_MN, Tensor<Engine3, Layout3> const &predicate_N, | ||
| const int max_M=0 | ||
| ) { | ||
| CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); // (MMA, MMA_M, MMA_N) | ||
| CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); // (MMA, MMA_M, MMA_N) | ||
|
|
@@ -542,14 +543,19 @@ __forceinline__ __device__ void copy_MN( | |
| if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_M) { | ||
| #pragma unroll | ||
| for (int n = 0; n < size<2>(S); ++n) { | ||
| if (Is_even_MN || get<1>(identity_MN(0, m, n)) < max_N) { | ||
| if (Is_even_MN || predicate_N(n)) { | ||
| if constexpr (Bool_to_Element) { | ||
| #pragma unroll | ||
| for (int i = 0; i < size<0>(S); ++i) { | ||
| D(i, m, n) = static_cast<bool>(S(i, m, n)) ? To_type(1) : To_type(0); | ||
| } | ||
| } else { | ||
| cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); | ||
| // Using vectorized load will cause out-of-bounds access when !Is_even_MN && !predicate_N(n) | ||
| // cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); | ||
| #pragma unroll | ||
| for (int i = 0; i < size<0>(S); ++i) { | ||
| D(i, m, n) = S(i, m, n); | ||
| } | ||
| } | ||
| } else if (Clear_OOB_MN) { | ||
| cute::clear(D(_, m, n)); | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -14,7 +14,7 @@ def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: | |||||||||
| def _sanitize_tensors(*tensors: Optional[torch.Tensor], nan: float = 0.0, posinf: float = 1e6, neginf: float = -1e6) -> None: | ||||||||||
| for t in tensors: | ||||||||||
| if t is not None and isinstance(t, torch.Tensor): | ||||||||||
| torch.nan_to_num(t, nan=nan, posinf=posinf, neginf=neginf, out=t) | ||||||||||
| torch.nan_to_num_(t, nan=nan, posinf=posinf, neginf=neginf) | ||||||||||
|
|
||||||||||
|
|
||||||||||
|
Comment on lines
14
to
19
|
||||||||||
| def _get_block_size_n(device, head_dim, is_causal): | ||||||||||
|
|
@@ -95,7 +95,7 @@ def _flash_dmattn_forward( | |||||||||
| softcap, | ||||||||||
| return_softmax, | ||||||||||
| ) | ||||||||||
| _sanitize_tensors(out) | ||||||||||
| _sanitize_tensors(out, nan=0.0, posinf=torch.finfo(out.dtype).max, neginf=torch.finfo(out.dtype).min) | ||||||||||
| return out, softmax_lse, S_dmask | ||||||||||
|
|
||||||||||
|
|
||||||||||
|
|
@@ -170,7 +170,7 @@ def _flash_dmattn_backward( | |||||||||
| softcap, | ||||||||||
| deterministic, | ||||||||||
| ) | ||||||||||
| _sanitize_tensors(dq, dk, dv, dbias) | ||||||||||
| _sanitize_tensors(dq, dk, dv, dbias, nan=0.0, posinf=torch.finfo(dq.dtype).max, neginf=torch.finfo(dq.dtype).min) | ||||||||||
| return softmax_d | ||||||||||
|
|
||||||||||
|
|
||||||||||
|
|
@@ -227,8 +227,6 @@ def forward( | |||||||||
| return_softmax: Optional[bool], | ||||||||||
| is_grad_enabled: bool = True, | ||||||||||
| ): | ||||||||||
| # q, k, v are expected to be of shape (batch_size, seqlen, num_heads, head_size) | ||||||||||
| seqlen_k = k.shape[1] | ||||||||||
| is_grad = is_grad_enabled and any( | ||||||||||
| x.requires_grad for x in [q, k, v] | ||||||||||
| ) | ||||||||||
|
|
@@ -249,14 +247,6 @@ def forward( | |||||||||
| k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) | ||||||||||
| v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) | ||||||||||
|
|
||||||||||
| if seqlen_k % 128 != 0: | ||||||||||
| k = torch.nn.functional.pad(k, [0, 0, 0, 0, 0, 128 - seqlen_k % 128]) | ||||||||||
| v = torch.nn.functional.pad(v, [0, 0, 0, 0, 0, 128 - seqlen_k % 128]) | ||||||||||
| if mask is not None: | ||||||||||
| mask = torch.nn.functional.pad(mask, [0, 128 - seqlen_k % 128], value=False) | ||||||||||
| if bias is not None: | ||||||||||
| bias = torch.nn.functional.pad(bias, [0, 128 - seqlen_k % 128], value=torch.finfo(bias.dtype).min) | ||||||||||
|
|
||||||||||
| out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward( | ||||||||||
| q, | ||||||||||
| k, | ||||||||||
|
|
@@ -271,7 +261,6 @@ def forward( | |||||||||
|
|
||||||||||
| if is_grad: | ||||||||||
| ctx.save_for_backward(q, k, v, mask, bias, out_padded, softmax_lse) | ||||||||||
| ctx.seqlen_k = seqlen_k | ||||||||||
| ctx.softmax_scale = softmax_scale | ||||||||||
| ctx.is_causal = is_causal | ||||||||||
| ctx.softcap = softcap | ||||||||||
|
|
@@ -288,7 +277,7 @@ def backward( | |||||||||
| *args: Any, | ||||||||||
| ): | ||||||||||
| q, k, v, mask, bias, out, softmax_lse = ctx.saved_tensors | ||||||||||
| dq, dk, dv, dbias = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.empty_like(bias) | ||||||||||
| dq, dk, dv, dbias = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v), torch.zeros_like(bias) | ||||||||||
|
||||||||||
| dq, dk, dv, dbias = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v), torch.zeros_like(bias) | |
| dq, dk, dv, dbias = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.empty_like(bias) |
Copilot
AI
Sep 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using torch.zeros_like() instead of torch.empty_like() initializes the tensors with zeros, which adds unnecessary overhead since these tensors will be completely overwritten by the backward computation.
| dq, dk, dv, dbias = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v), torch.zeros_like(bias) | |
| dq, dk, dv, dbias = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.empty_like(bias) |
Copilot
AI
Sep 16, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The word 'Similarity' on line 331 should be 'Similarly' (missing 'l').
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The commented-out
typename TiledCopyparameter should be removed entirely rather than left as a comment, as it's no longer used in the function signature.