@@ -138,7 +138,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
138138 NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
139139 float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
140140 size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
141- int64_t window_size_right, bool return_max_logit) {
141+ int64_t window_size_right, bool return_max_logit, bool cuda_graph ) {
142142 using namespace transformer_engine ;
143143 NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
144144 const int device_id = cuda::current_device ();
@@ -166,7 +166,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
166166 qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv &&
167167 max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 &&
168168 attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
169- // 9.2: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal}
169+ // 9.2.1 : {bshd, sbhd}, any seqlen, d=128, {no_mask, causal}
170170 (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 &&
171171 max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 &&
172172 (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
@@ -407,6 +407,28 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
407407 " Please upgrade your cuDNN version if possible."
408408 << std::endl;
409409 }
410+ if ((cudnn_runtime_version == 91400 ) && (max_seqlen_kv > 1024 ) && (window_size_left != -1 ) &&
411+ (attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK) &&
412+ (attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK)) {
413+ backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
414+ std::cout << " Warning: Given combination of attention mask (non-causal) and "
415+ " max_seqlen_kv (> 1024) does not support fused attention for cuDNN 9.14.0. "
416+ " Please upgrade your cuDNN version if possible."
417+ << std::endl;
418+ }
419+ if ((cudnn_runtime_version <= 91500 ) && is_training &&
420+ (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) &&
421+ (max_seqlen_kv % 128 != 0 ) && cuda_graph &&
422+ (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK) &&
423+ (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) &&
424+ (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)) {
425+ backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
426+ std::cout << " Warning: Given combination of attention mask (non-padding),"
427+ " max_seqlen_kv (not divisible by 128), and qkv_format (BSHD/SBHD) for"
428+ " backward fused attention with graph capture requires cuDNN 9.15.1+. "
429+ " Please upgrade your cuDNN version if possible."
430+ << std::endl;
431+ }
410432 } else {
411433 backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
412434 }
@@ -419,11 +441,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
419441 NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
420442 const NVTETensor cu_seqlens_padded, const NVTETensor rng_state,
421443 size_t max_seqlen, bool is_training, bool return_max_logit,
422- float attn_scale , float dropout, NVTE_QKV_Layout qkv_layout ,
423- NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type ,
424- NVTE_Softmax_Type softmax_type, int64_t window_size_left ,
425- int64_t window_size_right, NVTETensor workspace ,
426- cudaStream_t stream) {
444+ bool cuda_graph , float attn_scale, float dropout ,
445+ NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type ,
446+ NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type ,
447+ int64_t window_size_left, int64_t window_size_right ,
448+ NVTETensor workspace, cudaStream_t stream) {
427449 NVTE_API_CALL (nvte_flash_attn_fwd_qkvpacked);
428450 using namespace transformer_engine ;
429451
@@ -460,7 +482,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
460482
461483 NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend (
462484 is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
463- h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit);
485+ h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit,
486+ cuda_graph);
464487
465488 if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
466489#if (CUDNN_VERSION >= 8901)
@@ -496,16 +519,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
496519 }
497520}
498521// NVTE fused attention BWD with packed QKV
499- void nvte_fused_attn_bwd_qkvpacked (const NVTETensor QKV, const NVTETensor O, const NVTETensor dO,
500- const NVTETensor S, NVTETensor dP,
501- const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV,
502- NVTETensor dBias, NVTETensor dSoftmaxOffset,
503- const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
504- size_t max_seqlen, float attn_scale, float dropout,
505- NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
506- NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
507- int64_t window_size_left, int64_t window_size_right,
508- bool deterministic, NVTETensor workspace, cudaStream_t stream) {
522+ void nvte_fused_attn_bwd_qkvpacked (
523+ const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S,
524+ NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias,
525+ NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
526+ size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
527+ NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
528+ int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph,
529+ NVTETensor workspace, cudaStream_t stream) {
509530 NVTE_API_CALL (nvte_flash_attn_bwd_qkvpacked);
510531 using namespace transformer_engine ;
511532
@@ -544,7 +565,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
544565
545566 NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend (
546567 true , QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h,
547- max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false );
568+ max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false , cuda_graph );
548569
549570 if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
550571#if (CUDNN_VERSION >= 8901)
@@ -602,10 +623,10 @@ void nvte_fused_attn_fwd_kvpacked(
602623 const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
603624 const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
604625 const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q,
605- size_t max_seqlen_kv, bool is_training, bool return_max_logit, float attn_scale, float dropout ,
606- NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type ,
607- NVTE_Softmax_Type softmax_type, int64_t window_size_left , int64_t window_size_right ,
608- NVTETensor workspace, cudaStream_t stream) {
626+ size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph ,
627+ float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
628+ NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type , int64_t window_size_left ,
629+ int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) {
609630 NVTE_API_CALL (nvte_flash_attn_fwd_kvpacked);
610631 using namespace transformer_engine ;
611632 const Tensor *input_cu_seqlens_q = convertNVTETensorCheck (cu_seqlens_q);
@@ -681,7 +702,7 @@ void nvte_fused_attn_fwd_kvpacked(
681702 NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend (
682703 is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
683704 h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right,
684- return_max_logit);
705+ return_max_logit, cuda_graph );
685706
686707 if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
687708#if (CUDNN_VERSION >= 8901)
@@ -728,7 +749,8 @@ void nvte_fused_attn_bwd_kvpacked(
728749 const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv,
729750 float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
730751 NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
731- int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream) {
752+ int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace,
753+ cudaStream_t stream) {
732754 NVTE_API_CALL (nvte_flash_attn_bwd_kvpacked);
733755 using namespace transformer_engine ;
734756 const Tensor *input_cu_seqlens_q = convertNVTETensorCheck (cu_seqlens_q);
@@ -776,9 +798,10 @@ void nvte_fused_attn_bwd_kvpacked(
776798 const NVTEDType Q_type = static_cast <NVTEDType>(input_Q->data .dtype );
777799 const NVTEDType KV_type = static_cast <NVTEDType>(input_KV->data .dtype );
778800
779- NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend (
780- true , Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q,
781- h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, false );
801+ NVTE_Fused_Attn_Backend fused_attention_backend =
802+ nvte_get_fused_attn_backend (true , Q_type, KV_type, qkv_layout, bias_type, attn_mask_type,
803+ softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
804+ d, window_size_left, window_size_right, false , cuda_graph);
782805
783806 if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
784807#if (CUDNN_VERSION >= 8901)
@@ -833,16 +856,19 @@ void nvte_fused_attn_bwd_kvpacked(
833856 }
834857}
835858// NVTE fused attention FWD with separate Q, K and V
836- void nvte_fused_attn_fwd (
837- const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias,
838- const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
839- const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
840- const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
841- const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
842- size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit,
843- float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
844- NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
845- int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) {
859+ void nvte_fused_attn_fwd (const NVTETensor Q, const NVTETensor K, const NVTETensor V,
860+ const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
861+ NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
862+ const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
863+ const NVTETensor cu_seqlens_q_padded,
864+ const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
865+ const NVTETensor page_table_v, const NVTETensor rng_state,
866+ size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
867+ bool return_max_logit, bool cuda_graph, float attn_scale, float dropout,
868+ NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
869+ NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
870+ int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
871+ cudaStream_t stream) {
846872 NVTE_API_CALL (nvte_flash_attn_fwd);
847873 using namespace transformer_engine ;
848874 const Tensor *input_cu_seqlens_q = convertNVTETensorCheck (cu_seqlens_q);
@@ -913,7 +939,7 @@ void nvte_fused_attn_fwd(
913939 NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend (
914940 is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
915941 h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right,
916- return_max_logit);
942+ return_max_logit, cuda_graph );
917943
918944 if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
919945#if (CUDNN_VERSION >= 8901)
@@ -963,7 +989,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
963989 NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
964990 NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
965991 int64_t window_size_left, int64_t window_size_right, bool deterministic,
966- NVTETensor workspace, cudaStream_t stream) {
992+ bool cuda_graph, NVTETensor workspace, cudaStream_t stream) {
967993 NVTE_API_CALL (nvte_flash_attn_bwd);
968994 using namespace transformer_engine ;
969995 const Tensor *input_cu_seqlens_q = convertNVTETensorCheck (cu_seqlens_q);
@@ -1008,7 +1034,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
10081034
10091035 NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend (
10101036 true , Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q,
1011- h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false );
1037+ h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false ,
1038+ cuda_graph);
10121039
10131040 if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
10141041#if (CUDNN_VERSION >= 8901)
0 commit comments