diff --git a/CMakeLists.txt b/CMakeLists.txt index 91181735750d6..551e40da19d44 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -913,6 +913,13 @@ if(USE_ROCM) endif() endif() +# link CK library +if(USE_ROCM) + if(UNIX AND USE_CK_FLASH_ATTENTION) + include(cmake/External/ck_kernels.cmake) + endif() +endif() + if(DEBUG_CUDA) string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo") string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo") diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index b30d8336e8ec9..5d65e2dcbed59 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -187,26 +187,29 @@ if(USE_FLASH_ATTENTION) caffe2_update_option(USE_ROCM_CK_SDPA ON) endif() if(USE_ROCM_CK_SDPA) - if(DEFINED ENV{PYTORCH_ROCM_ARCH}) - list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS) - if(NUM_ARCHS GREATER 1) - message(WARNING "Building CK for multiple archs can increase build time considerably! - Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for") - endif() - endif() message(STATUS "USE_ROCM_CK_SDPA is set; building PyTorch with CK SDPA enabled") - message(STATUS "Generating CK kernel instances...") - add_subdirectory(native/transformers/hip/flash_attn/ck) + if(DEFINED CK_KERNELS_INSTALL_FROM_SOURCE) + if(DEFINED ENV{PYTORCH_ROCM_ARCH}) + list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS) + if(NUM_ARCHS GREATER 1) + message(WARNING "Building CK for multiple archs can increase build time considerably! + Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for") + endif() + endif() + # building CK kernels from source + message(STATUS "Generating CK kernel instances...") + add_subdirectory(native/transformers/hip/flash_attn/ck) + add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3) + endif() # end of CK_KERNELS_INSTALL_FROM_SOURCE file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip") list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip}) # FAv3 Generation - add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3) file(GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip") list(APPEND native_transformers_hip_hip ${flash_attention_v3_hip}) - endif() + endif() # end of USE_ROCM_CK_SDPA file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip") file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip") -endif() +endif() # end of USE_FLASH_ATTENTION #Mem_eff attention sources file(GLOB mem_eff_attention_cuda_cu "native/transformers/cuda/mem_eff_attention/*.cu") diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index c2193f2378dd5..e0ae5f53f8357 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -1403,6 +1403,8 @@ std::tuple _efficient_ at::ROCmFABackend::Ck) { #if defined(USE_ROCM_CK_SDPA) + TORCH_WARN_ONCE("Using CK backend for Efficient Attention forward..."); + std::optional out(res); std::optional seqused_k = std::nullopt; std::optional alibi_slopes = std::nullopt; @@ -1445,6 +1447,8 @@ std::tuple _efficient_ " (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)") } + TORCH_WARN_ONCE("Using AOTriton backend for Efficient Attention forward..."); + // AOTriton may accept aligned on logsumexp tensor in the future for better // performance, but for now it requires compact logsumexp tensor, even if // compute_logsumexp is false diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index 55fc1e261219e..4d9116af04c78 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -496,6 +496,8 @@ _efficient_attention_backward( if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { #if defined(USE_ROCM_CK_SDPA) + TORCH_WARN_ONCE("Using CK backend for Efficient Attention backward..."); + const auto my_softmax_scale = sdp::calculate_scale(query, scale).expect_float(); // Store grad_bias in optional std::optional opt_grad_bias = grad_bias; @@ -544,6 +546,9 @@ _efficient_attention_backward( "[AOTriton] Accelerated SDPA only supports MI200/MI300X/7900XTX/9070XT GPUs" " (gfx90a/gfx942/gfx1100/gfx1201)") } + + TORCH_WARN_ONCE("Using AOTriton backend for Efficient Attention backward..."); + const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float(); bool is_causal; if (static_cast(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) { diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip index 2467cb809fdbf..67313e1a494c0 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip @@ -71,6 +71,15 @@ namespace pytorch_flash { namespace { +void check_aotriton_gpu_arch(hipStream_t stream) { + auto ret = aotriton::v2::flash::check_gpu(stream); + if (hipSuccess != ret) { + TORCH_CHECK(false, + "[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs" + " (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)") + } +} + void check_gpu_arch(hipStream_t stream) { auto ret = aotriton::v2::flash::check_gpu(stream); if (hipSuccess != ret) { @@ -163,7 +172,7 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x const bool return_softmax, const std::optional& gen_) { auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); - check_gpu_arch(stream); + check_aotriton_gpu_arch(stream); auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, @@ -350,7 +359,7 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); - check_gpu_arch(stream); + check_aotriton_gpu_arch(stream); auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, @@ -562,7 +571,7 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea // Cast to char to avoid compiler warning about narrowing at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); - check_gpu_arch(stream); + check_aotriton_gpu_arch(stream); bool is_dropout = p_dropout > 0.0; @@ -795,7 +804,8 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size // Cast to char to avoid compiler warning about narrowing at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); - check_gpu_arch(stream); + check_aotriton_gpu_arch(stream); + bool is_dropout = p_dropout > 0.0; diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 6ab41b6c84793..8ee5a1451f701 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -940,6 +940,10 @@ if(USE_ROCM) if(USE_FLASH_ATTENTION) target_link_libraries(torch_hip PRIVATE __caffe2_aotriton) endif() +# link CK library if not building CK_kernels from source + if(USE_CK_FLASH_ATTENTION AND NOT CK_KERNELS_INSTALL_FROM_SOURCE) + target_link_libraries(torch_hip PRIVATE __ck_kernels_lib) + endif() set(CUDA_LINK_LIBRARIES_KEYWORD) torch_compile_options(torch_hip) # see cmake/public/utils.cmake # TODO: Not totally sure if this is live or not diff --git a/cmake/External/ck_kernels.cmake b/cmake/External/ck_kernels.cmake new file mode 100644 index 0000000000000..98c6e4ddd291f --- /dev/null +++ b/cmake/External/ck_kernels.cmake @@ -0,0 +1,63 @@ +# +# create INTERFACE target for CK library +# +if(NOT __ck_kernels_included) + set(__ck_kernels_included TRUE) + + set(ck_kernels_install_dir "${PROJECT_SOURCE_DIR}/torch/lib") + + set(__ck_kernels_version 0.1) + + # create INTERFACE target + add_library(__ck_kernels_lib INTERFACE) + + if(DEFINED ENV{CK_KERNELS_INSTALLED_PREFIX}) + # Copy .so from $ENV{CK_KERNELS_INSTALLED_PREFIX} into ${ck_kernels_install_dir} + install(DIRECTORY + $ENV{CK_KERNELS_INSTALLED_PREFIX}/ + DESTINATION ${ck_kernels_install_dir} + ) + set(ck_kernels_install_path "$ENV{CK_KERNELS_INSTALLED_PREFIX}/libck_kernels.so") + # specify path to CK library + target_link_libraries(__ck_kernels_lib INTERFACE ${ck_kernels_install_path}) + message(STATUS "Using Preinstalled CK_kernels from $ENV{CK_KERNELS_INSTALLED_PREFIX}; installed at ${ck_kernels_install_dir}") + elseif(DEFINED ENV{CK_KERNELS_PACKAGE_BASE_URL}) + # get CK commit hash + execute_process( + COMMAND git -C ${CMAKE_SOURCE_DIR}/third_party submodule status composable_kernel + RESULT_VARIABLE result + OUTPUT_VARIABLE submodule_status + ERROR_VARIABLE submodule_status_error + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + if(result EQUAL 0) + string(REGEX REPLACE "^[ \t]" "" submodule_status ${submodule_status}) + # extract first 8 characters of the commit hash + string(SUBSTRING "${submodule_status}" 0 8 ck_commit_hash) + else() + message(FATAL_ERROR "Failed to get submodule status for composable_kernel.") + endif() + + set(ck_kernels_package_full_url "$ENV{CK_KERNELS_PACKAGE_BASE_URL}/torch_ck_gen_lib/ck_${ck_commit_hash}/hip_${HIP_VERSION_MAJOR}.${HIP_VERSION_MINOR}/libck_kernels.tar.gz") + set(ck_kernels_install_path "${ck_kernels_install_dir}/libck_kernels.so") + + ExternalProject_Add(ck_kernels_external + URL "${ck_kernels_package_full_url}" + # URL_HASH + SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/ck_kernels_tarball + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory + "${CMAKE_CURRENT_BINARY_DIR}/ck_kernels_tarball" + "${ck_kernels_install_dir}" + BUILD_BYPRODUCTS "${ck_kernels_install_path}" + ) + add_dependencies(__ck_kernels_lib ck_kernels_external) + message(STATUS "Using CK_kernels from pre-compiled binary ${ck_kernels_package_full_url}; installed at ${ck_kernels_install_dir}") + # specify path to CK library + target_link_libraries(__ck_kernels_lib INTERFACE ${ck_kernels_install_path}) + else() + set(CK_KERNELS_INSTALL_FROM_SOURCE TRUE) + endif() # DEFINED ENV{CK_KERNELS_INSTALLED_PREFIX} + +endif() # __ck_kernels_included