Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ cmake_dependent_option(
"Whether to build the flash_attention kernel for scaled dot product attention.\
Will be disabled if not supported by the platform"
ON
"USE_CUDA OR USE_ROCM;NOT MSVC"
"(USE_CUDA AND NOT MSVC) OR USE_ROCM"
OFF)

cmake_dependent_option(
Expand Down Expand Up @@ -908,7 +908,7 @@ cmake_dependent_option(
# USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake
#
if(USE_ROCM)
if(UNIX AND (USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION))
if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)
include(cmake/External/aotriton.cmake)
endif()
endif()
Expand Down
66 changes: 66 additions & 0 deletions aten/src/ATen/native/transformers/cuda/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,72 @@
#endif
#endif

#if defined(USE_ROCM) && (defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION))
namespace pytorch_flash
{
std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor>
mha_fwd(
const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size
std::optional<at::Tensor>&
out_, // batch_size x seqlen_q x num_heads x head_size
std::optional<at::Tensor>&
alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout,
const float softmax_scale,
bool is_causal,
std::optional<int64_t> window_size_left,
std::optional<int64_t> window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_) {
#if defined(USE_ROCM_CK_SDPA)
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
const int non_null_window_left = window_size_left.value_or(-1);
const int non_null_window_right = window_size_right.value_or(-1);
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
return mha_fwd_ck(
q,
k,
v,
out_,
p_dropout,
softmax_scale,
is_causal,
non_null_window_left,
non_null_window_right,
return_softmax,
gen_,
dummy_attn_bias); // Not used in flash attention
}
#endif
return mha_fwd_aot(
q,
k,
v,
out_,
alibi_slopes_,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
return_softmax,
gen_);
}
}
#endif

namespace at {

namespace cuda::philox {
Expand Down
39 changes: 2 additions & 37 deletions aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varle
#endif

TORCH_API
inline std::tuple<
std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
Expand All @@ -294,42 +294,7 @@ mha_fwd(
std::optional<int64_t> window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_) {
#if defined(USE_ROCM_CK_SDPA)
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
const int non_null_window_left = window_size_left.value_or(-1);
const int non_null_window_right = window_size_right.value_or(-1);
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
return mha_fwd_ck(
q,
k,
v,
out_,
p_dropout,
softmax_scale,
is_causal,
non_null_window_left,
non_null_window_right,
return_softmax,
gen_,
dummy_attn_bias); // Not used in flash attention
}
#endif
return mha_fwd_aot(
q,
k,
v,
out_,
alibi_slopes_,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
return_softmax,
gen_);
}
std::optional<at::Generator> gen_);

inline std::tuple<
at::Tensor,
Expand Down
113 changes: 108 additions & 5 deletions cmake/External/aotriton.cmake
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My local builds seemed to succeed with this branch and aotriton actually enabled (visible in build logs + files present in the .whls). However, I'm seeing the same performance via comfyui with and without aotriton on gfx1100, even with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 and python D:\projects\ComfyUI\main.py --use-split-cross-attention. I see about 12.6it/s for image generation tasks while a month ago I reported 20.0it/s with aotriton 🤔

Logs before updating comfyui itself to latest had this:

D:\projects\ComfyUI\comfy\ops.py:47: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at D:/b/pytorch_2_9/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:800.)
  return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)

Those logs are not present after updating comfyui to latest.

The latest torch + rocm wheels from https://github.com/ROCm/TheRock/blob/main/RELEASES.md#installing-pytorch-python-packages get me about 14it/s.

rocm==7.10.0a20251015
rocm-sdk-core==7.10.0a20251015
rocm-sdk-libraries-gfx110X-dgpu==7.10.0a20251015
torch==2.10.0a0+rocm7.10.0a20251015
torchaudio==2.8.0a0+rocm7.10.0a20251015
torchsde==0.2.6
torchvision==0.25.0a0+rocm7.10.0a20251015

