Skip to content

Commit 26aad6b

Browse files
authored
Disable cuDNN attention for known IMA and NaNs (NVIDIA#2344)
* Fix cuDNN backend selection for more case. Add CG as a option as well Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * fix logic Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix cuDNN checks Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Add more checks Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix cuddn version Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix error message Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Add check for window size Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
1 parent f62cad9 commit 26aad6b

File tree

10 files changed

+178
-116
lines changed

10 files changed

+178
-116
lines changed

transformer_engine/common/fused_attn/fused_attn.cpp

Lines changed: 68 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
138138
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
139139
float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
140140
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
141-
int64_t window_size_right, bool return_max_logit) {
141+
int64_t window_size_right, bool return_max_logit, bool cuda_graph) {
142142
using namespace transformer_engine;
143143
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
144144
const int device_id = cuda::current_device();
@@ -166,7 +166,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
166166
qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv &&
167167
max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 &&
168168
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
169-
// 9.2: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal}
169+
// 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal}
170170
(cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 &&
171171
max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 &&
172172
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
@@ -407,6 +407,28 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
407407
" Please upgrade your cuDNN version if possible."
408408
<< std::endl;
409409
}
410+
if ((cudnn_runtime_version == 91400) && (max_seqlen_kv > 1024) && (window_size_left != -1) &&
411+
(attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK) &&
412+
(attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK)) {
413+
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
414+
std::cout << "Warning: Given combination of attention mask (non-causal) and "
415+
"max_seqlen_kv (> 1024) does not support fused attention for cuDNN 9.14.0. "
416+
" Please upgrade your cuDNN version if possible."
417+
<< std::endl;
418+
}
419+
if ((cudnn_runtime_version <= 91500) && is_training &&
420+
(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) &&
421+
(max_seqlen_kv % 128 != 0) && cuda_graph &&
422+
(attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK) &&
423+
(attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) &&
424+
(attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)) {
425+
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
426+
std::cout << "Warning: Given combination of attention mask (non-padding),"
427+
" max_seqlen_kv (not divisible by 128), and qkv_format (BSHD/SBHD) for"
428+
" backward fused attention with graph capture requires cuDNN 9.15.1+. "
429+
"Please upgrade your cuDNN version if possible."
430+
<< std::endl;
431+
}
410432
} else {
411433
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
412434
}
@@ -419,11 +441,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
419441
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
420442
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state,
421443
size_t max_seqlen, bool is_training, bool return_max_logit,
422-
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
423-
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
424-
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
425-
int64_t window_size_right, NVTETensor workspace,
426-
cudaStream_t stream) {
444+
bool cuda_graph, float attn_scale, float dropout,
445+
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
446+
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
447+
int64_t window_size_left, int64_t window_size_right,
448+
NVTETensor workspace, cudaStream_t stream) {
427449
NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked);
428450
using namespace transformer_engine;
429451

@@ -460,7 +482,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
460482

461483
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
462484
is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
463-
h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit);
485+
h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit,
486+
cuda_graph);
464487

465488
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
466489
#if (CUDNN_VERSION >= 8901)
@@ -496,16 +519,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
496519
}
497520
}
498521
// NVTE fused attention BWD with packed QKV
499-
void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO,
500-
const NVTETensor S, NVTETensor dP,
501-
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV,
502-
NVTETensor dBias, NVTETensor dSoftmaxOffset,
503-
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
504-
size_t max_seqlen, float attn_scale, float dropout,
505-
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
506-
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
507-
int64_t window_size_left, int64_t window_size_right,
508-
bool deterministic, NVTETensor workspace, cudaStream_t stream) {
522+
void nvte_fused_attn_bwd_qkvpacked(
523+
const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S,
524+
NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias,
525+
NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
526+
size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
527+
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
528+
int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph,
529+
NVTETensor workspace, cudaStream_t stream) {
509530
NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked);
510531
using namespace transformer_engine;
511532

@@ -544,7 +565,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
544565

545566
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
546567
true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h,
547-
max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false);
568+
max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph);
548569

549570
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
550571
#if (CUDNN_VERSION >= 8901)
@@ -602,10 +623,10 @@ void nvte_fused_attn_fwd_kvpacked(
602623
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
603624
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
604625
const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q,
605-
size_t max_seqlen_kv, bool is_training, bool return_max_logit, float attn_scale, float dropout,
606-
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
607-
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
608-
NVTETensor workspace, cudaStream_t stream) {
626+
size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph,
627+
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
628+
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
629+
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) {
609630
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
610631
using namespace transformer_engine;
611632
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
@@ -681,7 +702,7 @@ void nvte_fused_attn_fwd_kvpacked(
681702
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
682703
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
683704
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right,
684-
return_max_logit);
705+
return_max_logit, cuda_graph);
685706

686707
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
687708
#if (CUDNN_VERSION >= 8901)
@@ -728,7 +749,8 @@ void nvte_fused_attn_bwd_kvpacked(
728749
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv,
729750
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
730751
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
731-
int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream) {
752+
int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace,
753+
cudaStream_t stream) {
732754
NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked);
733755
using namespace transformer_engine;
734756
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
@@ -776,9 +798,10 @@ void nvte_fused_attn_bwd_kvpacked(
776798
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
777799
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
778800

779-
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
780-
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q,
781-
h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, false);
801+
NVTE_Fused_Attn_Backend fused_attention_backend =
802+
nvte_get_fused_attn_backend(true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type,
803+
softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
804+
d, window_size_left, window_size_right, false, cuda_graph);
782805

783806
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
784807
#if (CUDNN_VERSION >= 8901)
@@ -833,16 +856,19 @@ void nvte_fused_attn_bwd_kvpacked(
833856
}
834857
}
835858
// NVTE fused attention FWD with separate Q, K and V
836-
void nvte_fused_attn_fwd(
837-
const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias,
838-
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
839-
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
840-
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
841-
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
842-
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit,
843-
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
844-
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
845-
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) {
859+
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
860+
const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
861+
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
862+
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
863+
const NVTETensor cu_seqlens_q_padded,
864+
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
865+
const NVTETensor page_table_v, const NVTETensor rng_state,
866+
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
867+
bool return_max_logit, bool cuda_graph, float attn_scale, float dropout,
868+
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
869+
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
870+
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
871+
cudaStream_t stream) {
846872
NVTE_API_CALL(nvte_flash_attn_fwd);
847873
using namespace transformer_engine;
848874
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
@@ -913,7 +939,7 @@ void nvte_fused_attn_fwd(
913939
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
914940
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
915941
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right,
916-
return_max_logit);
942+
return_max_logit, cuda_graph);
917943

918944
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
919945
#if (CUDNN_VERSION >= 8901)
@@ -963,7 +989,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
963989
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
964990
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
965991
int64_t window_size_left, int64_t window_size_right, bool deterministic,
966-
NVTETensor workspace, cudaStream_t stream) {
992+
bool cuda_graph, NVTETensor workspace, cudaStream_t stream) {
967993
NVTE_API_CALL(nvte_flash_attn_bwd);
968994
using namespace transformer_engine;
969995
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
@@ -1008,7 +1034,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
10081034

10091035
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
10101036
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q,
1011-
h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false);
1037+
h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false,
1038+
cuda_graph);
10121039

10131040
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
10141041
#if (CUDNN_VERSION >= 8901)

0 commit comments

Comments
 (0)