diff --git a/.ci/docker/common/install_rocm_magma.sh b/.ci/docker/common/install_rocm_magma.sh index db826ed6e027..a8d8ba00b35b 100644 --- a/.ci/docker/common/install_rocm_magma.sh +++ b/.ci/docker/common/install_rocm_magma.sh @@ -1,60 +1,37 @@ -#!/bin/bash -# Script used in CI and CD pipeline +#!/usr/bin/env bash +# Script used only in CD pipeline -set -ex +set -eou pipefail -ver() { - printf "%3d%03d%03d%03d" $(echo "$1" | tr '.' ' '); -} - -# Magma build scripts need `python` -ln -sf /usr/bin/python3 /usr/bin/python - -ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"') -case "$ID" in - almalinux) - yum install -y gcc-gfortran - ;; - *) - echo "No preinstalls to build magma..." - ;; -esac +function do_install() { + rocm_version=$1 + if [[ ${rocm_version} =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then + # chop off any patch version + rocm_version="${rocm_version%.*}" + fi -MKLROOT=${MKLROOT:-/opt/conda/envs/py_$ANACONDA_PYTHON_VERSION} + rocm_version_nodot=${rocm_version//./} -# "install" hipMAGMA into /opt/rocm/magma by copying after build -if [[ $(ver $ROCM_VERSION) -ge $(ver 7.0) ]]; then - git clone https://github.com/ROCm/utk-magma.git -b release/2.9.0_rocm70 magma - pushd magma - # version 2.9 + ROCm 7.0 related updates - git checkout 91c4f720a17e842b364e9de41edeef76995eb9ad -else - git clone https://bitbucket.org/icl/magma.git - pushd magma # Version 2.7.2 + ROCm related updates - git checkout a1625ff4d9bc362906bd01f805dbbe12612953f6 -fi + MAGMA_VERSION=a1625ff4d9bc362906bd01f805dbbe12612953f6 + magma_archive="magma-rocm${rocm_version_nodot}-${MAGMA_VERSION}-1.tar.bz2" + + rocm_dir="/opt/rocm" + ( + set -x + tmp_dir=$(mktemp -d) + pushd ${tmp_dir} + curl -OLs https://ossci-linux.s3.us-east-1.amazonaws.com/${magma_archive} + if tar -xvf "${magma_archive}" + then + mkdir -p "${rocm_dir}/magma" + mv include "${rocm_dir}/magma/include" + mv lib "${rocm_dir}/magma/lib" + else + echo "${magma_archive} not found, skipping magma install" + fi + popd + ) +} -cp make.inc-examples/make.inc.hip-gcc-mkl make.inc -echo 'LIBDIR += -L$(MKLROOT)/lib' >> make.inc -if [[ -f "${MKLROOT}/lib/libmkl_core.a" ]]; then - echo 'LIB = -Wl,--start-group -lmkl_gf_lp64 -lmkl_gnu_thread -lmkl_core -Wl,--end-group -lpthread -lstdc++ -lm -lgomp -lhipblas -lhipsparse' >> make.inc -fi -echo 'LIB += -Wl,--enable-new-dtags -Wl,--rpath,/opt/rocm/lib -Wl,--rpath,$(MKLROOT)/lib -Wl,--rpath,/opt/rocm/magma/lib -ldl' >> make.inc -echo 'DEVCCFLAGS += --gpu-max-threads-per-block=256' >> make.inc -export PATH="${PATH}:/opt/rocm/bin" -if [[ -n "$PYTORCH_ROCM_ARCH" ]]; then - amdgpu_targets=`echo $PYTORCH_ROCM_ARCH | sed 's/;/ /g'` -else - amdgpu_targets=`rocm_agent_enumerator | grep -v gfx000 | sort -u | xargs` -fi -for arch in $amdgpu_targets; do - echo "DEVCCFLAGS += --offload-arch=$arch" >> make.inc -done -# hipcc with openmp flag may cause isnan() on __device__ not to be found; depending on context, compiler may attempt to match with host definition -sed -i 's/^FOPENMP/#FOPENMP/g' make.inc -make -f make.gen.hipMAGMA -j $(nproc) -LANG=C.UTF-8 make lib/libmagma.so -j $(nproc) MKLROOT="${MKLROOT}" -make testing/testing_dgemm -j $(nproc) MKLROOT="${MKLROOT}" -popd -mv magma /opt/rocm +do_install $1 diff --git a/CMakeLists.txt b/CMakeLists.txt index a5d25e6afa0f..63025c26a05e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -867,7 +867,7 @@ cmake_dependent_option( "Whether to build the flash_attention kernel for scaled dot product attention.\ Will be disabled if not supported by the platform" ON - "USE_CUDA OR USE_ROCM;NOT MSVC" + "(USE_CUDA AND NOT MSVC) OR USE_ROCM" OFF) # CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem @@ -883,7 +883,7 @@ cmake_dependent_option( # USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake # if(USE_ROCM) - if(UNIX AND (USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)) + if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION) include(cmake/External/aotriton.cmake) endif() endif() diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index e49fffc2effc..df43bfac16f7 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -56,9 +56,10 @@ constexpr uint64_t getDefaultMaxThreadsPerBlock() { #endif } -#if 0 +#ifdef USE_ROCM +#define SKIP_SORTED_INDICES 32 template -__global__ void indexing_backward_kernel( +__global__ void indexing_backward_kernel_many_indices( const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) { using opmath_t = at::opmath_type; @@ -141,10 +142,7 @@ __global__ void indexing_backward_kernel( } } } -#endif -#ifdef USE_ROCM -#define SKIP_SORTED_INDICES 32 template __global__ void indexing_backward_kernel_stride_1( const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, @@ -784,6 +782,38 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List= 200000) + AT_DISPATCH_V2( + expandedValue.scalar_type(), + "indexing_backward_many_indices", + AT_WRAP([&] { + indexing_backward_kernel_many_indices<<>>( + sorted_indices.const_data_ptr(), + orig_indices.const_data_ptr(), + expandedValue.const_data_ptr(), + src_.mutable_data_ptr(), + num_indices, + sliceSize, + strideBefore, + nElemBefore, + accumulate); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }), + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), + // AT_EXPAND(AT_FLOAT8_TYPES), + // TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True + // should not be supported here, then reenable AT_FLOAT8_DTYPES + kFloat8_e4m3fn, + kFloat8_e5m2, + kFloat8_e4m3fnuz, + kFloat8_e5m2fnuz, + kComplexHalf, + kHalf, + kBool, + kBFloat16); + else +#endif AT_DISPATCH_V2( expandedValue.scalar_type(), "indexing_backward", diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 80049aa9a832..76a62b3f7f8a 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -95,6 +95,72 @@ #endif #endif +#if defined(USE_ROCM) && (defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION)) +namespace pytorch_flash +{ +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_fwd( + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + out_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_) { +#if defined(USE_ROCM_CK_SDPA) + if (at::globalContext().getROCmFAPreferredBackend() == + at::ROCmFABackend::Ck) { + const int non_null_window_left = window_size_left.value_or(-1); + const int non_null_window_right = window_size_right.value_or(-1); + std::optional dummy_attn_bias = std::nullopt; + return mha_fwd_ck( + q, + k, + v, + out_, + p_dropout, + softmax_scale, + is_causal, + non_null_window_left, + non_null_window_right, + return_softmax, + gen_, + dummy_attn_bias); // Not used in flash attention + } +#endif + return mha_fwd_aot( + q, + k, + v, + out_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); +} +} +#endif + namespace at { namespace cuda::philox { @@ -1406,12 +1472,15 @@ std::tuple _efficient_ at::Tensor v_t = value.transpose(1, 2); at::Tensor output_t = res.transpose(1, 2); bool is_causal; - if (static_cast(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) { - is_causal = true; - } else if (static_cast(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) { + if (static_cast(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) { is_causal = false; } else { - TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now"); + is_causal = true; +#if AOTRITON_V3_API == 0 + if (static_cast(sdp::CustomMaskType::CausalFromTopLeft) != custom_mask_type) { + TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now"); + } +#endif } at::Tensor atomic_counter; @@ -1436,7 +1505,51 @@ std::tuple _efficient_ auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr() : nullptr); auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr() : nullptr); hipError_t err; // TODO: Error handling - if (seqstart_q.has_value()) { + if constexpr (AOTRITON_ALWAYS_V3_API) { // Better readability than nesting ifdef +#if AOTRITON_V3_API // if constexpr does not stop errors from undefined functions + using aotriton::v3::flash::CausalType; + using aotriton::v3::flash::VarlenType; + using aotriton::v3::flash::WindowValue; + aotriton::v3::flash::attn_fwd_params params; + params.Q = mk_aotensor(q_t, "q"); + params.K = mk_aotensor(k_t, "k"); + params.V = mk_aotensor(v_t, "v"); + params.Sm_scale = softmax_scale; + params.L = compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2; + params.Out = mk_aotensor(output_t, "Out"); + params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty + params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty + params.dropout_p = dropout_p; + params.philox_seed_ptr = seed; + params.philox_offset1 = offset1; + params.philox_offset2 = offset2; + params.philox_seed_output = seed_output; + params.philox_offset_output = offset_output; + params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax"); + params.persistent_atomic_counter = persistent_counter; + params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; + if (static_cast(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) { + params.window_left = WindowValue::TopLeftAligned; + params.window_right = WindowValue::TopLeftAligned; + } else if (static_cast(sdp::CustomMaskType::CausalFromBottomRight) == custom_mask_type) { + params.window_left = WindowValue::BottomRightAligned; + params.window_right = WindowValue::BottomRightAligned; + } + if (bias.has_value()) { + params.B = mk_aotensor(bias.value(), "bias"); + } + if (seqstart_q.has_value()) { + params.varlen_type = VarlenType::CompactVarlen; + params.cu_seqlens_q = mk_aotensor<1>(seqstart_q.value(), "cu_seqlens_q"); + params.cu_seqlens_k = mk_aotensor<1>(seqstart_k.value(), "cu_seqlens_k"); + } else { + params.varlen_type = VarlenType::None; + } + err = aotriton::v3::flash::attn_fwd(params, + aotriton::v3::flash::attn_fwd_params::kVersion, + stream); +#endif // AOTRITON_V3_API + } else if (seqstart_q.has_value()) { // varlen aka nested tensor err = attn_fwd_compact_varlen(mk_aotensor(q_t, "q"), mk_aotensor(k_t, "k"), diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index 8940bea9a27f..0339f6eec055 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -24,6 +24,7 @@ #include #include #else +#include #include #include #include @@ -45,6 +46,7 @@ #include #include #else +#include // MemoryEfficient Attention Specific Imports for ROCM #ifndef DISABLE_AOTRITON #include @@ -482,12 +484,15 @@ _efficient_attention_backward( } const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float(); bool is_causal; - if (static_cast(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) { - is_causal = true; - } else if (static_cast(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) { + if (static_cast(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) { is_causal = false; } else { - TORCH_CHECK(false, "[_efficient_attention_backward] Unsupported mask type in AOTriton, for now"); + is_causal = true; +#if AOTRITON_V3_API == 0 + if (static_cast(sdp::CustomMaskType::CausalFromTopLeft) != custom_mask_type) { + TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now"); + } +#endif } at::Tensor q_t = query.permute({0,2,1,3}); at::Tensor k_t = key.permute({0,2,1,3}); @@ -506,7 +511,62 @@ _efficient_attention_backward( using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::cast_dtype; aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype())); - if (cu_seqlens_q.has_value()) { + if constexpr (AOTRITON_ALWAYS_V3_API) { // Better readability than nesting ifdef +#if AOTRITON_V3_API // if constexpr does not stop errors from undefined functions + using aotriton::v3::flash::CausalType; + using aotriton::v3::flash::VarlenType; + using aotriton::v3::flash::WindowValue; + aotriton::v3::flash::attn_bwd_params params; + params.Q = mk_aotensor(q_t, "q"); + params.K = mk_aotensor(k_t, "k"); + params.V = mk_aotensor(v_t, "v"); + params.B = bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4; + params.Sm_scale = softmax_scale; + params.Out = mk_aotensor(out_t, "out"); + params.DO = mk_aotensor(dout_t, "dout"); + params.DK = mk_aotensor(dk_t, "dk"); + params.DV = mk_aotensor(dv_t, "dv"); + params.DQ = mk_aotensor(dq_t, "dq"); + params.DB = bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4; + params.L = mk_aotensor<2>(softmax_lse, "L"); + params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty + params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty + params.dropout_p = float(dropout_p); + params.philox_seed_ptr = mk_aoscalartensor(philox_seed); + params.philox_offset1 = mk_aoscalartensor(philox_offset); + params.philox_offset2 = 0; + params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; + if (static_cast(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) { + params.window_left = WindowValue::TopLeftAligned; + params.window_right = WindowValue::TopLeftAligned; + } else if (static_cast(sdp::CustomMaskType::CausalFromBottomRight) == custom_mask_type) { + params.window_left = WindowValue::BottomRightAligned; + params.window_right = WindowValue::BottomRightAligned; + } +#if AOTRITON_ALWAYS_V3_API + using sdp::aotriton_adapter::mklazy_empty_like; + using sdp::aotriton_adapter::mklazy_fp32zeros; + using sdp::aotriton_adapter::LazyTensorContext; + LazyTensorContext lazy_delta { .like_tensor = softmax_lse, .tensor_name = "delta" }; + LazyTensorContext lazy_dq_acc { .like_tensor = dq_t, .tensor_name = "dq_acc" }; + params.D = mklazy_empty_like<2>(&lazy_delta); + params.DQ_ACC = mklazy_fp32zeros<4>(&lazy_dq_acc); +#else + at::Tensor delta = at::empty_like(softmax_lse).contiguous(); + params.D = mk_aotensor<2>(delta, "delta"); +#endif + if (cu_seqlens_q.has_value()) { + params.varlen_type = VarlenType::CompactVarlen; + params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q.value(), "cu_seqlens_q"); + params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k.value(), "cu_seqlens_k"); + } else { + params.varlen_type = VarlenType::None; + } + err = aotriton::v3::flash::attn_bwd(params, + aotriton::v3::flash::attn_bwd_params::kVersion, + stream); +#endif // AOTRITON_V3_API + } else if (cu_seqlens_q.has_value()) { at::Tensor delta = at::empty_like(softmax_lse).contiguous(); // varlen aka Nested tensor err = attn_bwd_compact_varlen(mk_aotensor(q_t, "q"), diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 45b4cf118c1b..0df958c4c010 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #if AT_CUDNN_ENABLED() #include @@ -25,9 +26,12 @@ #if USE_ROCM #if defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION) +#include #include #define USE_ROCM_ATTENTION 1 #endif +#else +#define USE_ROCM_ATTENTION 0 #endif // Avoid potential compiler -Wall -Werror complains undefined macro @@ -112,9 +116,24 @@ int64_t minimum_gemm_alignment(sdp_params const& params) { // caller_is_meff is added to make the TORCH_WARN message showing the correct result template bool check_head_dim_size_flash(sdp_params const& params, bool debug) { -#if USE_ROCM_ATTENTION && AOTRITON_VERSION_MINOR >= 9 +#if USE_ROCM_ATTENTION // AOTriton 0.9+ supports head_dim up to 512 - const auto max_size = c10::SymInt(512); + const static auto max_hdim = []() { +#if AOTRITON_VERSION_CURRENT == AOTRITON_VERSION_INT(0, 11) + // gfx11xx only support hdim <= 256 on AOTriton 0.11 + auto dprops = at::cuda::getCurrentDeviceProperties(); + const c10::basic_string_view arch(dprops->gcnArchName); + if (arch.starts_with("gfx11")) { + return 256; + } +#endif // AOTriton 0.11 +#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 9) + return 512; +#else + return 256; +#endif + }(); + const auto max_size = c10::SymInt(max_hdim); #else // All head_dim sizes must be equal and less than 256 const auto max_size = c10::SymInt(256); @@ -139,6 +158,28 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) { } return false; } + if constexpr(caller_is_meff) { + bool is_half = (params.query.dtype() == at::kHalf) || + (params.query.dtype() == at::kBFloat16); + const int64_t alignment = is_half ? 8 : 4; + if (!(query_size_last % alignment == 0 && query_size_last > 0 && + value_size_last % alignment == 0 && value_size_last > 0)) { + if (debug) { + TORCH_WARN( + "Mem efficient attention requires last dimension of inputs to be divisible by ", + alignment, + ". ", + "Got Query.size(-1): ", + query_size_last, + ", Key.size(-1): ", + params.key.sym_size(-1), + ", Value.size(-1): ", + params.value.sym_size(-1), + " instead."); + } + return false; + } + } return true; } diff --git a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h index b38122248db8..a80d4053b27b 100644 --- a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h +++ b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h @@ -2,8 +2,12 @@ #ifdef USE_ROCM +// Expect to be included after headers of at::zeros_like and at::empty_like + #include #include +#include +#include //////////////////////////////////////////////////////////////////////////////// // Common macros copied from cuda/mem_eff_attention/gemm_kernel_utils.h @@ -111,6 +115,61 @@ inline aotriton::TensorView<0> mk_atomictensor(const int32_t* ptr) aotriton::DType::kInt32); } +#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 11) + +struct LazyTensorContext { + at::Tensor like_tensor; + std::string_view tensor_name; + at::Tensor tensor; +}; + +template +struct LazyTensorFunctions : public LazyTensorContext { + static aotriton::TensorView acquire(void* cookie) { + auto ctx = (LazyTensorContext*)cookie; + if (!ctx->tensor.defined()) { + auto q = ctx->like_tensor; + if constexpr (kRequireZeros) { + ctx->tensor = at::zeros(q.sizes(), + q.options().dtype(at::kFloat)); + } else { + ctx->tensor = at::empty_like(q); + } + } + return mk_aotensor(ctx->tensor, ctx->tensor_name); + } + + static void dispose(void* cookie) { + } +}; + +template +aotriton::LazyTensor mklazy_common(LazyTensorContext* cookie) +{ + using LTF = LazyTensorFunctions; + return aotriton::LazyTensor { + .cookie = cookie, + .acquire = <F::acquire, + .dispose = <F::dispose + }; +} + +template +auto mklazy_empty_like(LazyTensorContext* cookie) +{ + return mklazy_common(cookie); +} + + +// Note: this will not keep the original strides +template +auto mklazy_fp32zeros(LazyTensorContext* cookie) +{ + return mklazy_common(cookie); +} + +#endif // >= 0.11 + } // namespace aotriton_adapter } // namespace sdp diff --git a/aten/src/ATen/native/transformers/hip/aotriton_versions.h b/aten/src/ATen/native/transformers/hip/aotriton_versions.h new file mode 100644 index 000000000000..2f5d3f0e1222 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/aotriton_versions.h @@ -0,0 +1,20 @@ +#pragma once + +#ifdef USE_ROCM + +#define AOTRITON_VERSION_INT(x, y) (x * 100 + y) +#define AOTRITON_VERSION_CURRENT (AOTRITON_VERSION_MAJOR * 100 + AOTRITON_VERSION_MINOR) + +#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 11) +#define AOTRITON_ALWAYS_V3_API 1 +#else +#define AOTRITON_ALWAYS_V3_API 0 +#endif + +#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 10) +#define AOTRITON_V3_API 1 +#else +#define AOTRITON_V3_API 0 +#endif + +#endif diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip index 1908096e2f6f..acadb67ae171 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip @@ -60,20 +60,13 @@ #include // AOTriton headers -#include #include #include -#if AOTRITON_VERSION_MINOR < 9 +#if AOTRITON_VERSION_CURRENT < AOTRITON_VERSION_INT(0, 9) #error "This adaptor code is only tested with AOTriton >= 0.9" #endif -#if (AOTRITON_VERSION_MAJOR * 100 + AOTRITON_VERSION_MINOR) >= 10 -#define V3_API 1 -#else -#define V3_API 0 -#endif - namespace pytorch_flash { namespace { @@ -93,15 +86,15 @@ calculate_swa(std::optional window_size_left, int max_seqlen_q, int max_seqlen_k, bool is_causal) { -#if V3_API // SWA is exposed through V3 API +#if AOTRITON_V3_API // SWA is exposed through V3 API bool needs_swa = false; using aotriton::v3::flash::WindowValue; // Default values when std::optional window_size_left/right have no value int window_left = max_seqlen_q; int window_right = max_seqlen_k; if (is_causal) { - window_left = WindowValue::TopLeftAligned; - window_right = WindowValue::TopLeftAligned; + window_left = WindowValue::BottomRightAligned; + window_right = WindowValue::BottomRightAligned; } if (window_size_left.has_value() || window_size_right.has_value()) { needs_swa = true; @@ -254,10 +247,10 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x seqlen_q, seqlen_k, is_causal); -#if V3_API +#if AOTRITON_V3_API const bool uses_swa = needs_swa; #else - // When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be + // When AOTRITON_V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be // optimized out (hopefully). constexpr bool uses_swa = false; #endif @@ -276,8 +269,8 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x auto seed_output = mk_philoxtensor(use_philox_state ? seed_t.data_ptr() : nullptr); auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr() : nullptr); auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr() : nullptr); - if (uses_swa) { -#if V3_API + if (uses_swa || AOTRITON_ALWAYS_V3_API) { +#if AOTRITON_V3_API using aotriton::v3::flash::CausalType; using aotriton::v3::flash::VarlenType; aotriton::v3::flash::attn_fwd_params params; @@ -297,7 +290,7 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x params.philox_offset_output = offset_output; params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax"); params.persistent_atomic_counter = persistent_counter; - params.causal_type = CausalType::WindowedAttention; + params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; params.varlen_type = VarlenType::None; params.window_left = window_left; params.window_right = window_right; @@ -447,10 +440,10 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot max_seqlen_q, max_seqlen_k, is_causal); -#if V3_API +#if AOTRITON_V3_API const bool uses_swa = needs_swa; #else - // When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be + // When AOTRITON_V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be // optimized out (hopefully). constexpr bool uses_swa = false; #endif @@ -464,10 +457,11 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::mk_philoxtensor; + using sdp::aotriton_adapter::mk_atomictensor; using sdp::aotriton_adapter::cast_dtype; at::Tensor atomic_counter; if (is_causal) { - atomic_counter = at::zeros({1}, q.options()); + atomic_counter = at::zeros({1}, q.options().dtype(at::kInt)); } aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t); @@ -476,9 +470,9 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot auto nullscalar = mk_philoxtensor(nullptr); auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr()) : nullscalar; auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr()) : nullscalar; - auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr()) : nullscalar; - if (uses_swa) { -#if V3_API + auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr() : nullptr); + if (uses_swa || AOTRITON_ALWAYS_V3_API) { +#if AOTRITON_V3_API using aotriton::v3::flash::CausalType; using aotriton::v3::flash::VarlenType; aotriton::v3::flash::attn_fwd_params params; @@ -500,7 +494,7 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot params.philox_offset_output = offset_output; params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax"); params.persistent_atomic_counter = persistent_counter; - params.causal_type = CausalType::WindowedAttention; + params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; params.varlen_type = VarlenType::CompactVarlen; params.window_left = window_left; params.window_right = window_right; @@ -594,10 +588,6 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); - if (is_causal){ - TORCH_CHECK((seqlen_q == seqlen_k), "For backwards kernel seqlen_q must equal seqlen_k for causal kernels"); - } - TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!"); @@ -649,10 +639,10 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea seqlen_q, seqlen_k, is_causal); -#if V3_API +#if AOTRITON_V3_API const bool uses_swa = needs_swa; #else - // When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be + // When AOTRITON_V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be // optimized out (hopefully). constexpr bool uses_swa = false; #endif @@ -676,10 +666,9 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea hipError_t err; // TODO: Error handling using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; - if (uses_swa) { -#if V3_API + if (uses_swa || AOTRITON_ALWAYS_V3_API) { +#if AOTRITON_V3_API // Fused BWD does not support SWA - at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); using aotriton::v3::flash::CausalType; using aotriton::v3::flash::VarlenType; aotriton::v3::flash::attn_bwd_params params; @@ -689,21 +678,32 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea params.Sm_scale = softmax_scale; params.Out = mk_aotensor(out_t, "out"); params.DO = mk_aotensor(dout_t, "dout"); - params.DK = mk_aotensor(dq_t, "dq"); - params.DV = mk_aotensor(dk_t, "dk"); - params.DQ = mk_aotensor(dv_t, "dv"); + params.DQ = mk_aotensor(dq_t, "dq"); + params.DK = mk_aotensor(dk_t, "dk"); + params.DV = mk_aotensor(dv_t, "dv"); params.L = mk_aotensor<2>(softmax_lse_cont, "L"); - params.D = mk_aotensor<2>(delta, "delta"); params.Max_seqlen_q = seqlen_q; // Unused if cu_seqlens_q is empty params.Max_seqlen_k = seqlen_k; // Unused if cu_seqlens_k is empty params.dropout_p = p_dropout; params.philox_seed_ptr = mk_aoscalartensor(philox_seed); params.philox_offset1 = mk_aoscalartensor(philox_offset); params.philox_offset2 = 0; - params.causal_type = CausalType::WindowedAttention; - params.varlen_type = VarlenType::None; + params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; params.window_left = window_left; params.window_right = window_right; + params.varlen_type = VarlenType::None; +#if AOTRITON_ALWAYS_V3_API + using sdp::aotriton_adapter::mklazy_empty_like; + using sdp::aotriton_adapter::mklazy_fp32zeros; + using sdp::aotriton_adapter::LazyTensorContext; + LazyTensorContext lazy_delta { .like_tensor = softmax_lse_cont, .tensor_name = "delta" }; + LazyTensorContext lazy_dq_acc { .like_tensor = dq_t, .tensor_name = "dq_acc" }; + params.D = mklazy_empty_like<2>(&lazy_delta); + params.DQ_ACC = mklazy_fp32zeros<4>(&lazy_dq_acc); +#else + at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); + params.D = mk_aotensor<2>(delta, "delta"); +#endif err = aotriton::v3::flash::attn_bwd(params, aotriton::v3::flash::attn_bwd_params::kVersion, stream); @@ -838,7 +838,6 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size CHECK_SHAPE(cu_seqlens_k, batch_size + 1); at::Tensor softmax_lse_cont = softmax_lse.view({batch_size * num_heads, max_seqlen_q}).contiguous(); - at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); at::Tensor q_padded, k_padded, v_padded; q_padded = q.unsqueeze(0).transpose(1, 2); @@ -896,10 +895,10 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size max_seqlen_q, max_seqlen_k, is_causal); -#if V3_API +#if AOTRITON_V3_API const bool uses_swa = needs_swa; #else - // When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be + // When AOTRITON_V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be // optimized out (hopefully). constexpr bool uses_swa = false; #endif @@ -919,8 +918,8 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size hipError_t err; // TODO: Error handling using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; - if (uses_swa) { -#if V3_API + if (uses_swa || AOTRITON_ALWAYS_V3_API) { +#if AOTRITON_V3_API using aotriton::v3::flash::CausalType; using aotriton::v3::flash::VarlenType; aotriton::v3::flash::attn_bwd_params params; @@ -930,11 +929,10 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size params.Sm_scale = softmax_scale; params.Out = mk_aotensor(out_t, "out"); params.DO = mk_aotensor(dout_t, "dout"); - params.DK = mk_aotensor(dq_padded, "dq"); - params.DV = mk_aotensor(dk_padded, "dk"); - params.DQ = mk_aotensor(dv_padded, "dv"); + params.DK = mk_aotensor(dk_padded, "dk"); + params.DV = mk_aotensor(dv_padded, "dv"); + params.DQ = mk_aotensor(dq_padded, "dq"); params.L = mk_aotensor<2>(softmax_lse_cont, "L"); - params.D = mk_aotensor<2>(delta, "delta"); params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"); params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"); params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty @@ -943,17 +941,30 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size params.philox_seed_ptr = mk_aoscalartensor(philox_seed); params.philox_offset1 = mk_aoscalartensor(philox_offset); params.philox_offset2 = 0; - params.causal_type = CausalType::WindowedAttention; + params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; params.varlen_type = VarlenType::CompactVarlen; params.window_left = window_left; params.window_right = window_right; +#if AOTRITON_ALWAYS_V3_API + using sdp::aotriton_adapter::mklazy_empty_like; + using sdp::aotriton_adapter::mklazy_fp32zeros; + using sdp::aotriton_adapter::LazyTensorContext; + LazyTensorContext lazy_delta { .like_tensor = softmax_lse_cont, .tensor_name = "delta" }; + LazyTensorContext lazy_dq_acc { .like_tensor = dq_padded, .tensor_name = "dq_acc" }; + params.D = mklazy_empty_like<2>(&lazy_delta); + params.DQ_ACC = mklazy_fp32zeros<4>(&lazy_dq_acc); +#else + at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); + params.D = mk_aotensor<2>(delta, "delta"); +#endif err = aotriton::v3::flash::attn_bwd(params, aotriton::v3::flash::attn_bwd_params::kVersion, stream); -#endif +#endif // AOTRITON_ALWAYS_V3_API } else { using aotriton::v2::flash::attn_bwd_compact_varlen; using sdp::aotriton_adapter::cast_dtype; + at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); err = attn_bwd_compact_varlen(mk_aotensor(q_padded, "q"), mk_aotensor(k_padded, "k"), diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h index 17298aae9485..e578847e3273 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h @@ -270,7 +270,7 @@ std::tuple mha_varle #endif TORCH_API -inline std::tuple< +std::tuple< at::Tensor, at::Tensor, at::Tensor, @@ -294,42 +294,7 @@ mha_fwd( std::optional window_size_right, const float softcap, const bool return_softmax, - std::optional gen_) { -#if defined(USE_CK_FLASH_ATTENTION) - if (at::globalContext().getROCmFAPreferredBackend() == - at::ROCmFABackend::Ck) { - const int non_null_window_left = window_size_left.value_or(-1); - const int non_null_window_right = window_size_right.value_or(-1); - std::optional dummy_attn_bias = std::nullopt; - return mha_fwd_ck( - q, - k, - v, - out_, - p_dropout, - softmax_scale, - is_causal, - non_null_window_left, - non_null_window_right, - return_softmax, - gen_, - dummy_attn_bias); // Not used in flash attention - } -#endif - return mha_fwd_aot( - q, - k, - v, - out_, - alibi_slopes_, - p_dropout, - softmax_scale, - is_causal, - window_size_left, - window_size_right, - return_softmax, - gen_); -} + std::optional gen_); inline std::tuple< at::Tensor, diff --git a/aten/src/ATen/native/transformers/hip/gemm_kernel_utils.h b/aten/src/ATen/native/transformers/hip/gemm_kernel_utils.h new file mode 100644 index 000000000000..c18744afc1ff --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/gemm_kernel_utils.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// This file is a trimmed version of cuda/mem_eff_attention/gemm_kernel_utils.h +#pragma once + +#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \ + TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + TORCH_CHECK(TENSOR.is_contiguous()); + +#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \ + TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + TORCH_CHECK( \ + TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous"); + +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + TORCH_CHECK( \ + uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") + +#define ASSIGN_CHECK_OVERFLOW(A, B) \ + { \ + A = B; \ + TORCH_CHECK( \ + B < std::numeric_limits::max(), #B " overflows"); \ + } diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index 8b380d24f6c8..f09f77bedb80 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -9,96 +9,265 @@ if(NOT __AOTRITON_INCLUDED) # Replaces .ci/docker/aotriton_version.txt # Note packages information may have versions skipped (due to no ABI breaks) # But they must be listed from lower version to higher version - set(__AOTRITON_VER "0.10b") + set(__AOTRITON_VER "0.11b") set(__AOTRITON_MANYLINUX_LIST + "manylinux_2_28" # rocm6.2 "manylinux_2_28" # rocm6.3 "manylinux_2_28" # rocm6.4 "manylinux_2_28" # rocm7.0 ) set(__AOTRITON_ROCM_LIST + "rocm6.2" "rocm6.3" "rocm6.4" "rocm7.0" ) - set(__AOTRITON_CI_COMMIT "6fca155f4deeb8d9529326f7b69f350aeeb93477") + set(__AOTRITON_CI_COMMIT "972223c501ffc22068bb035ac5d64cf54318d895") set(__AOTRITON_SHA256_LIST - "861cd9f7479eec943933c27cb86920247e5b5dd139bc7c1376c81808abb7d7fe" # rocm6.3 - "acea7d811a2d3bbe718b6e07fc2a9f739e49eecd60b4b6a36fcb3fe8edf85d78" # rocm6.4 - "1e9b3dddf0c7fc07131c6f0f5266129e83ce2331f459fa2be8c63f4ae91b0f5b" # rocm7.0 + "6cae3d5de75ee205d22e088f7dfaab1227056d02ea67f29ccdbc09f2be4e8c8f" # rocm6.2 + "72a153549ea20707331e8a1f1e3d1b8de2913f9d5af2b900c56235d578b57efe" # rocm6.3 + "c7f319dd7448cbbbab81889dd8a37d47dbc25ebcbd89760f09e6a0904e556393" # rocm6.4 + "a2a974e0ad929a5e5827c0f896c59bda4872459cbaf8dd8e0a00407f404491cf" # rocm7.0 ) + set(__AOTRITON_IMAGE_LIST + "amd-gfx90a" + "amd-gfx942" + "amd-gfx950" + "amd-gfx11xx" + "amd-gfx120x" + ) + set(__AOTRITON_IMAGE_SHA256_LIST + "c19a41c9480510ab32e6fb05e6ed0a3832d6b07634f050b836b760200befa735" # amd-gfx90a + "3a06a99971dddb7703a30378f1c5d6b41468d926ea51821156d1b6857b985bc4" # amd-gfx942 + "27fc21f6761d57987a700436de8cf29cbdd9eeee91318dfed596eeb147d219ad" # amd-gfx950 + "ec134032087344176695505db659387374d1916adfee16f0db47dee38d9c8603" # amd-gfx11xx + "fec05205747ff51649b1e151545267d5aa2037ba9d0338cad286882915b941b0" # amd-gfx120x + ) + set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore set(__AOTRITON_Z "gz") + # Set the default __AOTRITON_LIB path + if(NOT WIN32) + set(__AOTRITON_LIB "lib/libaotriton_v2.so") + else() + set(__AOTRITON_LIB "lib/aotriton_v2.lib") + endif() - # Note it is INSTALL"ED" - if(DEFINED ENV{AOTRITON_INSTALLED_PREFIX}) - install(DIRECTORY - $ENV{AOTRITON_INSTALLED_PREFIX}/lib - $ENV{AOTRITON_INSTALLED_PREFIX}/include - DESTINATION ${__AOTRITON_INSTALL_DIR}) - set(__AOTRITON_INSTALL_DIR "$ENV{AOTRITON_INSTALLED_PREFIX}") - message(STATUS "Using Preinstalled AOTriton at ${__AOTRITON_INSTALL_DIR}") - elseif(DEFINED ENV{AOTRITON_INSTALL_FROM_SOURCE}) - ExternalProject_Add(aotriton_external + function(aotriton_build_windows_dependencies dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR) + # Windows-specific dependencies - build these first + if(NOT noimage) + message(FATAL_ERROR "noimage must be ON for Windows builds") + endif() + # Build dlfcn-win32 + set(__DLFCN_WIN32_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/dlfcn-win32") + set(__DLFCN_WIN32_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/dlfcn-win32-install") + + ExternalProject_Add(${dlfcn-win32_external} + GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git + GIT_TAG v1.4.2 + PREFIX ${__DLFCN_WIN32_PREFIX} + INSTALL_DIR ${__DLFCN_WIN32_INSTALL_DIR} + CMAKE_ARGS + -DCMAKE_INSTALL_PREFIX=${__DLFCN_WIN32_INSTALL_DIR} + -DCMAKE_BUILD_TYPE=Release + -DCMAKE_C_COMPILER=cl + -DCMAKE_CXX_COMPILER=cl + -DBUILD_SHARED_LIBS=ON + -DBUILD_TESTS=OFF + BUILD_BYPRODUCTS + "${__DLFCN_WIN32_INSTALL_DIR}/lib/dl.lib" + "${__DLFCN_WIN32_INSTALL_DIR}/bin/dl.dll" + ) + ExternalProject_Add_Step(${dlfcn-win32_external} copy_to_aotriton + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${__DLFCN_WIN32_INSTALL_DIR}/bin/dl.dll" + "${__AOTRITON_INSTALL_DIR}/lib/" + DEPENDEES install + ) + set(${dlfcn-win32_DIR} "${__DLFCN_WIN32_INSTALL_DIR}/share/dlfcn-win32" CACHE PATH "Path to dlfcn-win32 CMake config" FORCE) + + # Build xz/liblzma + set(__XZ_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/xz") + set(__XZ_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/xz-install") + + ExternalProject_Add(${xz_external} + GIT_REPOSITORY https://github.com/tukaani-project/xz.git + GIT_TAG v5.8.1 + PREFIX ${__XZ_PREFIX} + INSTALL_DIR ${__XZ_INSTALL_DIR} + CMAKE_ARGS + -DCMAKE_INSTALL_PREFIX=${__XZ_INSTALL_DIR} + -DCMAKE_BUILD_TYPE=Release + -DBUILD_SHARED_LIBS=ON + -DENABLE_NLS=OFF + -DXZ_TOOL_LZMAINFO=OFF + -DXZ_TOOL_XZ=OFF + -DXZ_TOOL_XZDEC=OFF + -DXZ_TOOL_LZMADEC=OFF + BUILD_BYPRODUCTS + "${__XZ_INSTALL_DIR}/lib/lzma.lib" + "${__XZ_INSTALL_DIR}/bin/liblzma.dll" + ) + ExternalProject_Add_Step(${xz_external} copy_to_aotriton + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${__XZ_INSTALL_DIR}/bin/liblzma.dll" + "${__AOTRITON_INSTALL_DIR}/lib/" + DEPENDEES install + ) + set(${liblzma_DIR} "${__XZ_INSTALL_DIR}/lib/cmake/liblzma" CACHE PATH "Path to xz/liblzma CMake config" FORCE) + endfunction() + + function(aotriton_build_from_source noimage project) + if(noimage) + SET(RECURSIVE "OFF") + else() + SET(RECURSIVE "ON") + endif() + if(WIN32) + message(STATUS "Building AOTriton Windows dependencies") + aotriton_build_windows_dependencies(dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR) + endif() + message(STATUS "PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}") + + ExternalProject_Add(${project} GIT_REPOSITORY https://github.com/ROCm/aotriton.git + GIT_SUBMODULES_RECURSE ${RECURSIVE} GIT_TAG ${__AOTRITON_CI_COMMIT} PREFIX ${__AOTRITON_EXTERN_PREFIX} - INSTALL_DIR ${__AOTRITON_INSTALL_DIR} - CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR} + CMAKE_CACHE_ARGS -DAOTRITON_TARGET_ARCH:STRING=${PYTORCH_ROCM_ARCH} + -DCMAKE_INSTALL_PREFIX:FILEPATH=${__AOTRITON_INSTALL_DIR} + CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} + -DAOTRITON_GPU_BUILD_TIMEOUT=0 -DAOTRITON_NO_PYTHON=ON - -DAOTRITON_NO_SHARED=OFF - # CONFIGURE_COMMAND "" - BUILD_COMMAND "" # No build, install command will repeat the build process due to problems in the build system. - BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so" + -DAOTRITON_NOIMAGE_MODE=${noimage} + -DHIP_PLATFORM=amd + $<$:-Ddlfcn-win32_DIR=${dlfcn-win32_DIR}> + $<$:-Dliblzma_DIR=${liblzma_DIR}> + BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/${__AOTRITON_LIB}" USES_TERMINAL_DOWNLOAD TRUE USES_TERMINAL_CONFIGURE TRUE USES_TERMINAL_BUILD TRUE USES_TERMINAL_INSTALL TRUE - # INSTALL_COMMAND ${MAKE_COMMAND} install - ) - add_dependencies(__caffe2_aotriton aotriton_external) - message(STATUS "Using AOTriton compiled from source directory ${__AOTRITON_EXTERN_PREFIX}") - else() - set(__AOTRITON_SYSTEM_ROCM "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}") - list(GET __AOTRITON_ROCM_LIST 0 __AOTRITON_ROCM_DEFAULT_STR) - # Initialize __AOTRITON_ROCM to lowest version, in case all builds > system's ROCM - string(SUBSTRING ${__AOTRITON_ROCM_DEFAULT_STR} 4 -1 __AOTRITON_ROCM) - foreach(AOTRITON_ROCM_BUILD_STR IN LISTS __AOTRITON_ROCM_LIST) - # len("rocm") == 4 - string(SUBSTRING ${AOTRITON_ROCM_BUILD_STR} 4 -1 AOTRITON_ROCM_BUILD) - # Find the last build that <= system's ROCM - # Assume the list is from lower to higher - if(AOTRITON_ROCM_BUILD VERSION_GREATER __AOTRITON_SYSTEM_ROCM) - break() - endif() - set(__AOTRITON_ROCM ${AOTRITON_ROCM_BUILD}) - endforeach() - list(FIND __AOTRITON_ROCM_LIST "rocm${__AOTRITON_ROCM}" __AOTRITON_ROCM_INDEX) - list(GET __AOTRITON_SHA256_LIST ${__AOTRITON_ROCM_INDEX} __AOTRITON_SHA256) - list(GET __AOTRITON_MANYLINUX_LIST ${__AOTRITON_ROCM_INDEX} __AOTRITON_MANYLINUX) - set(__AOTRITON_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR}) + ) + if(WIN32) + add_dependencies(${project} dlfcn-win32_external xz_external) + endif() + endfunction() + + set(__AOTRITON_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR}) + function(aotriton_download_runtime index project) + list(GET __AOTRITON_ROCM_LIST ${index} __AOTRITON_ROCM) + list(GET __AOTRITON_MANYLINUX_LIST ${index} __AOTRITON_MANYLINUX) + list(GET __AOTRITON_SHA256_LIST ${index} __AOTRITON_SHA256) + string(CONCAT __AOTRITON_FILE "aotriton-" "${__AOTRITON_VER}-${__AOTRITON_MANYLINUX}" - "_${__AOTRITON_ARCH}-rocm${__AOTRITON_ROCM}" + "_${__AOTRITON_ARCH}-${__AOTRITON_ROCM}" "-shared.tar.${__AOTRITON_Z}") - string(CONCAT __AOTRITON_URL "https://github.com/ROCm/aotriton/releases/download/" # @lint-ignore - "${__AOTRITON_VER}/${__AOTRITON_FILE}") - ExternalProject_Add(aotriton_external + string(CONCAT __AOTRITON_URL + "${__AOTRITON_BASE_URL}" + "${__AOTRITON_VER}/${__AOTRITON_FILE}") + ExternalProject_Add(${project} URL "${__AOTRITON_URL}" URL_HASH SHA256=${__AOTRITON_SHA256} - SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball + SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_runtime CONFIGURE_COMMAND "" BUILD_COMMAND "" INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory - "${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball" + "${CMAKE_CURRENT_BINARY_DIR}/aotriton_runtime" "${__AOTRITON_INSTALL_DIR}" - BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so" + BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/${__AOTRITON_LIB}" ) - add_dependencies(__caffe2_aotriton aotriton_external) - message(STATUS "Using AOTriton from pre-compiled binary ${__AOTRITON_URL}.\ + message(STATUS "Using AOTriton Runtime from pre-compiled binary ${__AOTRITON_URL}.\ Set env variables AOTRITON_INSTALL_FROM_SOURCE=1 to build from source.") + endfunction() + + function(aotriton_download_image image project) + list(FIND __AOTRITON_IMAGE_LIST ${image} index) + list(GET __AOTRITON_IMAGE_SHA256_LIST ${index} __AOTRITON_SHA256) + + string(CONCAT __AOTRITON_FILE + "aotriton-${__AOTRITON_VER}-images-" + "${image}.tar.${__AOTRITON_Z}") + string(CONCAT __AOTRITON_URL + "${__AOTRITON_BASE_URL}" + "${__AOTRITON_VER}/${__AOTRITON_FILE}") + + # Set up directories + set(__AOTRITON_DOWNLOAD_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_download-${image}) + set(__AOTRITON_EXTRACT_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}) + set(__AOTRITON_INSTALL_SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}) + set(__DOWNLOAD_NO_EXTRACT "") + set(__BUILD_COMMANDS "") + + # On Windows, we need custom tar extraction with UTF-8 support + if(WIN32) + set(__DOWNLOAD_NO_EXTRACT "DOWNLOAD_NO_EXTRACT;TRUE") + set(__BUILD_COMMANDS + COMMAND ${CMAKE_COMMAND} -E make_directory "${__AOTRITON_EXTRACT_DIR}" + COMMAND tar --options hdrcharset=UTF-8 -xf "${__AOTRITON_DOWNLOAD_DIR}/${__AOTRITON_FILE}" -C "${__AOTRITON_EXTRACT_DIR}" + ) + set(__AOTRITON_INSTALL_SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}/aotriton) + endif() + + ExternalProject_Add(${project} + URL "${__AOTRITON_URL}" + URL_HASH SHA256=${__AOTRITON_SHA256} + DOWNLOAD_DIR ${__AOTRITON_DOWNLOAD_DIR} + ${__DOWNLOAD_NO_EXTRACT} + SOURCE_DIR ${__AOTRITON_EXTRACT_DIR} + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + ${__BUILD_COMMANDS} + INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory + "${__AOTRITON_INSTALL_SOURCE_DIR}" + "${__AOTRITON_INSTALL_DIR}" + BUILD_BYPRODUCTS + "${__AOTRITON_INSTALL_DIR}/lib/aotriton.images/${image}/__signature__" + ) + message(STATUS "Download AOTriton pre-compiled GPU images from ${__AOTRITON_URL}.") + endfunction() + + # Note it is INSTALL"ED" + if(DEFINED ENV{AOTRITON_INSTALLED_PREFIX}) + install(DIRECTORY + $ENV{AOTRITON_INSTALLED_PREFIX}/lib + $ENV{AOTRITON_INSTALLED_PREFIX}/include + DESTINATION ${__AOTRITON_INSTALL_DIR}) + set(__AOTRITON_INSTALL_DIR "$ENV{AOTRITON_INSTALLED_PREFIX}") + message(STATUS "Using Preinstalled AOTriton at ${__AOTRITON_INSTALL_DIR}") + elseif(DEFINED ENV{AOTRITON_INSTALL_FROM_SOURCE}) + aotriton_build_from_source(OFF aotriton_external) + add_dependencies(__caffe2_aotriton aotriton_external) + message(STATUS "Using AOTriton compiled from source directory ${__AOTRITON_EXTERN_PREFIX}") + else() + set(__AOTRITON_SYSTEM_ROCM "${HIP_VERSION_MAJOR}.${HIP_VERSION_MINOR}") + list(FIND __AOTRITON_ROCM_LIST "rocm${__AOTRITON_SYSTEM_ROCM}" __AOTRITON_RUNTIME_INDEX) + if(${__AOTRITON_RUNTIME_INDEX} LESS 0) + message(STATUS "Cannot find AOTriton runtime for ROCM ${__AOTRITON_SYSTEM_ROCM}. \ + Build runtime from source") + aotriton_build_from_source(ON aotriton_runtime) + else() + aotriton_download_runtime(${__AOTRITON_RUNTIME_INDEX} aotriton_runtime) + endif() + add_dependencies(__caffe2_aotriton aotriton_runtime) + set(__AOTRITON_CHAINED_IMAGE "aotriton_runtime") + foreach(image ${__AOTRITON_IMAGE_LIST}) + string(SUBSTRING ${image} 7 -1 gfx_pattern) + string(REPLACE "x" "." gfx_regex ${gfx_pattern}) + foreach(target ${PYTORCH_ROCM_ARCH}) + if(target MATCHES ${gfx_regex}) + set(__AOTRITON_DOWNLOAD_TARGET aotriton_image_${gfx_pattern}) + aotriton_download_image(${image} ${__AOTRITON_DOWNLOAD_TARGET}) + add_dependencies(${__AOTRITON_CHAINED_IMAGE} ${__AOTRITON_DOWNLOAD_TARGET}) + set(__AOTRITON_CHAINED_IMAGE ${__AOTRITON_DOWNLOAD_TARGET}) + break() + endif() + endforeach() + endforeach() endif() - target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so) + target_link_libraries(__caffe2_aotriton INTERFACE "${__AOTRITON_INSTALL_DIR}/${__AOTRITON_LIB}") target_include_directories(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include) set(AOTRITON_FOUND TRUE) endif() # __AOTRITON_INCLUDED diff --git a/test/nn/test_multihead_attention.py b/test/nn/test_multihead_attention.py index c0419664d009..40dca90b1648 100644 --- a/test/nn/test_multihead_attention.py +++ b/test/nn/test_multihead_attention.py @@ -17,7 +17,6 @@ instantiate_parametrized_tests, parametrize as parametrize_test, run_tests, - skipIfRocm, TEST_NUMPY, TEST_WITH_CROSSREF, ) @@ -746,7 +745,6 @@ def test_multihead_attn_nested_tensor_outside_fast_path(self): class TestMultiheadAttentionNNDeviceType(NNTestCase): - @skipIfRocm(msg="To investigate: yields NaN") def test_multihead_self_attn_two_masks_fast_path(self, device): """ Multihead self-attention should give the same result on the fast path (BetterTransformer) as on the slow path diff --git a/test/test_flop_counter.py b/test/test_flop_counter.py index c44d5e5d4145..17e699e04e58 100644 --- a/test/test_flop_counter.py +++ b/test/test_flop_counter.py @@ -15,7 +15,6 @@ ) from torch.testing._internal.common_utils import ( run_tests, - skipIfRocm, TEST_WITH_TORCHDYNAMO, TestCase, ) @@ -463,7 +462,6 @@ def get_flops( self.assertExpectedInline(str(flops_fw_bw_math), """805306368""") self.assertExpectedInline(str(flops_fw_bw_efficient), """939524096""") - @skipIfRocm # Nested tensor @unittest.skipIf(not HAS_CUDA, "CUDA not available") @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION @@ -683,7 +681,6 @@ def split_tensor(x): ), ) - @skipIfRocm # Nested tensor @unittest.skipIf(not HAS_CUDA, "CUDA not available") @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, diff --git a/test/test_nn.py b/test/test_nn.py index f3b1764af69d..e65f5d53147a 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -36,8 +36,9 @@ download_file, get_function_arglist, load_tests, skipIfMPS, \ IS_PPC, \ parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \ - skipIfTorchDynamo, skipIfRocmVersionLessThan, gcIfJetson, set_default_dtype -from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PLATFORM_SUPPORTS_FLASH_ATTENTION, _get_torch_rocm_version + skipIfTorchDynamo, gcIfJetson, set_default_dtype +from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \ + _get_torch_rocm_version from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \ module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \ ctcloss_reference, get_new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input @@ -3148,7 +3149,6 @@ def perm_fn(x): [2.42240309, 0.0354595, -0.60659063, -0.05378816]]])) torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0) - @skipIfRocm(msg='Large numerical errors') def test_transformerdecoder(self): def get_a_test_layer(use_cuda, activation, batch_first=False): d_model = 4 @@ -12937,8 +12937,6 @@ def test_skip_init(self, device): @dtypes(torch.float) @dtypesIfCUDA(torch.double, torch.float, torch.half) def test_transformerencoderlayer(self, device, dtype): - if TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION and dtype == torch.half: - self.skipTest("Skip on ROCM due to Flash Attention tolerances") # this is a deterministic test for TransformerEncoderLayer d_model = 4 nhead = 2 @@ -13160,8 +13158,6 @@ def test_transformerencoderlayer_fast_path(self, device, dtype): @dtypes(torch.float) @dtypesIfCUDA(torch.half, torch.float) def test_transformerencoderlayer_gelu(self, device, dtype): - if TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION and dtype == torch.half: - self.skipTest("Skip on ROCM due to Flash Attention tolerances") # this is a deterministic test for TransformerEncoderLayer with gelu activation d_model = 4 nhead = 2 diff --git a/test/test_transformers.py b/test/test_transformers.py index 8bdad854cd22..a68ed8e10576 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -52,6 +52,7 @@ SM90OrLater, tf32_on_and_off, tf32_enabled, + ROCM_VERSION, ) if TEST_FAIRSEQ: @@ -340,7 +341,7 @@ def test_train_with_pad_and_catch_error(self, device): l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item() self.assertTrue(l1_bool < 1e-4, "Eval/Train difference in pad_mask BOOL") - @tf32_on_and_off(0.001) + @tf32_on_and_off(0.001, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0))) @parametrize("attn_mask_dim", [2, 3, None]) @parametrize("key_padding_mask_dim", [2, None]) @parametrize("mask_dtype", [torch.bool, torch.float32]) @@ -430,7 +431,6 @@ def hook(module, inputs, output): # remove hook handle.remove() - @skipIfRocm @tf32_on_and_off(0.001) @parametrize("use_torchscript", [False]) @parametrize("enable_nested_tensor", [True, False]) @@ -524,7 +524,7 @@ def test_transformerencoder_fastpath(self, device, use_torchscript, enable_neste slowpath_output = slowpath_output.masked_fill(src_key_padding_mask.unsqueeze(-1), 0) self.assertEqual(fastpath_output_expanded, slowpath_output) - @tf32_on_and_off(0.001) + @tf32_on_and_off(0.001, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0))) @parametrize("with_no_grad", [True, False]) @parametrize("training", [True, False]) @parametrize("enable_nested_tensor", [False]) @@ -1110,7 +1110,7 @@ def forward( return_all_hiddens=False, )[0] - @tf32_on_and_off(0.003) + @tf32_on_and_off(0.003, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0))) @parametrize("input_dim,attn_mask_dim,is_causal", [(3, None, False), (3, 2, False), (3, 2, True), (3, 3, False), (3, 3, True), (4, None, False), (4, 2, False), (4, 2, True), (4, 4, False), (4, 4, True)], @@ -1421,7 +1421,6 @@ def ones_tensor(*shape): _ = mha_f(qkv_f, qkv_f, qkv_f, attn_mask=mask, need_weights=False, is_causal=True) torch.cuda.synchronize() - @skipIfRocm # Missing EFFICIENT_ATTENTION @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not supposrt fused SDPA or pre-SM80 hardware" ) @@ -1727,7 +1726,7 @@ def test_unaligned_tensors(self, device): make_tensor = partial(torch.rand, size, device=device, dtype=dtype) q, k, v = make_tensor(), make_tensor(), make_tensor() with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): - ctxmgr = self.assertRaises(RuntimeError) if not TEST_WITH_ROCM else contextlib.nullcontext() + ctxmgr = self.assertRaises(RuntimeError) with ctxmgr: torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False) @@ -2541,7 +2540,6 @@ def convert_flash_attn_S_to_softmax( S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded)) return S_converted[:, :, :seqlen_q, :seqlen_k] - @skipIfRocm # No cuDNN Attention @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") def test_cudnn_attention_different_dk_dv(self, device): dtype = torch.bfloat16 @@ -2565,7 +2563,6 @@ def test_cudnn_attention_different_dk_dv(self, device): self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) - @skipIfRocm # No cuDNN Attention @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") def test_cudnn_attention_gqa(self, device): batch = 4 @@ -2589,7 +2586,6 @@ def test_cudnn_attention_gqa(self, device): self.assertEqual(output_math, output_cudnn) - @skipIfRocm # No cuDNN Attention @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") def test_cudnn_attention_d256_heuristic(self, device): dtype = torch.bfloat16 @@ -2614,7 +2610,6 @@ def test_cudnn_attention_d256_heuristic(self, device): self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) - @skipIfRocm(msg="No cuDNN on ROCm") @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") def test_fused_attention_different_dk_dv(self, device): dtype = torch.bfloat16 @@ -2638,7 +2633,6 @@ def test_fused_attention_different_dk_dv(self, device): self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) - @skipIfRocm # No cuDNN Attention @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") def test_cudnn_attention_fail_d128(self, device): # Test that cuDNN attention dispatching correctly bails out on d > 128 @@ -2661,7 +2655,6 @@ def test_cudnn_attention_fail_d128(self, device): with self.assertRaisesRegex(RuntimeError, "No available kernel."): torch.nn.functional.scaled_dot_product_attention(q, k, v) - @skipIfRocm(msg="No cuDNN on ROCm") @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") def test_cudnn_attention_trivial_output_transpose(self, device): # see also: https://github.com/pytorch/pytorch/issues/134001 @@ -2677,7 +2670,6 @@ def test_cudnn_attention_trivial_output_transpose(self, device): o.backward(o) torch.testing.assert_close(x.grad, x_cpu.grad.cuda(), atol=7e-3, rtol=7e-3) - @skipIfRocm # No cuDNN Attention @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") def test_cudnn_attention_nonmodulo64seqlen(self, device): # see also: https://github.com/pytorch/pytorch/issues/137347 @@ -2717,7 +2709,6 @@ def test_cudnn_attention_nonmodulo64seqlen(self, device): torch.testing.assert_close(k.grad, k_cpu.grad.cuda(), atol=3e-3, rtol=2e-3) torch.testing.assert_close(v.grad, v_cpu.grad.cuda(), atol=3e-3, rtol=2e-3) - @skipIfRocm @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") def test_cudnn_attention_preserves_query_layout(self, device): @@ -3130,7 +3121,6 @@ def test_sdp_choice_with_determinism(self, device, warn_only): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]): assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value - @skipIfRocm @onlyCUDA @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA") @@ -3351,6 +3341,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, 'grad_value': 8.5, } if TEST_WITH_ROCM: + fudge_factors['out'] = 5.0 fudge_factors['grad_key'] = 45.0 fudge_factors['grad_query'] = 360.0 if seq_len_k >= 1024: @@ -3360,6 +3351,8 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, fudge_factors['grad_query'] = 670.0 if dtype == torch.float32: fudge_factors['grad_key'] = 90.0 + if "gfx95" in torch.cuda.get_device_properties(0).gcnArchName: + fudge_factors['grad_value'] = 16.0 check_out_and_grad( (out_ref, out_lp_ref, out), @@ -3472,6 +3465,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, "grad_attn_mask": 45.0, } if TEST_WITH_ROCM: + fudge_factors['out'] = 6.0 fudge_factors['grad_key'] = 45.0 fudge_factors['grad_query'] = 360.0 if seq_len_k >= 1024: @@ -3481,6 +3475,8 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, fudge_factors['grad_query'] = 670.0 # gfx90a if dtype == torch.float32: fudge_factors['grad_key'] = 90.0 + if "gfx95" in torch.cuda.get_device_properties(0).gcnArchName: + fudge_factors['grad_value'] = 16.0 check_out_and_grad( (out_ref, out_lp_ref, out), @@ -3601,17 +3597,33 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le 'grad_value': 4, } if TEST_WITH_ROCM: - fudge_factors['grad_key'] = 45.0 - fudge_factors['grad_query'] = 360.0 - if seq_len_k >= 1024: - fudge_factors['grad_key'] = 70.0 - if seq_len_k >= 2048: - fudge_factors['grad_key'] = 190.0 - fudge_factors['grad_query'] = 650.0 - if seq_len_q >= 2048: - fudge_factors['grad_query'] = 1100.0 - if dtype == torch.float32: - fudge_factors['grad_key'] = 90.0 + fudge_factors['grad_value'] = 6.0 + if TEST_WITH_CK: + fudge_factors['out'] = 5.0 + fudge_factors['grad_key'] = 145.0 + fudge_factors['grad_query'] = 855.0 # ck min = 855.0 + if seq_len_k >= 1024: + fudge_factors['grad_key'] = 70.0 + if seq_len_k >= 2048: + fudge_factors['grad_key'] = 190.0 + fudge_factors['grad_query'] = 1550.0 # NEW CK MIN + if seq_len_q >= 2048: + fudge_factors['grad_query'] = 1100.0 + if dtype == torch.float32: + fudge_factors['grad_key'] = 90.0 + else: + fudge_factors['out'] = 6.0 + fudge_factors['grad_key'] = 45.0 + fudge_factors['grad_query'] = 360.0 + if seq_len_k >= 1024: + fudge_factors['grad_key'] = 70.0 + if seq_len_k >= 2048: + fudge_factors['grad_key'] = 190.0 + fudge_factors['grad_query'] = 650.0 + if seq_len_q >= 2048: + fudge_factors['grad_query'] = 1100.0 + if dtype == torch.float32: + fudge_factors['grad_key'] = 90.0 check_out_and_grad( (out_ref, out_lp_ref, out), @@ -3764,15 +3776,19 @@ def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, d grads_ref_lp = torch.autograd.grad(out_lp_ref, (query, key, value), upstream_grad) grads_ref = torch.autograd.grad(out_ref, (query_ref, key_ref, value_ref), upstream_grad) + fudge_factors = { + 'out': 3.0, + 'grad_query': 110.0, + 'grad_key': 8.0, + 'grad_value': 3.0, + } + if TEST_WITH_ROCM: + fudge_factors['out'] = 6.0 + fudge_factors['grad_value'] = 6.0 check_out_and_grad( (out_ref, out_lp_ref, out), *zip(grads_ref, grads_ref_lp, grads), - fudge_factors={ - 'out': 3.0, - 'grad_query': 110.0, - 'grad_key': 8.0, - 'grad_value': 3.0, - } + fudge_factors=fudge_factors ) @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") @@ -4384,10 +4400,6 @@ def test_causal_variants(self, device, causal_variant: CausalVariant, shape: lis make_tensor = partial( torch.rand, device=device, dtype=torch.float16, requires_grad=True ) - if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT: - self.skipTest("No support for LOWER_RIGHT variant for now") - return - bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim)) make_kv_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim)) @@ -4418,10 +4430,6 @@ def test_causal_variants(self, device, causal_variant: CausalVariant, shape: lis @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows") @skipIfTorchDynamo("This function already calls torch.compile.") def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: list[tuple[int]]): - if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT: - self.skipTest("No support for LOWER_RIGHT variant for now") - return - cnts = CompileCounterWithBackend("aot_eager") make_tensor = partial( torch.rand, device=device, dtype=torch.float16, requires_grad=True diff --git a/tools/linter/dictionary.txt b/tools/linter/dictionary.txt index a3da2299cf23..2a7c3b9d1acd 100644 --- a/tools/linter/dictionary.txt +++ b/tools/linter/dictionary.txt @@ -3,6 +3,7 @@ BU contiguities contiguity coo +DEPENDEES Din Dout dOut diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 2620c64a95ef..9c4dfd1d7d44 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -24,6 +24,7 @@ TEST_CUDNN = LazyVal(lambda: TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE))) TEST_CUDNN_VERSION = LazyVal(lambda: torch.backends.cudnn.version() if TEST_CUDNN else 0) +ROCM_VERSION = LazyVal(lambda : tuple(int(v) for v in torch.version.hip.split('.')[:2]) if torch.version.hip else (0, 0)) SM53OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3)) SM60OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (6, 0)) @@ -93,7 +94,6 @@ def evaluate_platform_supports_cudnn_attention(): def evaluate_platform_supports_fp8(): if torch.cuda.is_available(): if torch.version.hip: - ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2]) archs = ['gfx94'] if ROCM_VERSION >= (6, 3): archs.extend(['gfx120']) @@ -111,7 +111,8 @@ def evaluate_platform_supports_fp8(): def _platform_supports_mx_gemm(): if torch.cuda.is_available(): if torch.version.hip: - return 'gfx95' in torch.cuda.get_device_properties(0).gcnArchName + if ROCM_VERSION >= (7, 0): + return 'gfx950' in torch.cuda.get_device_properties(0).gcnArchName else: return SM100OrLater return False @@ -222,7 +223,7 @@ def tf32_enabled(): # if device is specified, it will check if device is cuda # if dtype is specified, it will check if dtype is float32 or complex64 # tf32 and fp32 are different only when all the three checks pass -def tf32_on_and_off(tf32_precision=1e-5): +def tf32_on_and_off(tf32_precision=1e-5, only_if=True): def with_tf32_disabled(self, function_call): with tf32_off(): function_call() @@ -238,7 +239,7 @@ def wrapper(f): @functools.wraps(f) def wrapped(*args, **kwargs): kwargs.update(zip(arg_names, args)) - cond = torch.cuda.is_tf32_supported() + cond = torch.cuda.is_tf32_supported() and only_if if 'device' in kwargs: cond = cond and (torch.device(kwargs['device']).type == 'cuda') if 'dtype' in kwargs: @@ -252,7 +253,6 @@ def wrapped(*args, **kwargs): return wrapped return wrapper - # This is a wrapper that wraps a test to run it with TF32 turned off. # This wrapper is designed to be used when a test uses matmul or convolutions # but the purpose of that test is not testing matmul or convolutions. diff --git a/torch/utils/_triton.py b/torch/utils/_triton.py index fa2355463bb3..f260f5781f96 100644 --- a/torch/utils/_triton.py +++ b/torch/utils/_triton.py @@ -15,6 +15,17 @@ def has_triton_package() -> bool: return False +@functools.cache +def get_triton_version(fallback: tuple[int, int] = (0, 0)) -> tuple[int, int]: + try: + import triton # noqa: F401 + + major, minor = tuple(int(v) for v in triton.__version__.split(".")[:2]) + return (major, minor) + except ImportError: + return fallback + + @functools.cache def _device_supports_tma() -> bool: import torch