Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,13 @@ if(USE_ROCM)
endif()
endif()

# link CK library
if(USE_ROCM)
if(UNIX AND USE_CK_FLASH_ATTENTION)
include(cmake/External/ck_kernels.cmake)
endif()
endif()

if(DEBUG_CUDA)
string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo")
string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo")
Expand Down
27 changes: 15 additions & 12 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -187,26 +187,29 @@ if(USE_FLASH_ATTENTION)
caffe2_update_option(USE_ROCM_CK_SDPA ON)
endif()
if(USE_ROCM_CK_SDPA)
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS)
if(NUM_ARCHS GREATER 1)
message(WARNING "Building CK for multiple archs can increase build time considerably!
Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for")
endif()
endif()
message(STATUS "USE_ROCM_CK_SDPA is set; building PyTorch with CK SDPA enabled")
message(STATUS "Generating CK kernel instances...")
add_subdirectory(native/transformers/hip/flash_attn/ck)
if(DEFINED CK_KERNELS_INSTALL_FROM_SOURCE)
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS)
if(NUM_ARCHS GREATER 1)
message(WARNING "Building CK for multiple archs can increase build time considerably!
Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for")
endif()
endif()
# building CK kernels from source
message(STATUS "Generating CK kernel instances...")
add_subdirectory(native/transformers/hip/flash_attn/ck)
add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3)
endif() # end of CK_KERNELS_INSTALL_FROM_SOURCE
file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip")
list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip})
# FAv3 Generation
add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3)
file(GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip")
list(APPEND native_transformers_hip_hip ${flash_attention_v3_hip})
endif()
endif() # end of USE_ROCM_CK_SDPA
file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip")
file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip")
endif()
endif() # end of USE_FLASH_ATTENTION

#Mem_eff attention sources
file(GLOB mem_eff_attention_cuda_cu "native/transformers/cuda/mem_eff_attention/*.cu")
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/transformers/cuda/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1403,6 +1403,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
at::ROCmFABackend::Ck) {

#if defined(USE_ROCM_CK_SDPA)
TORCH_WARN_ONCE("Using CK backend for Efficient Attention forward...");

std::optional<Tensor> out(res);
std::optional<Tensor> seqused_k = std::nullopt;
std::optional<Tensor> alibi_slopes = std::nullopt;
Expand Down Expand Up @@ -1445,6 +1447,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
" (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)")
}

TORCH_WARN_ONCE("Using AOTriton backend for Efficient Attention forward...");

// AOTriton may accept aligned on logsumexp tensor in the future for better
// performance, but for now it requires compact logsumexp tensor, even if
// compute_logsumexp is false
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/native/transformers/cuda/attention_backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,8 @@ _efficient_attention_backward(
if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck)
{
#if defined(USE_ROCM_CK_SDPA)
TORCH_WARN_ONCE("Using CK backend for Efficient Attention backward...");

const auto my_softmax_scale = sdp::calculate_scale(query, scale).expect_float();
// Store grad_bias in optional
std::optional<at::Tensor> opt_grad_bias = grad_bias;
Expand Down Expand Up @@ -544,6 +546,9 @@ _efficient_attention_backward(
"[AOTriton] Accelerated SDPA only supports MI200/MI300X/7900XTX/9070XT GPUs"
" (gfx90a/gfx942/gfx1100/gfx1201)")
}

TORCH_WARN_ONCE("Using AOTriton backend for Efficient Attention backward...");

