diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_bwd.cu b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_bwd.cu index 49e8812b73..bb2e8cf26e 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_bwd.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_bwd.cu @@ -41,10 +41,16 @@ std::tuple dispatch_fmha_bwd( causal = false; // Use local attention instead of causal } // Expand -1 window sizes to full sequence length if available - if (window_size_left < 0 && max_seq_len_k.has_value()) { + if (window_size_left < 0) { + TORCH_CHECK( + max_seq_len_k.has_value(), + "window_size_left is negative but max_seq_len_k is not provided"); window_size_left = max_seq_len_k.value(); } - if (window_size_right < 0 && max_seq_len_k.has_value()) { + if (window_size_right < 0) { + TORCH_CHECK( + max_seq_len_k.has_value(), + "window_size_right is negative but max_seq_len_k is not provided"); window_size_right = max_seq_len_k.value(); } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd.cu b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd.cu index 7096725588..5b63ae44e9 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd.cu @@ -26,10 +26,16 @@ std::tuple dispatch_fmha_fwd( causal = false; // Use local attention instead of causal } // Expand -1 window sizes to full sequence length if available - if (window_size_left < 0 && max_seq_len_k.has_value()) { + if (window_size_left < 0) { + TORCH_CHECK( + max_seq_len_k.has_value(), + "window_size_left is negative but max_seq_len_k is not provided"); window_size_left = max_seq_len_k.value(); } - if (window_size_right < 0 && max_seq_len_k.has_value()) { + if (window_size_right < 0) { + TORCH_CHECK( + max_seq_len_k.has_value(), + "window_size_right is negative but max_seq_len_k is not provided"); window_size_right = max_seq_len_k.value(); } }