Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 31 additions & 54 deletions .ci/docker/common/install_rocm_magma.sh
Original file line number Diff line number Diff line change
@@ -1,60 +1,37 @@
#!/bin/bash
# Script used in CI and CD pipeline
#!/usr/bin/env bash
# Script used only in CD pipeline

set -ex
set -eou pipefail

ver() {
printf "%3d%03d%03d%03d" $(echo "$1" | tr '.' ' ');
}

# Magma build scripts need `python`
ln -sf /usr/bin/python3 /usr/bin/python

ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"')
case "$ID" in
almalinux)
yum install -y gcc-gfortran
;;
*)
echo "No preinstalls to build magma..."
;;
esac
function do_install() {
rocm_version=$1
if [[ ${rocm_version} =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
# chop off any patch version
rocm_version="${rocm_version%.*}"
fi

MKLROOT=${MKLROOT:-/opt/conda/envs/py_$ANACONDA_PYTHON_VERSION}
rocm_version_nodot=${rocm_version//./}

# "install" hipMAGMA into /opt/rocm/magma by copying after build
if [[ $(ver $ROCM_VERSION) -ge $(ver 7.0) ]]; then
git clone https://github.com/ROCm/utk-magma.git -b release/2.9.0_rocm70 magma
pushd magma
# version 2.9 + ROCm 7.0 related updates
git checkout 91c4f720a17e842b364e9de41edeef76995eb9ad
else
git clone https://bitbucket.org/icl/magma.git
pushd magma
# Version 2.7.2 + ROCm related updates
git checkout a1625ff4d9bc362906bd01f805dbbe12612953f6
fi
MAGMA_VERSION=a1625ff4d9bc362906bd01f805dbbe12612953f6
magma_archive="magma-rocm${rocm_version_nodot}-${MAGMA_VERSION}-1.tar.bz2"

rocm_dir="/opt/rocm"
(
set -x
tmp_dir=$(mktemp -d)
pushd ${tmp_dir}
curl -OLs https://ossci-linux.s3.us-east-1.amazonaws.com/${magma_archive}
if tar -xvf "${magma_archive}"
then
mkdir -p "${rocm_dir}/magma"
mv include "${rocm_dir}/magma/include"
mv lib "${rocm_dir}/magma/lib"
else
echo "${magma_archive} not found, skipping magma install"
fi
popd
)
}

cp make.inc-examples/make.inc.hip-gcc-mkl make.inc
echo 'LIBDIR += -L$(MKLROOT)/lib' >> make.inc
if [[ -f "${MKLROOT}/lib/libmkl_core.a" ]]; then
echo 'LIB = -Wl,--start-group -lmkl_gf_lp64 -lmkl_gnu_thread -lmkl_core -Wl,--end-group -lpthread -lstdc++ -lm -lgomp -lhipblas -lhipsparse' >> make.inc
fi
echo 'LIB += -Wl,--enable-new-dtags -Wl,--rpath,/opt/rocm/lib -Wl,--rpath,$(MKLROOT)/lib -Wl,--rpath,/opt/rocm/magma/lib -ldl' >> make.inc
echo 'DEVCCFLAGS += --gpu-max-threads-per-block=256' >> make.inc
export PATH="${PATH}:/opt/rocm/bin"
if [[ -n "$PYTORCH_ROCM_ARCH" ]]; then
amdgpu_targets=`echo $PYTORCH_ROCM_ARCH | sed 's/;/ /g'`
else
amdgpu_targets=`rocm_agent_enumerator | grep -v gfx000 | sort -u | xargs`
fi
for arch in $amdgpu_targets; do
echo "DEVCCFLAGS += --offload-arch=$arch" >> make.inc
done
# hipcc with openmp flag may cause isnan() on __device__ not to be found; depending on context, compiler may attempt to match with host definition
sed -i 's/^FOPENMP/#FOPENMP/g' make.inc
make -f make.gen.hipMAGMA -j $(nproc)
LANG=C.UTF-8 make lib/libmagma.so -j $(nproc) MKLROOT="${MKLROOT}"
make testing/testing_dgemm -j $(nproc) MKLROOT="${MKLROOT}"
popd
mv magma /opt/rocm
do_install $1
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,7 @@ cmake_dependent_option(
"Whether to build the flash_attention kernel for scaled dot product attention.\
Will be disabled if not supported by the platform"
ON
"USE_CUDA OR USE_ROCM;NOT MSVC"
"(USE_CUDA AND NOT MSVC) OR USE_ROCM"
OFF)

# CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem
Expand All @@ -883,7 +883,7 @@ cmake_dependent_option(
# USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake
#
if(USE_ROCM)
if(UNIX AND (USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION))
if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)
include(cmake/External/aotriton.cmake)
endif()
endif()
Expand Down
40 changes: 35 additions & 5 deletions aten/src/ATen/native/cuda/Indexing.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ constexpr uint64_t getDefaultMaxThreadsPerBlock() {
#endif
}

#if 0
#ifdef USE_ROCM
#define SKIP_SORTED_INDICES 32
template <typename scalar_t, int SZ>
__global__ void indexing_backward_kernel(
__global__ void indexing_backward_kernel_many_indices(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
using opmath_t = at::opmath_type<scalar_t>;
Expand Down Expand Up @@ -141,10 +142,7 @@ __global__ void indexing_backward_kernel(
}
}
}
#endif

#ifdef USE_ROCM
#define SKIP_SORTED_INDICES 32
template <typename scalar_t>
__global__ void indexing_backward_kernel_stride_1(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
Expand Down Expand Up @@ -784,6 +782,38 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
kBool,
kBFloat16);
} else {
#ifdef USE_ROCM
if (num_indices >= 200000)
AT_DISPATCH_V2(
expandedValue.scalar_type(),
"indexing_backward_many_indices",
AT_WRAP([&] {
indexing_backward_kernel_many_indices<scalar_t, UNROLL><<<new_grid, block, smem_dups_size, stream>>>(
sorted_indices.const_data_ptr<int64_t>(),
orig_indices.const_data_ptr<int64_t>(),
expandedValue.const_data_ptr<scalar_t>(),
src_.mutable_data_ptr<scalar_t>(),
num_indices,
sliceSize,
strideBefore,
nElemBefore,
accumulate);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}),
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
// AT_EXPAND(AT_FLOAT8_TYPES),
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
// should not be supported here, then reenable AT_FLOAT8_DTYPES
kFloat8_e4m3fn,
kFloat8_e5m2,
kFloat8_e4m3fnuz,
kFloat8_e5m2fnuz,
kComplexHalf,
kHalf,
kBool,
kBFloat16);
else
#endif
AT_DISPATCH_V2(
expandedValue.scalar_type(),
"indexing_backward",
Expand Down
123 changes: 118 additions & 5 deletions aten/src/ATen/native/transformers/cuda/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,72 @@
#endif
#endif

#if defined(USE_ROCM) && (defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION))
namespace pytorch_flash
{
std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor>
mha_fwd(
const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size
std::optional<at::Tensor>&
out_, // batch_size x seqlen_q x num_heads x head_size
std::optional<at::Tensor>&
alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout,
const float softmax_scale,
bool is_causal,
std::optional<int64_t> window_size_left,
std::optional<int64_t> window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_) {
#if defined(USE_ROCM_CK_SDPA)
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
const int non_null_window_left = window_size_left.value_or(-1);
const int non_null_window_right = window_size_right.value_or(-1);
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
return mha_fwd_ck(
q,
k,
v,
out_,
p_dropout,
softmax_scale,
is_causal,
non_null_window_left,
non_null_window_right,
return_softmax,
gen_,
dummy_attn_bias); // Not used in flash attention
}
#endif
return mha_fwd_aot(
q,
k,
v,
out_,
alibi_slopes_,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
return_softmax,
gen_);
}
}
#endif

namespace at {

namespace cuda::philox {
Expand Down Expand Up @@ -1406,12 +1472,15 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
at::Tensor v_t = value.transpose(1, 2);
at::Tensor output_t = res.transpose(1, 2);
bool is_causal;
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
is_causal = true;
} else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
is_causal = false;
} else {
TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now");
is_causal = true;
#if AOTRITON_V3_API == 0
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) != custom_mask_type) {
TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now");
}
#endif
}

at::Tensor atomic_counter;
Expand All @@ -1436,7 +1505,51 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr<int64_t>() : nullptr);
auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr);
hipError_t err; // TODO: Error handling
if (seqstart_q.has_value()) {
if constexpr (AOTRITON_ALWAYS_V3_API) { // Better readability than nesting ifdef
#if AOTRITON_V3_API // if constexpr does not stop errors from undefined functions
using aotriton::v3::flash::CausalType;
using aotriton::v3::flash::VarlenType;
using aotriton::v3::flash::WindowValue;
aotriton::v3::flash::attn_fwd_params params;
params.Q = mk_aotensor(q_t, "q");
params.K = mk_aotensor(k_t, "k");
params.V = mk_aotensor(v_t, "v");
params.Sm_scale = softmax_scale;
params.L = compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2;
params.Out = mk_aotensor(output_t, "Out");
params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty
params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty
params.dropout_p = dropout_p;
params.philox_seed_ptr = seed;
params.philox_offset1 = offset1;
params.philox_offset2 = offset2;
params.philox_seed_output = seed_output;
params.philox_offset_output = offset_output;
params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax");
params.persistent_atomic_counter = persistent_counter;
params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None;
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
params.window_left = WindowValue::TopLeftAligned;
params.window_right = WindowValue::TopLeftAligned;
} else if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromBottomRight) == custom_mask_type) {
params.window_left = WindowValue::BottomRightAligned;
params.window_right = WindowValue::BottomRightAligned;
}
if (bias.has_value()) {
params.B = mk_aotensor(bias.value(), "bias");
}
if (seqstart_q.has_value()) {
params.varlen_type = VarlenType::CompactVarlen;
params.cu_seqlens_q = mk_aotensor<1>(seqstart_q.value(), "cu_seqlens_q");
params.cu_seqlens_k = mk_aotensor<1>(seqstart_k.value(), "cu_seqlens_k");
} else {
params.varlen_type = VarlenType::None;
}
err = aotriton::v3::flash::attn_fwd(params,
aotriton::v3::flash::attn_fwd_params::kVersion,
stream);
#endif // AOTRITON_V3_API
} else if (seqstart_q.has_value()) {
// varlen aka nested tensor
err = attn_fwd_compact_varlen(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
Expand Down
Loading