const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
bool is_causal;
if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ namespace pytorch_flash {

namespace {

void check_aotriton_gpu_arch(hipStream_t stream) {
auto ret = aotriton::v2::flash::check_gpu(stream);
if (hipSuccess != ret) {
TORCH_CHECK(false,
"[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs"
" (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)")
}
}

void check_gpu_arch(hipStream_t stream) {
auto ret = aotriton::v2::flash::check_gpu(stream);
if (hipSuccess != ret) {
Expand Down Expand Up @@ -163,7 +172,7 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
const bool return_softmax,
const std::optional<at::Generator>& gen_) {
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
check_gpu_arch(stream);
check_aotriton_gpu_arch(stream);

auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
Expand Down Expand Up @@ -350,7 +359,7 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot

at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
check_gpu_arch(stream);
check_aotriton_gpu_arch(stream);

auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
Expand Down Expand Up @@ -562,7 +571,7 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
// Cast to char to avoid compiler warning about narrowing
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
check_gpu_arch(stream);
check_aotriton_gpu_arch(stream);

bool is_dropout = p_dropout > 0.0;

Expand Down Expand Up @@ -795,7 +804,8 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
// Cast to char to avoid compiler warning about narrowing
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
check_gpu_arch(stream);
check_aotriton_gpu_arch(stream);


bool is_dropout = p_dropout > 0.0;

Expand Down
4 changes: 4 additions & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,10 @@ if(USE_ROCM)
if(USE_FLASH_ATTENTION)
target_link_libraries(torch_hip PRIVATE __caffe2_aotriton)
endif()
# link CK library if not building CK_kernels from source
if(USE_CK_FLASH_ATTENTION AND NOT CK_KERNELS_INSTALL_FROM_SOURCE)
target_link_libraries(torch_hip PRIVATE __ck_kernels_lib)
endif()
set(CUDA_LINK_LIBRARIES_KEYWORD)
torch_compile_options(torch_hip) # see cmake/public/utils.cmake
# TODO: Not totally sure if this is live or not
Expand Down
63 changes: 63 additions & 0 deletions cmake/External/ck_kernels.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#
# create INTERFACE target for CK library
#
if(NOT __ck_kernels_included)
set(__ck_kernels_included TRUE)

set(ck_kernels_install_dir "${PROJECT_SOURCE_DIR}/torch/lib")

set(__ck_kernels_version 0.1)

# create INTERFACE target
add_library(__ck_kernels_lib INTERFACE)

if(DEFINED ENV{CK_KERNELS_INSTALLED_PREFIX})
# Copy .so from $ENV{CK_KERNELS_INSTALLED_PREFIX} into ${ck_kernels_install_dir}
install(DIRECTORY
$ENV{CK_KERNELS_INSTALLED_PREFIX}/
DESTINATION ${ck_kernels_install_dir}
)
set(ck_kernels_install_path "$ENV{CK_KERNELS_INSTALLED_PREFIX}/libck_kernels.so")
# specify path to CK library
target_link_libraries(__ck_kernels_lib INTERFACE ${ck_kernels_install_path})
message(STATUS "Using Preinstalled CK_kernels from $ENV{CK_KERNELS_INSTALLED_PREFIX}; installed at ${ck_kernels_install_dir}")
elseif(DEFINED ENV{CK_KERNELS_PACKAGE_BASE_URL})
# get CK commit hash
execute_process(
COMMAND git -C ${CMAKE_SOURCE_DIR}/third_party submodule status composable_kernel
RESULT_VARIABLE result
OUTPUT_VARIABLE submodule_status
ERROR_VARIABLE submodule_status_error
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(result EQUAL 0)
string(REGEX REPLACE "^[ \t]" "" submodule_status ${submodule_status})
# extract first 8 characters of the commit hash
string(SUBSTRING "${submodule_status}" 0 8 ck_commit_hash)
else()
message(FATAL_ERROR "Failed to get submodule status for composable_kernel.")
endif()

set(ck_kernels_package_full_url "$ENV{CK_KERNELS_PACKAGE_BASE_URL}/torch_ck_gen_lib/ck_${ck_commit_hash}/hip_${HIP_VERSION_MAJOR}.${HIP_VERSION_MINOR}/libck_kernels.tar.gz")
set(ck_kernels_install_path "${ck_kernels_install_dir}/libck_kernels.so")

ExternalProject_Add(ck_kernels_external
URL "${ck_kernels_package_full_url}"
# URL_HASH
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/ck_kernels_tarball
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
"${CMAKE_CURRENT_BINARY_DIR}/ck_kernels_tarball"
"${ck_kernels_install_dir}"
BUILD_BYPRODUCTS "${ck_kernels_install_path}"
)
add_dependencies(__ck_kernels_lib ck_kernels_external)
message(STATUS "Using CK_kernels from pre-compiled binary ${ck_kernels_package_full_url}; installed at ${ck_kernels_install_dir}")
# specify path to CK library
target_link_libraries(__ck_kernels_lib INTERFACE ${ck_kernels_install_path})
else()
set(CK_KERNELS_INSTALL_FROM_SOURCE TRUE)
endif() # DEFINED ENV{CK_KERNELS_INSTALLED_PREFIX}

endif() # __ck_kernels_included