Not sure where the diffs are coming from. Could be:

  • Missing more changes on 2.9 that are present on 2.10a
  • My system is under more load now (could also test with older releases)
  • Aotriton is not actually enabled / in use?

Copy link

@jammm jammm Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I missed the part where you already had TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 set.

Those logs are not present after updating comfyui to latest.

The latest one on main disables MIOpen itself, but aotriton should still be running I think.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird...

  • My locally built .whl files have torch/lib/aotriton_v2.dll
  • I do not see that DLL in site-packages/torch/lib/ after installing the locally built .whl files
  • I do see that DLL after installing our nightly built .whl files (from torch 2.10a / nightly / main)
  • The script for installing the locally built wheels shows missing aotriton:
    (3.12.venv) λ python D:\scratch\python\validate_torch_vroom.py
    Benchmarking Scaled Dot-Product Attention (Flash) in FP16 ...
    D:\scratch\python\validate_torch_vroom.py:72: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at D:/b/pytorch_2_9/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:800.)
      out = scaled_dot_product_attention(q, k, v)
    D:\scratch\python\validate_torch_vroom.py:72: UserWarning: Memory efficient kernel not used because: (Triggered internally at D:/b/pytorch_2_9/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:938.)
      out = scaled_dot_product_attention(q, k, v)
    D:\scratch\python\validate_torch_vroom.py:72: UserWarning: Flash attention kernel not used because: (Triggered internally at D:/b/pytorch_2_9/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:940.)
      out = scaled_dot_product_attention(q, k, v)
    D:\scratch\python\validate_torch_vroom.py:72: UserWarning: Torch was not compiled with flash attention. (Triggered internally at D:/b/pytorch_2_9/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:749.)
      out = scaled_dot_product_attention(q, k, v)
    D:\scratch\python\validate_torch_vroom.py:72: UserWarning: cuDNN attention kernel not used because: (Triggered internally at D:/b/pytorch_2_9/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:942.)
      out = scaled_dot_product_attention(q, k, v)
    D:\scratch\python\validate_torch_vroom.py:72: UserWarning: Torch was not compiled with cuDNN attention. (Triggered internally at D:/b/pytorch_2_9/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:683.)
      out = scaled_dot_product_attention(q, k, v)
    Traceback (most recent call last):
      File "D:\scratch\python\validate_torch_vroom.py", line 215, in <module>
        sdpa_time, sdpa_mem, sdpa_gflops = measure_op(run_sdpa, warmup=3, total_runs=10)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "D:\scratch\python\validate_torch_vroom.py", line 34, in measure_op
        t_ms, peak_mb, gf_s = op_func()
                              ^^^^^^^^^
      File "D:\scratch\python\validate_torch_vroom.py", line 72, in run_sdpa
        out = scaled_dot_product_attention(q, k, v)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    RuntimeError: No available kernel. Aborting execution.
    

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ScottTodd the aotriton_v2.dll file is copied over from <torch_src>/torch/lib which could be a remnant of previous builds. It's likely that it got copied over even though torch was built without aotriton.

Copy link
Member Author

@ScottTodd ScottTodd Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤦 I built torch-2.9.0 after we changed the version but installed my prior build of torch-2.9.0a0...

Okay, aotriton is there with my local build from this PR (or the release/2.9_rocm7.9 branch)

17 it/s with

set TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1
python D:\projects\ComfyUI\main.py --use-pytorch-cross-attention

14.5 it/s with

set TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=0
python D:\projects\ComfyUI\main.py

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ScottTodd @jammm maybe this change is missing: pytorch#165538

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ScottTodd @jammm maybe this change is missing: pytorch#165538

Could be useful. In my build logs I see this though, showing that it isn't strictly required here yet:

-- Cannot find AOTriton runtime for ROCM 7.1.       Build runtime from source

