diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index e58b2da3a8..0a9c8d454f 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -307,6 +307,7 @@ def run_dpa_with_cp( if config.attn_bias_type not in ["no_bias", "alibi"]: attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv) bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda() + bias.requires_grad = True else: bias = None @@ -338,7 +339,7 @@ def run_dpa_with_cp( out.backward(dout_fp8) else: out.backward(dout) - dq, dk, dv = q.grad, k.grad, v.grad + dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad d_softmax_offset = None if config.softmax_type != "vanilla": d_softmax_offset = core_attn.softmax_offset.grad @@ -394,6 +395,7 @@ def run_dpa_with_cp( ) bias_ = bias_.index_select(2, seq_idx) bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) + bias_.requires_grad = True # set up environment core_attn.set_context_parallel_group( cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, @@ -433,23 +435,23 @@ def run_dpa_with_cp( out_.backward(dout_fp8_) else: out_.backward(dout_) - dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad + dq_, dk_, dv_, dbias_ = q_.grad, k_.grad, v_.grad, bias_.grad d_softmax_offset_ = None if config.softmax_type != "vanilla": d_softmax_offset_ = core_attn.softmax_offset.grad.clone() # get outputs - tensors = [out, dq, dk, dv, out_, dq_, dk_, dv_] + tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_] if fp8_mha: tensors_to_deq = [out, out_] if not fp8_bwd else tensors for i, tensor in enumerate(tensors_to_deq): tensors_to_deq[i] = tensor.dequantize() if not fp8_bwd: - tensors[0], tensors[4] = tensors_to_deq + tensors[0], tensors[5] = tensors_to_deq for tensor in tensors: assert torch.all(~torch.isnan(tensor)) assert torch.all(~torch.isinf(tensor)) - out, dq, dk, dv, out_, dq_, dk_, dv_ = tensors + out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors ############ compare results between CP and no-CP ############ if qkv_format == "bshd" or qkv_format == "sbhd": @@ -467,6 +469,22 @@ def run_dpa_with_cp( x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :]) for x in [dq_, dk_, dv_, out_] ] + if dbias is not None and dbias_ is not None: + dbias = dbias.view( + dbias.shape[0], + dbias.shape[1], + 2 * world_size, + dbias.shape[2] // (2 * world_size), + dbias.shape[3], + ) + # bias has fixed axis (2) as dbias shape: (1, 1, max_seqlen_q, max_seqlen_kv) + dbias = dbias.index_select(2, seq_idx) + # Flatten + dbias = dbias.view(dbias.shape[0], dbias.shape[1], -1, dbias.shape[-1]) + dbias_ = dbias_.view( + dbias_.shape[0], dbias_.shape[1], 2, dbias_.shape[2] // 2, dbias_.shape[3] + ) + elif qkv_format == "thd": dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] @@ -509,9 +527,9 @@ def run_dpa_with_cp( ) atol, rtol, rmse_tol = get_tols(config, dtype) - tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_, max_logit_] - tensors_no_cp = [out, dq, dk, dv, d_softmax_offset, max_logit] - names = ["out", "dq", "dk", "dv", "d_softmax_offset", "max_logit"] + tensors_cp = [out_, dq_, dk_, dv_, dbias_, d_softmax_offset_, max_logit_] + tensors_no_cp = [out, dq, dk, dv, dbias, d_softmax_offset, max_logit] + names = ["out", "dq", "dk", "dv", "dbias", "d_softmax_offset", "max_logit"] names_cp = [x + "_cp" for x in names] names_no_cp = [x + "_no_cp" for x in names] is_fp8 = dtype == "fp8" @@ -519,47 +537,103 @@ def run_dpa_with_cp( if t is not None: if "softmax_offset" not in names[i] and "max_logit" not in names[i]: if qkv_format == "bshd": - compare_and_assert( - t[:, 0], - tensors_cp[i][:, 0], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - compare_and_assert( - t[:, 1], - tensors_cp[i][:, 1], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) + # Compare the two sequence chunks separately + # Compare dbias + if names[i] == "dbias": + # After reshaping: (1, 1, 2, seq_q//2, seq_kv) + # Compare along dimension 2 (the split sequence dimension) + compare_and_assert( + t[:, :, 0], # First sequence chunk + tensors_cp[i][:, :, 0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[:, :, 1], # Second sequence chunk + tensors_cp[i][:, :, 1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + # Compare Q/K/V/out + else: + # Compare along dimension 1 (the split sequence dimension) + compare_and_assert( + t[:, 0], + tensors_cp[i][:, 0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[:, 1], + tensors_cp[i][:, 1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) elif qkv_format == "sbhd": - compare_and_assert( - t[0], - tensors_cp[i][0], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - compare_and_assert( - t[1], - tensors_cp[i][1], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) + # Compare the two sequence chunks separately + # Compare dbias (same as BSHD) + if names[i] == "dbias": + # After reshaping: (1, 1, 2, seq_q//2, seq_kv) + # Compare along dimension 2 (the split sequence dimension) + compare_and_assert( + t[:, :, 0], # First sequence chunk + tensors_cp[i][:, :, 0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[:, :, 1], # Second sequence chunk + tensors_cp[i][:, :, 1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + # Compare Q/K/V/out + else: + # Compare along dimension 0 (the split sequence dimension) + compare_and_assert( + t[0], + tensors_cp[i][0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[1], + tensors_cp[i][1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) elif qkv_format == "thd": compare_and_assert( t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 4aedcff1b8..57965723ac 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -505,7 +505,7 @@ def test_dpa_mask(dtype, model_configs, model): model_configs_bias = { # test: ModelConfig(b, sq, hq, dqk) - "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"), + "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="111s"), "bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"), "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"), "bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"), @@ -1118,11 +1118,10 @@ def _run_dot_product_attention( bias = None if config.attn_bias_type == "post_scale_bias": shape = "_".join(config.bias_shape) + shape = shape.replace("_1_s", "_1_skv") shape = shape.replace("_s_s", "_sq_skv") tensor_shape = [dim_to_num[j] for j in shape.split("_")] bias = torch.randn(tensor_shape, dtype=dtype, device="cuda") - if config.bias_shape != "1hss": - bias.requires_grad = False # Create RNG _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index efa4c78439..5d91829d39 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -52,15 +52,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v, int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, - int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, - bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, - void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, - void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, - void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, - void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, - void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, + bool is_training, bool return_max_logit, float scaling_factor, float dropout_probability, + NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, + void *devPtrS1, void *devPtrS2, void *devPtrO, void *devPtrDropoutSeed, + void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, + void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ, + void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, + size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -120,6 +121,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( max_pages_per_seq_v, bias_b, bias_h, + bias_sq, + bias_skv, scaling_factor, is_training, dropout_probability, @@ -261,10 +264,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( sdpa_options.set_alibi_mask(is_alibi); if (is_bias) { - bias = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("bias") - .set_dim({bias_b, bias_h, s_q, s_kv}) - .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + bias = mha_graph->tensor( + fe::graph::Tensor_attributes() + .set_name("bias") + .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); sdpa_options.set_bias(bias); } @@ -540,15 +544,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl( void fused_attn_arbitrary_seqlen_bwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, - float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ, - void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, - void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV, - void *devPtrdO, void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, - void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, - void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, - void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + int64_t bias_sq, int64_t bias_skv, float scaling_factor, float dropout_probability, + NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool deterministic, void *devPtrQ, void *devPtrKTranspose, void *devPtrVTranspose, + void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, void *devPtrSoftmaxOffset, + void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, + void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, void *devPtrDropoutOffset, + void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, + void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, + size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -612,6 +617,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( 0, bias_b, bias_h, + bias_sq, + bias_skv, scaling_factor, true, dropout_probability, @@ -792,14 +799,16 @@ void fused_attn_arbitrary_seqlen_bwd_impl( sdpa_backward_options.set_alibi_mask(is_alibi); if (is_bias) { - bias = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("bias") - .set_dim({bias_b, bias_h, s_q, s_kv}) - .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); - dBias = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("dBias") - .set_dim({bias_b, bias_h, s_q, s_kv}) - .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + bias = mha_graph->tensor( + fe::graph::Tensor_attributes() + .set_name("bias") + .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); + dBias = mha_graph->tensor( + fe::graph::Tensor_attributes() + .set_name("dBias") + .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); sdpa_backward_options.set_bias(bias); // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] // are not supported for dbias calculation but they are @@ -1064,10 +1073,14 @@ void fused_attn_arbitrary_seqlen_fwd( void *devPtrBias = nullptr; size_t bias_b = 0; size_t bias_h = 0; + size_t bias_sq = 0; + size_t bias_skv = 0; if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { devPtrBias = input_Bias->data.dptr; bias_b = input_Bias->data.shape[0]; bias_h = input_Bias->data.shape[1]; + bias_sq = input_Bias->data.shape[2]; + bias_skv = input_Bias->data.shape[3]; } void *devPtrSoftmaxOffset = nullptr; if (softmax_type != NVTE_VANILLA_SOFTMAX) { @@ -1133,7 +1146,7 @@ void fused_attn_arbitrary_seqlen_fwd( if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_bias->data.dptr = nullptr; - output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; + output_bias->data.shape = {bias_b, bias_h, bias_sq, bias_skv}; output_bias->data.dtype = QKV_type; } @@ -1178,9 +1191,9 @@ void fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, - page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, - return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, + is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, @@ -1224,11 +1237,15 @@ void fused_attn_arbitrary_seqlen_bwd( void *devPtrdBias = nullptr; size_t bias_b = 0; size_t bias_h = 0; + size_t bias_sq = 0; + size_t bias_skv = 0; if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { devPtrBias = input_Bias->data.dptr; devPtrdBias = output_dBias->data.dptr; bias_b = output_dBias->data.shape[0]; bias_h = output_dBias->data.shape[1]; + bias_sq = output_dBias->data.shape[2]; + bias_skv = output_dBias->data.shape[3]; } size_t max_batch_size = 0; @@ -1271,10 +1288,10 @@ void fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, - max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, - deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, - devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, + max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, bias_sq, bias_skv, attn_scale, + p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, + devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 5d806290a9..54a1bb9a65 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1671,6 +1671,8 @@ void fused_attn_fp8_fwd_impl_v1( bool is_dropout = (is_training && dropout_probability != 0.0f); auto bias_b = b; auto bias_h = h; + auto bias_sq = s_q; + auto bias_skv = s_kv; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); bool is_current_scaling = (o_tensor_type == cudnn_frontend::DataType_t::HALF || @@ -1697,6 +1699,8 @@ void fused_attn_fp8_fwd_impl_v1( 0, bias_b, bias_h, + bias_sq, + bias_skv, scaling_factor, is_training, dropout_probability, @@ -1817,8 +1821,8 @@ void fused_attn_fp8_fwd_impl_v1( // if (is_bias) { // bias = mha_graph->tensor(fe::graph::Tensor_attributes() // .set_name("bias") - // .set_dim({bias_b, bias_h, s_q, s_kv}) - // .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + // .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + // .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); // sdpa_options.set_bias(bias); // } @@ -1998,6 +2002,8 @@ void fused_attn_fp8_bwd_impl_v1( bool is_dropout = (dropout_probability != 0.0f); auto bias_b = b; auto bias_h = h; + auto bias_sq = s_q; + auto bias_skv = s_kv; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); bool is_current_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || @@ -2026,6 +2032,8 @@ void fused_attn_fp8_bwd_impl_v1( 0, bias_b, bias_h, + bias_sq, + bias_skv, scaling_factor, true, dropout_probability, @@ -2192,12 +2200,12 @@ void fused_attn_fp8_bwd_impl_v1( // if (is_bias) { // bias = mha_graph->tensor(fe::graph::Tensor_attributes() // .set_name("bias") - // .set_dim({bias_b, bias_h, s_q, s_kv}) - // .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + // .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + // .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); // dBias = mha_graph->tensor(fe::graph::Tensor_attributes() // .set_name("dBias") - // .set_dim({bias_b, bias_h, s_q, s_kv}) - // .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + // .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + // .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); // sdpa_backward_options.set_bias(bias); // // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] // // are not supported for dbias calculation but they are diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 72047a73f2..1ba8a60f03 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -101,6 +101,8 @@ struct FADescriptor_v1 { std::int64_t max_pages_per_seq_v; std::int64_t bias_b; std::int64_t bias_h; + std::int64_t bias_sq; + std::int64_t bias_skv; float attnScale; bool isTraining; float dropoutProbability; @@ -119,17 +121,18 @@ struct FADescriptor_v1 { bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, - page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, - attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, - window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type, - o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_max_sum_exp) < + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, + bias_skv, attnScale, isTraining, dropoutProbability, layout, mask_type, + softmax_type, window_size_left, window_size_right, deterministic, bias_type, + qkv_tensor_type, o_tensor_type, do_tensor_type, dqkv_tensor_type, + generate_max_sum_exp) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, - rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, - rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type, - rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, - rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, - rhs.dqkv_tensor_type, rhs.generate_max_sum_exp); + rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv, + rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, + rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right, + rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, + rhs.do_tensor_type, rhs.dqkv_tensor_type, rhs.generate_max_sum_exp); } }; diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 10a06ed965..fec44739a1 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -958,12 +958,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt and fu_core_attention_bias_type == "post_scale_bias" and fu_core_attention_bias_shape != "1hss" ): - if fu_core_attention_bias_requires_grad: - # remove this line when cuDNN adds bwd support for - # [1, 1, s, s], [b, 1, s, s] and [b, h, s, s] - logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape") - use_fused_attention = False - else: + if not fu_core_attention_bias_requires_grad: # max512 backend will only support [1, h, s, s] os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"