Skip to content

Commit cbd27ae

Browse files
xinyazhangjammmScottTodd
authored
[ROCM] Backport AOTriton 0.11b to release/2.8 and other fixes (#2686)
Fixes: pytorch#163958 Cherry-pick pytorch#161754 Cherry-pick pytorch#162330 Cherry-pick pytorch#163373 Cherry-pick pytorch#163745 Note TF32 support is still being plagued by `HIPBLASLT_ALLOW_TF32`, which should be handled by another PR due to its complexity. --------- Co-authored-by: Aaryaman Vasishta <[email protected]> Co-authored-by: Scott Todd <[email protected]>
1 parent 2cd73af commit cbd27ae

File tree

17 files changed

+692
-211
lines changed

17 files changed

+692
-211
lines changed

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -867,7 +867,7 @@ cmake_dependent_option(
867867
"Whether to build the flash_attention kernel for scaled dot product attention.\
868868
Will be disabled if not supported by the platform"
869869
ON
870-
"USE_CUDA OR USE_ROCM;NOT MSVC"
870+
"(USE_CUDA AND NOT MSVC) OR USE_ROCM"
871871
OFF)
872872

873873
# CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem
@@ -883,7 +883,7 @@ cmake_dependent_option(
883883
# USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake
884884
#
885885
if(USE_ROCM)
886-
if(UNIX AND (USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION))
886+
if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)
887887
include(cmake/External/aotriton.cmake)
888888
endif()
889889
endif()

aten/src/ATen/native/transformers/cuda/attention.cu

Lines changed: 118 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,72 @@
9595
#endif
9696
#endif
9797

98+
#if defined(USE_ROCM) && (defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION))
99+
namespace pytorch_flash
100+
{
101+
std::tuple<
102+
at::Tensor,
103+
at::Tensor,
104+
at::Tensor,
105+
at::Tensor,
106+
at::Tensor,
107+
at::Tensor,
108+
at::Tensor,
109+
at::Tensor>
110+
mha_fwd(
111+
const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
112+
const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size
113+
const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size
114+
std::optional<at::Tensor>&
115+
out_, // batch_size x seqlen_q x num_heads x head_size
116+
std::optional<at::Tensor>&
117+
alibi_slopes_, // num_heads or batch_size x num_heads
118+
const float p_dropout,
119+
const float softmax_scale,
120+
bool is_causal,
121+
std::optional<int64_t> window_size_left,
122+
std::optional<int64_t> window_size_right,
123+
const float softcap,
124+
const bool return_softmax,
125+
std::optional<at::Generator> gen_) {
126+
#if defined(USE_ROCM_CK_SDPA)
127+
if (at::globalContext().getROCmFAPreferredBackend() ==
128+
at::ROCmFABackend::Ck) {
129+
const int non_null_window_left = window_size_left.value_or(-1);
130+
const int non_null_window_right = window_size_right.value_or(-1);
131+
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
132+
return mha_fwd_ck(
133+
q,
134+
k,
135+
v,
136+
out_,
137+
p_dropout,
138+
softmax_scale,
139+
is_causal,
140+
non_null_window_left,
141+
non_null_window_right,
142+
return_softmax,
143+
gen_,
144+
dummy_attn_bias); // Not used in flash attention
145+
}
146+
#endif
147+
return mha_fwd_aot(
148+
q,
149+
k,
150+
v,
151+
out_,
152+
alibi_slopes_,
153+
p_dropout,
154+
softmax_scale,
155+
is_causal,
156+
window_size_left,
157+
window_size_right,
158+
return_softmax,
159+
gen_);
160+
}
161+
}
162+
#endif
163+
98164
namespace at {
99165

100166
namespace cuda::philox {
@@ -1406,12 +1472,15 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
14061472
at::Tensor v_t = value.transpose(1, 2);
14071473
at::Tensor output_t = res.transpose(1, 2);
14081474
bool is_causal;
1409-
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
1410-
is_causal = true;
1411-
} else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
1475+
if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
14121476
is_causal = false;
14131477
} else {
1414-
TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now");
1478+
is_causal = true;
1479+
#if AOTRITON_V3_API == 0
1480+
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) != custom_mask_type) {
1481+
TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now");
1482+
}
1483+
#endif
14151484
}
14161485

