Skip to content

Commit 88dc834

Browse files
henrylhtsangmeta-codesync[bot]
authored andcommitted
better error handling when window size is -1 and no max seq len is provided (#4979)
Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/1993 Pull Request resolved: #4979 Improve error handling when one of the window size is not provided and we don't know what max seq len is. Reviewed By: Aya-ZIbra Differential Revision: D84021746 fbshipit-source-id: f0142f27b5e365a0f8698e29f9fe2f1f940cb859
1 parent 724f2ce commit 88dc834

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_bwd.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,16 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> dispatch_fmha_bwd(
4141
causal = false; // Use local attention instead of causal
4242
}
4343
// Expand -1 window sizes to full sequence length if available
44-
if (window_size_left < 0 && max_seq_len_k.has_value()) {
44+
if (window_size_left < 0) {
45+
TORCH_CHECK(
46+
max_seq_len_k.has_value(),
47+
"window_size_left is negative but max_seq_len_k is not provided");
4548
window_size_left = max_seq_len_k.value();
4649
}
47-
if (window_size_right < 0 && max_seq_len_k.has_value()) {
50+
if (window_size_right < 0) {
51+
TORCH_CHECK(
52+
max_seq_len_k.has_value(),
53+
"window_size_right is negative but max_seq_len_k is not provided");
4854
window_size_right = max_seq_len_k.value();
4955
}
5056
}

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,16 @@ std::tuple<at::Tensor, at::Tensor> dispatch_fmha_fwd(
2626
causal = false; // Use local attention instead of causal
2727
}
2828
// Expand -1 window sizes to full sequence length if available
29-
if (window_size_left < 0 && max_seq_len_k.has_value()) {
29+
if (window_size_left < 0) {
30+
TORCH_CHECK(
31+
max_seq_len_k.has_value(),
32+
"window_size_left is negative but max_seq_len_k is not provided");
3033
window_size_left = max_seq_len_k.value();
3134
}
32-
if (window_size_right < 0 && max_seq_len_k.has_value()) {
35+
if (window_size_right < 0) {
36+
TORCH_CHECK(
37+
max_seq_len_k.has_value(),
38+
"window_size_right is negative but max_seq_len_k is not provided");
3339
window_size_right = max_seq_len_k.value();
3440
}
3541
}

0 commit comments

Comments
 (0)