Skip to content

Commit 83a046e

Browse files
jammmScottTodd
authored andcommitted
[ROCm/Windows] Support aotriton for scaled_dot_product_attention on Windows. (pytorch#162330)
Enables flash attention and/or memory efficient attention on Windows with scaled_dot_product_attention via. aotriton. Already tested to be working on Windows with TheRock. Steps to enable: simply set `USE_FLASH_ATTENTION=1` and `USE_MEM_EFF_ATTENTION=1` as usual. See https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py#L578-L604 Pull Request resolved: pytorch#162330 Approved by: https://github.com/jeffdaily Co-authored-by: Scott Todd <[email protected]>
1 parent 8ed5a06 commit 83a046e

File tree

5 files changed

+179
-44
lines changed

5 files changed

+179
-44
lines changed

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,7 @@ cmake_dependent_option(
873873
"Whether to build the flash_attention kernel for scaled dot product attention.\
874874
Will be disabled if not supported by the platform"
875875
ON
876-
"USE_CUDA OR USE_ROCM;NOT MSVC"
876+
"(USE_CUDA AND NOT MSVC) OR USE_ROCM"
877877
OFF)
878878

879879
cmake_dependent_option(
@@ -908,7 +908,7 @@ cmake_dependent_option(
908908
# USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake
909909
#
910910
if(USE_ROCM)
911-
if(UNIX AND (USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION))
911+
if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)
912912
include(cmake/External/aotriton.cmake)
913913
endif()
914914
endif()

aten/src/ATen/native/transformers/cuda/attention.cu

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,72 @@
9595
#endif
9696
#endif
9797

98+
#if defined(USE_ROCM) && (defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION))
99+
namespace pytorch_flash
100+
{
101+
std::tuple<
102+
at::Tensor,
103+
at::Tensor,
104+
at::Tensor,
105+
at::Tensor,
106+
at::Tensor,
107+
at::Tensor,
108+
at::Tensor,
109+
at::Tensor>
110+
mha_fwd(
111+
const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
112+
const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size
113+
const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size
114+
std::optional<at::Tensor>&
115+
out_, // batch_size x seqlen_q x num_heads x head_size
116+
std::optional<at::Tensor>&
117+
alibi_slopes_, // num_heads or batch_size x num_heads
118+
const float p_dropout,
119+
const float softmax_scale,
120+
bool is_causal,
121+
std::optional<int64_t> window_size_left,
122+
std::optional<int64_t> window_size_right,
123+
const float softcap,
124+
const bool return_softmax,
125+
std::optional<at::Generator> gen_) {
126+
#if defined(USE_ROCM_CK_SDPA)
127+
if (at::globalContext().getROCmFAPreferredBackend() ==
128+
at::ROCmFABackend::Ck) {
129+
const int non_null_window_left = window_size_left.value_or(-1);
130+
const int non_null_window_right = window_size_right.value_or(-1);
131+
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
132+
return mha_fwd_ck(
133+
q,
134+
k,
135+
v,
136+
out_,
137+
p_dropout,
138+
softmax_scale,
139+
is_causal,
140+
non_null_window_left,
141+
non_null_window_right,
142+
return_softmax,
143+
gen_,
144+
dummy_attn_bias); // Not used in flash attention
145+
}
146+
#endif
147+
return mha_fwd_aot(
148+
q,
149+
k,
150+
v,
151+
out_,
152+
alibi_slopes_,
153+
p_dropout,
154+
softmax_scale,
155+
is_causal,
156+
window_size_left,
157+
window_size_right,
158+
return_softmax,
159+
gen_);
160+
}
161+
}
162+
#endif
163+
98164
namespace at {
99165

100166
namespace cuda::philox {

aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varle
270270
#endif
271271

272272
TORCH_API
273-
inline std::tuple<
273+
std::tuple<
274274
at::Tensor,
275275
at::Tensor,
276276
at::Tensor,
@@ -294,42 +294,7 @@ mha_fwd(
294294
std::optional<int64_t> window_size_right,
295295
const float softcap,
296296
const bool return_softmax,
297-
std::optional<at::Generator> gen_) {
298-
#if defined(USE_ROCM_CK_SDPA)
299-
if (at::globalContext().getROCmFAPreferredBackend() ==
300-
at::ROCmFABackend::Ck) {
301-
const int non_null_window_left = window_size_left.value_or(-1);
302-
const int non_null_window_right = window_size_right.value_or(-1);
303-
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
304-
return mha_fwd_ck(
305-
q,
306-
k,
307-
v,
308-
out_,
309-
p_dropout,
310-
softmax_scale,
311-
is_causal,
312-
non_null_window_left,
313-
non_null_window_right,
314-
return_softmax,
315-
gen_,
316-
dummy_attn_bias); // Not used in flash attention
317-
}
318-
#endif
319-
return mha_fwd_aot(
320-
q,
321-
k,
322-
v,
323-
out_,
324-
alibi_slopes_,
325-
p_dropout,
326-
softmax_scale,
327-
is_causal,
328-
window_size_left,
329-
window_size_right,
330-
return_softmax,
331-
gen_);
332-
}
297+
std::optional<at::Generator> gen_);
333298

334299
inline std::tuple<
335300
at::Tensor,

cmake/External/aotriton.cmake

Lines changed: 108 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,88 @@ if(NOT __AOTRITON_INCLUDED)
4545
)
4646
set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore
4747
set(__AOTRITON_Z "gz")
48+
# Set the default __AOTRITON_LIB path
49+
set(__AOTRITON_LIB "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so")
50+
if(WIN32)
51+
set(__AOTRITON_LIB "${__AOTRITON_INSTALL_DIR}/lib/aotriton_v2.lib")
52+
endif()
53+
54+
function(aotriton_build_windows_dependencies dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR)
55+
# Windows-specific dependencies - build these first
56+
if(NOT noimage)
57+
message(FATAL_ERROR "noimage must be ON for Windows builds")
58+
endif()
59+
# Build dlfcn-win32
60+
set(__DLFCN_WIN32_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/dlfcn-win32")
61+
set(__DLFCN_WIN32_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/dlfcn-win32-install")
62+
63+
ExternalProject_Add(${dlfcn-win32_external}
64+
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
65+
GIT_TAG v1.4.2
66+
PREFIX ${__DLFCN_WIN32_PREFIX}
67+
INSTALL_DIR ${__DLFCN_WIN32_INSTALL_DIR}
68+
CMAKE_ARGS
69+
-DCMAKE_INSTALL_PREFIX=${__DLFCN_WIN32_INSTALL_DIR}
70+
-DCMAKE_BUILD_TYPE=Release
71+
-DCMAKE_C_COMPILER=cl
72+
-DCMAKE_CXX_COMPILER=cl
73+
-DBUILD_SHARED_LIBS=ON
74+
-DBUILD_TESTS=OFF
75+
BUILD_BYPRODUCTS
76+
"${__DLFCN_WIN32_INSTALL_DIR}/lib/dl.lib"
77+
"${__DLFCN_WIN32_INSTALL_DIR}/bin/dl.dll"
78+
)
79+
ExternalProject_Add_Step(${dlfcn-win32_external} copy_to_aotriton
80+
COMMAND ${CMAKE_COMMAND} -E copy_if_different
81+
"${__DLFCN_WIN32_INSTALL_DIR}/bin/dl.dll"
82+
"${__AOTRITON_INSTALL_DIR}/lib/"
83+
DEPENDEES install
84+
)
85+
set(${dlfcn-win32_DIR} "${__DLFCN_WIN32_INSTALL_DIR}/share/dlfcn-win32" CACHE PATH "Path to dlfcn-win32 CMake config" FORCE)
86+
87+
# Build xz/liblzma
88+
set(__XZ_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/xz")
89+
set(__XZ_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/xz-install")
90+
91+
ExternalProject_Add(${xz_external}
92+
GIT_REPOSITORY https://github.com/tukaani-project/xz.git
93+
GIT_TAG v5.8.1
94+
PREFIX ${__XZ_PREFIX}
95+
INSTALL_DIR ${__XZ_INSTALL_DIR}
96+
CMAKE_ARGS
97+
-DCMAKE_INSTALL_PREFIX=${__XZ_INSTALL_DIR}
98+
-DCMAKE_BUILD_TYPE=Release
99+
-DBUILD_SHARED_LIBS=ON
100+
-DENABLE_NLS=OFF
101+
-DXZ_TOOL_LZMAINFO=OFF
102+
-DXZ_TOOL_XZ=OFF
103+
-DXZ_TOOL_XZDEC=OFF
104+
-DXZ_TOOL_LZMADEC=OFF
105+
BUILD_BYPRODUCTS
106+
"${__XZ_INSTALL_DIR}/lib/lzma.lib"
107+
"${__XZ_INSTALL_DIR}/bin/liblzma.dll"
108+
)
109+
ExternalProject_Add_Step(${xz_external} copy_to_aotriton
110+
COMMAND ${CMAKE_COMMAND} -E copy_if_different
111+
"${__XZ_INSTALL_DIR}/bin/liblzma.dll"
112+
"${__AOTRITON_INSTALL_DIR}/lib/"
113+
DEPENDEES install
114+
)
115+
set(${liblzma_DIR} "${__XZ_INSTALL_DIR}/lib/cmake/liblzma" CACHE PATH "Path to xz/liblzma CMake config" FORCE)
116+
endfunction()
117+
48118
function(aotriton_build_from_source noimage project)
49119
if(noimage)
50120
SET(RECURSIVE "OFF")
51121
else()
52122
SET(RECURSIVE "ON")
53123
endif()
124+
if(WIN32)
125+
message(STATUS "Building AOTriton Windows dependencies")
126+
aotriton_build_windows_dependencies(dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR)
127+
endif()
54128
message(STATUS "PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}")
129+
55130
ExternalProject_Add(${project}
56131
GIT_REPOSITORY https://github.com/ROCm/aotriton.git
57132
GIT_SUBMODULES_RECURSE ${RECURSIVE}
@@ -65,12 +140,19 @@ if(NOT __AOTRITON_INCLUDED)
65140
-DAOTRITON_GPU_BUILD_TIMEOUT=0
66141
-DAOTRITON_NO_PYTHON=ON
67142
-DAOTRITON_NOIMAGE_MODE=${noimage}
68-
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so"
143+
-DHIP_PLATFORM=amd
144+
$<$<BOOL:${WIN32}>:-Ddlfcn-win32_DIR=${dlfcn-win32_DIR}>
145+
$<$<BOOL:${WIN32}>:-Dliblzma_DIR=${liblzma_DIR}>
146+
BUILD_BYPRODUCTS
147+
"${__AOTRITON_LIB}"
69148
USES_TERMINAL_DOWNLOAD TRUE
70149
USES_TERMINAL_CONFIGURE TRUE
71150
USES_TERMINAL_BUILD TRUE
72151
USES_TERMINAL_INSTALL TRUE
73152
)
153+
if(WIN32)
154+
add_dependencies(${project} dlfcn-win32_external xz_external)
155+
endif()
74156
endfunction()
75157

76158
set(__AOTRITON_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR})
@@ -95,7 +177,7 @@ if(NOT __AOTRITON_INCLUDED)
95177
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
96178
"${CMAKE_CURRENT_BINARY_DIR}/aotriton_runtime"
97179
"${__AOTRITON_INSTALL_DIR}"
98-
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so"
180+
BUILD_BYPRODUCTS "${__AOTRITON_LIB}"
99181
)
100182
message(STATUS "Using AOTriton Runtime from pre-compiled binary ${__AOTRITON_URL}.\
101183
Set env variables AOTRITON_INSTALL_FROM_SOURCE=1 to build from source.")
@@ -111,14 +193,35 @@ if(NOT __AOTRITON_INCLUDED)
111193
string(CONCAT __AOTRITON_URL
112194
"${__AOTRITON_BASE_URL}"
113195
"${__AOTRITON_VER}/${__AOTRITON_FILE}")
196+
197+
# Set up directories
198+
set(__AOTRITON_DOWNLOAD_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_download-${image})
199+
set(__AOTRITON_EXTRACT_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image})
200+
set(__AOTRITON_INSTALL_SOURCE_DIR ${__AOTRITON_EXTRACT_DIR})
201+
set(__DOWNLOAD_NO_EXTRACT "")
202+
set(__BUILD_COMMANDS "")
203+
204+
# On Windows, we need custom tar extraction with UTF-8 support
205+
if(WIN32)
206+
set(__DOWNLOAD_NO_EXTRACT "DOWNLOAD_NO_EXTRACT;TRUE")
207+
set(__BUILD_COMMANDS
208+
COMMAND ${CMAKE_COMMAND} -E make_directory "${__AOTRITON_EXTRACT_DIR}"
209+
COMMAND tar --options hdrcharset=UTF-8 -xf "${__AOTRITON_DOWNLOAD_DIR}/${__AOTRITON_FILE}" -C "${__AOTRITON_EXTRACT_DIR}"
210+
)
211+
set(__AOTRITON_INSTALL_SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}/aotriton)
212+
endif()
213+
114214
ExternalProject_Add(${project}
115215
URL "${__AOTRITON_URL}"
116216
URL_HASH SHA256=${__AOTRITON_SHA256}
117-
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}
217+
DOWNLOAD_DIR ${__AOTRITON_DOWNLOAD_DIR}
218+
${__DOWNLOAD_NO_EXTRACT}
219+
SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}
118220
CONFIGURE_COMMAND ""
119221
BUILD_COMMAND ""
222+
${__BUILD_COMMANDS}
120223
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
121-
"${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}"
224+
"${__AOTRITON_INSTALL_SOURCE_DIR}"
122225
"${__AOTRITON_INSTALL_DIR}"
123226
BUILD_BYPRODUCTS
124227
"${__AOTRITON_INSTALL_DIR}/lib/aotriton.images/${image}/__signature__"
@@ -164,7 +267,7 @@ if(NOT __AOTRITON_INCLUDED)
164267
endforeach()
165268
endforeach()
166269
endif()
167-
target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so)
270+
target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_LIB})
168271
target_include_directories(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include)
169272
set(AOTRITON_FOUND TRUE)
170273
endif() # __AOTRITON_INCLUDED

tools/linter/dictionary.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ BU
1212
contiguities
1313
contiguity
1414
coo
15+
DEPENDEES
1516
deser
1617
din
1718
dout

0 commit comments

Comments
 (0)