14171486
at::Tensor atomic_counter;
@@ -1436,7 +1505,51 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
14361505
auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr<int64_t>() : nullptr);
14371506
auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr);
14381507
hipError_t err; // TODO: Error handling
1439-
if (seqstart_q.has_value()) {
1508+
if constexpr (AOTRITON_ALWAYS_V3_API) { // Better readability than nesting ifdef
1509+
#if AOTRITON_V3_API // if constexpr does not stop errors from undefined functions
1510+
using aotriton::v3::flash::CausalType;
1511+
using aotriton::v3::flash::VarlenType;
1512+
using aotriton::v3::flash::WindowValue;
1513+
aotriton::v3::flash::attn_fwd_params params;
1514+
params.Q = mk_aotensor(q_t, "q");
1515+
params.K = mk_aotensor(k_t, "k");
1516+
params.V = mk_aotensor(v_t, "v");
1517+
params.Sm_scale = softmax_scale;
1518+
params.L = compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2;
1519+
params.Out = mk_aotensor(output_t, "Out");
1520+
params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty
1521+
params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty
1522+
params.dropout_p = dropout_p;
1523+
params.philox_seed_ptr = seed;
1524+
params.philox_offset1 = offset1;
1525+
params.philox_offset2 = offset2;
1526+
params.philox_seed_output = seed_output;
1527+
params.philox_offset_output = offset_output;
1528+
params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax");
1529+
params.persistent_atomic_counter = persistent_counter;
1530+
params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None;
1531+
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
1532+
params.window_left = WindowValue::TopLeftAligned;
1533+
params.window_right = WindowValue::TopLeftAligned;
1534+
} else if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromBottomRight) == custom_mask_type) {
1535+
params.window_left = WindowValue::BottomRightAligned;
1536+
params.window_right = WindowValue::BottomRightAligned;
1537+
}
1538+
if (bias.has_value()) {
1539+
params.B = mk_aotensor(bias.value(), "bias");
1540+
}
1541+
if (seqstart_q.has_value()) {
1542+
params.varlen_type = VarlenType::CompactVarlen;
1543+
params.cu_seqlens_q = mk_aotensor<1>(seqstart_q.value(), "cu_seqlens_q");
1544+
params.cu_seqlens_k = mk_aotensor<1>(seqstart_k.value(), "cu_seqlens_k");
1545+
} else {
1546+
params.varlen_type = VarlenType::None;
1547+
}
1548+
err = aotriton::v3::flash::attn_fwd(params,
1549+
aotriton::v3::flash::attn_fwd_params::kVersion,
1550+
stream);
1551+
#endif // AOTRITON_V3_API
1552+
} else if (seqstart_q.has_value()) {
14401553
// varlen aka nested tensor
14411554
err = attn_fwd_compact_varlen(mk_aotensor(q_t, "q"),
14421555
mk_aotensor(k_t, "k"),

aten/src/ATen/native/transformers/cuda/attention_backward.cu

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <ATen/Functions.h>
2525
#include <ATen/NativeFunctions.h>
2626
#else
27+
#include <ATen/ops/zeros.h>
2728
#include <ATen/ops/zeros_like.h>
2829
#include <ATen/ops/empty_strided.h>
2930
#include <ATen/ops/_flash_attention_backward.h>
@@ -45,6 +46,7 @@
4546
#include <ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h>
4647
#include <ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h>
4748
#else
49+
#include <ATen/native/transformers/hip/gemm_kernel_utils.h>
4850
// MemoryEfficient Attention Specific Imports for ROCM
4951
#ifndef DISABLE_AOTRITON
5052
#include <ATen/native/transformers/hip/aotriton_adapter.h>
@@ -482,12 +484,15 @@ _efficient_attention_backward(
482484
}
483485
const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
484486
bool is_causal;
485-
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
486-
is_causal = true;
487-
} else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
487+
if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
488488
is_causal = false;
489489
} else {
490-
TORCH_CHECK(false, "[_efficient_attention_backward] Unsupported mask type in AOTriton, for now");
490+
is_causal = true;
491+
#if AOTRITON_V3_API == 0
492+
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) != custom_mask_type) {
493+
TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now");
494+
}
495+
#endif
491496
}
492497
at::Tensor q_t = query.permute({0,2,1,3});
493498
at::Tensor k_t = key.permute({0,2,1,3});
@@ -506,7 +511,62 @@ _efficient_attention_backward(
506511
using sdp::aotriton_adapter::mk_aoscalartensor;
507512
using sdp::aotriton_adapter::cast_dtype;
508513
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype()));
509-
if (cu_seqlens_q.has_value()) {
514+
if constexpr (AOTRITON_ALWAYS_V3_API) { // Better readability than nesting ifdef
515+
#if AOTRITON_V3_API // if constexpr does not stop errors from undefined functions
516+
using aotriton::v3::flash::CausalType;
517+
using aotriton::v3::flash::VarlenType;
518+
using aotriton::v3::flash::WindowValue;
519+
aotriton::v3::flash::attn_bwd_params params;
520+
params.Q = mk_aotensor(q_t, "q");
521+
params.K = mk_aotensor(k_t, "k");
522+
params.V = mk_aotensor(v_t, "v");
523+
params.B = bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4;
524+
params.Sm_scale = softmax_scale;
525+
params.Out = mk_aotensor(out_t, "out");
526+
params.DO = mk_aotensor(dout_t, "dout");
527+
params.DK = mk_aotensor(dk_t, "dk");
528+
params.DV = mk_aotensor(dv_t, "dv");
529+
params.DQ = mk_aotensor(dq_t, "dq");
530+
params.DB = bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4;
531+
params.L = mk_aotensor<2>(softmax_lse, "L");
532+
params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty
533+
params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty
534+
params.dropout_p = float(dropout_p);
535+
params.philox_seed_ptr = mk_aoscalartensor(philox_seed);
536+
params.philox_offset1 = mk_aoscalartensor(philox_offset);
537+
params.philox_offset2 = 0;
538+
params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None;
539+
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
540+
params.window_left = WindowValue::TopLeftAligned;
541+
params.window_right = WindowValue::TopLeftAligned;
542+
} else if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromBottomRight) == custom_mask_type) {
543+
params.window_left = WindowValue::BottomRightAligned;
544+
params.window_right = WindowValue::BottomRightAligned;
545+
}
546+
#if AOTRITON_ALWAYS_V3_API
547+
using sdp::aotriton_adapter::mklazy_empty_like;
548+
using sdp::aotriton_adapter::mklazy_fp32zeros;
549+
using sdp::aotriton_adapter::LazyTensorContext;
550+
LazyTensorContext lazy_delta { .like_tensor = softmax_lse, .tensor_name = "delta" };
551+
LazyTensorContext lazy_dq_acc { .like_tensor = dq_t, .tensor_name = "dq_acc" };
552+
params.D = mklazy_empty_like<2>(&lazy_delta);
553+
params.DQ_ACC = mklazy_fp32zeros<4>(&lazy_dq_acc);
554+
#else
555+
at::Tensor delta = at::empty_like(softmax_lse).contiguous();
556+
params.D = mk_aotensor<2>(delta, "delta");
557+
#endif
558+
if (cu_seqlens_q.has_value()) {
559+
params.varlen_type = VarlenType::CompactVarlen;
560+
params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q.value(), "cu_seqlens_q");
561+
params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k.value(), "cu_seqlens_k");
562+
} else {
563+
params.varlen_type = VarlenType::None;
564+
}
565+
err = aotriton::v3::flash::attn_bwd(params,
566+
aotriton::v3::flash::attn_bwd_params::kVersion,
567+
stream);
568+
#endif // AOTRITON_V3_API
569+
} else if (cu_seqlens_q.has_value()) {
510570
at::Tensor delta = at::empty_like(softmax_lse).contiguous();
511571
// varlen aka Nested tensor
512572
err = attn_bwd_compact_varlen(mk_aotensor(q_t, "q"),

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <c10/util/irange.h>
1717
#include <c10/util/Array.h>
1818
#include <c10/util/Exception.h>
19+
#include <c10/util/string_view.h>
1920

2021
#if AT_CUDNN_ENABLED()
2122
#include <ATen/cudnn/cudnn-wrapper.h>
@@ -25,9 +26,12 @@
2526

2627
#if USE_ROCM
2728
#if defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION)
29+
#include <ATen/native/transformers/hip/aotriton_versions.h>
2830
#include <aotriton/flash.h>
2931
#define USE_ROCM_ATTENTION 1
3032
#endif
33+
#else
34+
#define USE_ROCM_ATTENTION 0
3135
#endif
3236

3337
// Avoid potential compiler -Wall -Werror complains undefined macro
@@ -112,9 +116,24 @@ int64_t minimum_gemm_alignment(sdp_params const& params) {
112116
// caller_is_meff is added to make the TORCH_WARN message showing the correct result
113117
template<bool caller_is_meff = false>
114118
bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
115-
#if USE_ROCM_ATTENTION && AOTRITON_VERSION_MINOR >= 9
119+
#if USE_ROCM_ATTENTION
116120
// AOTriton 0.9+ supports head_dim up to 512
117-
const auto max_size = c10::SymInt(512);
121+
const static auto max_hdim = []() {
122+
#if AOTRITON_VERSION_CURRENT == AOTRITON_VERSION_INT(0, 11)
123+
// gfx11xx only support hdim <= 256 on AOTriton 0.11
124+
auto dprops = at::cuda::getCurrentDeviceProperties();
125+
const c10::basic_string_view<char> arch(dprops->gcnArchName);
126+
if (arch.starts_with("gfx11")) {
127+
return 256;
128+
}
129+
#endif // AOTriton 0.11
130+
#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 9)
131+
return 512;
132+
#else
133+
return 256;
134+
#endif
135+
}();
136+
const auto max_size = c10::SymInt(max_hdim);
118137
#else
119138
// All head_dim sizes must be equal and less than 256
120139
const auto max_size = c10::SymInt(256);
@@ -139,6 +158,28 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
139158
}
140159
return false;
141160
}
161+
if constexpr(caller_is_meff) {
162+
bool is_half = (params.query.dtype() == at::kHalf) ||
163+
(params.query.dtype() == at::kBFloat16);
164+
const int64_t alignment = is_half ? 8 : 4;
165+
if (!(query_size_last % alignment == 0 && query_size_last > 0 &&
166+
value_size_last % alignment == 0 && value_size_last > 0)) {
167+
if (debug) {
168+
TORCH_WARN(
169+
"Mem efficient attention requires last dimension of inputs to be divisible by ",
170+
alignment,
171+
". ",
172+
"Got Query.size(-1): ",
173+
query_size_last,
174+
", Key.size(-1): ",
175+
params.key.sym_size(-1),
176+
", Value.size(-1): ",
177+
params.value.sym_size(-1),
178+
" instead.");
179+
}
180+
return false;
181+
}
182+
}
142183
return true;
143184
}
144185

0 commit comments

Comments
 (0)