(the top level version should technically be 7.9 I think, but it is specified in multiple subprojects with different values)

Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,88 @@ if(NOT __AOTRITON_INCLUDED)
)
set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore
set(__AOTRITON_Z "gz")
# Set the default __AOTRITON_LIB path
set(__AOTRITON_LIB "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so")
if(WIN32)
set(__AOTRITON_LIB "${__AOTRITON_INSTALL_DIR}/lib/aotriton_v2.lib")
endif()

function(aotriton_build_windows_dependencies dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR)
# Windows-specific dependencies - build these first
if(NOT noimage)
message(FATAL_ERROR "noimage must be ON for Windows builds")
endif()
# Build dlfcn-win32
set(__DLFCN_WIN32_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/dlfcn-win32")
set(__DLFCN_WIN32_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/dlfcn-win32-install")

ExternalProject_Add(${dlfcn-win32_external}
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
GIT_TAG v1.4.2
PREFIX ${__DLFCN_WIN32_PREFIX}
INSTALL_DIR ${__DLFCN_WIN32_INSTALL_DIR}
CMAKE_ARGS
-DCMAKE_INSTALL_PREFIX=${__DLFCN_WIN32_INSTALL_DIR}
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_C_COMPILER=cl
-DCMAKE_CXX_COMPILER=cl
-DBUILD_SHARED_LIBS=ON
-DBUILD_TESTS=OFF
BUILD_BYPRODUCTS
"${__DLFCN_WIN32_INSTALL_DIR}/lib/dl.lib"
"${__DLFCN_WIN32_INSTALL_DIR}/bin/dl.dll"
)
ExternalProject_Add_Step(${dlfcn-win32_external} copy_to_aotriton
COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${__DLFCN_WIN32_INSTALL_DIR}/bin/dl.dll"
"${__AOTRITON_INSTALL_DIR}/lib/"
DEPENDEES install
)
set(${dlfcn-win32_DIR} "${__DLFCN_WIN32_INSTALL_DIR}/share/dlfcn-win32" CACHE PATH "Path to dlfcn-win32 CMake config" FORCE)

# Build xz/liblzma
set(__XZ_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/xz")
set(__XZ_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/xz-install")

ExternalProject_Add(${xz_external}
GIT_REPOSITORY https://github.com/tukaani-project/xz.git
GIT_TAG v5.8.1
PREFIX ${__XZ_PREFIX}
INSTALL_DIR ${__XZ_INSTALL_DIR}
CMAKE_ARGS
-DCMAKE_INSTALL_PREFIX=${__XZ_INSTALL_DIR}
-DCMAKE_BUILD_TYPE=Release
-DBUILD_SHARED_LIBS=ON
-DENABLE_NLS=OFF
-DXZ_TOOL_LZMAINFO=OFF
-DXZ_TOOL_XZ=OFF
-DXZ_TOOL_XZDEC=OFF
-DXZ_TOOL_LZMADEC=OFF
BUILD_BYPRODUCTS
"${__XZ_INSTALL_DIR}/lib/lzma.lib"
"${__XZ_INSTALL_DIR}/bin/liblzma.dll"
)
ExternalProject_Add_Step(${xz_external} copy_to_aotriton
COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${__XZ_INSTALL_DIR}/bin/liblzma.dll"
"${__AOTRITON_INSTALL_DIR}/lib/"
DEPENDEES install
)
set(${liblzma_DIR} "${__XZ_INSTALL_DIR}/lib/cmake/liblzma" CACHE PATH "Path to xz/liblzma CMake config" FORCE)
endfunction()

function(aotriton_build_from_source noimage project)
if(noimage)
SET(RECURSIVE "OFF")
else()
SET(RECURSIVE "ON")
endif()
if(WIN32)
message(STATUS "Building AOTriton Windows dependencies")
aotriton_build_windows_dependencies(dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR)
endif()
message(STATUS "PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}")

ExternalProject_Add(${project}
GIT_REPOSITORY https://github.com/ROCm/aotriton.git
GIT_SUBMODULES_RECURSE ${RECURSIVE}
Expand All @@ -65,12 +140,19 @@ if(NOT __AOTRITON_INCLUDED)
-DAOTRITON_GPU_BUILD_TIMEOUT=0
-DAOTRITON_NO_PYTHON=ON
-DAOTRITON_NOIMAGE_MODE=${noimage}
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so"
-DHIP_PLATFORM=amd
$<$<BOOL:${WIN32}>:-Ddlfcn-win32_DIR=${dlfcn-win32_DIR}>
$<$<BOOL:${WIN32}>:-Dliblzma_DIR=${liblzma_DIR}>
BUILD_BYPRODUCTS
"${__AOTRITON_LIB}"
USES_TERMINAL_DOWNLOAD TRUE
USES_TERMINAL_CONFIGURE TRUE
USES_TERMINAL_BUILD TRUE
USES_TERMINAL_INSTALL TRUE
)
if(WIN32)
add_dependencies(${project} dlfcn-win32_external xz_external)
endif()
endfunction()

set(__AOTRITON_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR})
Expand All @@ -95,7 +177,7 @@ if(NOT __AOTRITON_INCLUDED)
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
"${CMAKE_CURRENT_BINARY_DIR}/aotriton_runtime"
"${__AOTRITON_INSTALL_DIR}"
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so"
BUILD_BYPRODUCTS "${__AOTRITON_LIB}"
)
message(STATUS "Using AOTriton Runtime from pre-compiled binary ${__AOTRITON_URL}.\
Set env variables AOTRITON_INSTALL_FROM_SOURCE=1 to build from source.")
Expand All @@ -111,14 +193,35 @@ if(NOT __AOTRITON_INCLUDED)
string(CONCAT __AOTRITON_URL
"${__AOTRITON_BASE_URL}"
"${__AOTRITON_VER}/${__AOTRITON_FILE}")

# Set up directories
set(__AOTRITON_DOWNLOAD_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_download-${image})
set(__AOTRITON_EXTRACT_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image})
set(__AOTRITON_INSTALL_SOURCE_DIR ${__AOTRITON_EXTRACT_DIR})
set(__DOWNLOAD_NO_EXTRACT "")
set(__BUILD_COMMANDS "")

# On Windows, we need custom tar extraction with UTF-8 support
if(WIN32)
set(__DOWNLOAD_NO_EXTRACT "DOWNLOAD_NO_EXTRACT;TRUE")
set(__BUILD_COMMANDS
COMMAND ${CMAKE_COMMAND} -E make_directory "${__AOTRITON_EXTRACT_DIR}"
COMMAND tar --options hdrcharset=UTF-8 -xf "${__AOTRITON_DOWNLOAD_DIR}/${__AOTRITON_FILE}" -C "${__AOTRITON_EXTRACT_DIR}"
)
set(__AOTRITON_INSTALL_SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}/aotriton)
endif()

ExternalProject_Add(${project}
URL "${__AOTRITON_URL}"
URL_HASH SHA256=${__AOTRITON_SHA256}
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}
DOWNLOAD_DIR ${__AOTRITON_DOWNLOAD_DIR}
${__DOWNLOAD_NO_EXTRACT}
SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
${__BUILD_COMMANDS}
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
"${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}"
"${__AOTRITON_INSTALL_SOURCE_DIR}"
"${__AOTRITON_INSTALL_DIR}"
BUILD_BYPRODUCTS
"${__AOTRITON_INSTALL_DIR}/lib/aotriton.images/${image}/__signature__"
Expand Down Expand Up @@ -164,7 +267,7 @@ if(NOT __AOTRITON_INCLUDED)
endforeach()
endforeach()
endif()
target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so)
target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_LIB})
target_include_directories(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include)
set(AOTRITON_FOUND TRUE)
endif() # __AOTRITON_INCLUDED
1 change: 1 addition & 0 deletions tools/linter/dictionary.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ BU
contiguities
contiguity
coo
DEPENDEES
deser
din
dout
Expand Down