Skip to content

Commit 8b73e5a

Browse files
authored
[BENCHMARK] Integrate CUTLASS's FlashAttention into our Triton benchmarking (#4513)
This PR integrates CUTLASS's FlashAttention (FWD) into our Triton benchmarking. The commits are organized into four parts: 1. Restructure CUTLASS Directory: The CUTLASS benchmark directory is restructured to use a single main file that includes specific headers for each kernel. At the moment, we support both GEMM and FA. Some headers shared between the two have been moved to the main `.cpp` file 2. Disable XeTLA check_close: As discussed with @mfrancepillois, running the FA benchmark takes too long due to check_close. Therefore, this PR disables XeTLA’s result checking by default for FlashAttention. However, CUTLASS results are still validated against PyTorch 3. Add CUTLASS FA Forward Kernel: The forward kernel for CUTLASS FlashAttention is added. The backward mode currently only sets the benchmark’s expected values to NaN to clearly indicate that it is not supported 4. The CI is updated to run the CUTLASS FA forward benchmark. The backward mode for CUTLASS is not triggered by the CI --------- Signed-off-by: Jefferson Le Quellec <[email protected]>
1 parent bb5a79a commit 8b73e5a

File tree

11 files changed

+487
-196
lines changed

11 files changed

+487
-196
lines changed

.github/workflows/triton-benchmarks.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ jobs:
277277
source ../../scripts/capture-hw-details.sh
278278
python build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-triton-report.csv --benchmark flash-attn --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
279279
python build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-xetla-report.csv --benchmark flash-attn --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
280+
python build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-cutlass-report.csv --benchmark flash-attn --compiler cutlass --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col CUTLASS-TFlops --hbm_col "CUTLASS-GB/s" --tag $TAG
280281
281282
- name: Run Triton FA bwd kernel benchmark
282283
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flash_attention_bwd_benchmark.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flash_attention_bwd_benchmark.py') }}
@@ -302,6 +303,7 @@ jobs:
302303
source ../../scripts/capture-hw-details.sh
303304
python build_report.py $REPORTS/attn-tensor-desc-performance.csv $REPORTS/attn-tensor-desc-triton-report.csv --benchmark flash-attn-tensor-desc --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
304305
python build_report.py $REPORTS/attn-tensor-desc-performance.csv $REPORTS/attn-tensor-desc-xetla-report.csv --benchmark flash-attn-tensor-desc --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
306+
python build_report.py $REPORTS/attn-tensor-desc-performance.csv $REPORTS/attn-tensor-desc-cutlass-report.csv --benchmark flash-attn-tensor-desc --compiler cutlass --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col CUTLASS-TFlops --hbm_col "CUTLASS-GB/s" --tag $TAG
305307
306308
- name: Run Triton FlexAttention Causal Mask fwd kernel benchmark
307309
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flex_attention_benchmark_causal_mask.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_benchmark_causal_mask.py') }}

benchmarks/cmake/FindCUTLASSLibrary.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ if (NOT CUTLASSLibrary_FOUND)
2727

2828
set(CUTLASSLibrary_INCLUDE_DIR "${CUTLASSLibrary_SOURCE_DIR}/include" CACHE INTERNAL "CUTLASSLibrary_SOURCE_DIR")
2929
set(CUTLASSLibrary_INCLUDE_TOOL_DIR "${CUTLASSLibrary_SOURCE_DIR}/tools/util/include" CACHE INTERNAL "CUTLASSLibrary_SOURCE_DIR")
30+
set(CUTLASSLibrary_INCLUDE_APPLICATION_DIR "${CUTLASSLibrary_SOURCE_DIR}/applications" CACHE INTERNAL "CUTLASSLibrary_SOURCE_DIR")
3031

3132
find_package_handle_standard_args(
3233
CUTLASSLibrary

benchmarks/cutlass_kernel/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ target_compile_options(cutlass_kernel PRIVATE "-DSYCL_INTEL_TARGET")
1616
target_link_options(cutlass_kernel PRIVATE ${CUTLASS_KERNEL_FLAGS})
1717
target_link_libraries(cutlass_kernel PUBLIC ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY})
1818

19-
target_include_directories(cutlass_kernel PUBLIC "${CUTLASSLibrary_INCLUDE_DIR}" "${CUTLASSLibrary_INCLUDE_TOOL_DIR}")
19+
target_include_directories(cutlass_kernel PUBLIC "${CUTLASSLibrary_INCLUDE_DIR}" "${CUTLASSLibrary_INCLUDE_TOOL_DIR}" "${CUTLASSLibrary_INCLUDE_APPLICATION_DIR}")
20+
21+
add_subdirectory(gemm)
22+
add_subdirectory(attention)
2023

2124
install(TARGETS cutlass_kernel LIBRARY DESTINATION .)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
target_include_directories(cutlass_kernel PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
#include "flash_attention_v2/collective/fmha_fusion.hpp"
2+
#include "flash_attention_v2/collective/xe_flash_attn_prefill_epilogue.hpp"
3+
#include "flash_attention_v2/collective/xe_flash_attn_prefill_mma.hpp"
4+
#include "flash_attention_v2/collective/xe_flash_attn_prefill_softmax_epilogue.hpp"
5+
#include "flash_attention_v2/kernel/tile_scheduler.hpp"
6+
#include "flash_attention_v2/kernel/xe_flash_attn_prefill.hpp"
7+
8+
#include "cutlass/gemm/dispatch_policy.hpp"
9+
10+
#include <exception>
11+
#include <iostream>
12+
13+
////////////////////////////////////////////////////////////////////////////////
14+
// PRIVATE FUNCTION
15+
////////////////////////////////////////////////////////////////////////////////
16+
17+
template <typename FMHA> static auto run(typename FMHA::Params params) -> void {
18+
cute::dim3 const block = FMHA::get_block_shape();
19+
cute::dim3 const grid = FMHA::get_grid_shape(params);
20+
21+
int smem_size = FMHA::SharedStorageSize;
22+
23+
const auto sycl_block = syclcompat::dim3(block.x, block.y, block.z);
24+
const auto sycl_grid = syclcompat::dim3(grid.x, grid.y, grid.z);
25+
26+
#if !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY)
27+
using namespace syclcompat::experimental;
28+
auto event = launch<cutlass::device_kernel<FMHA>>(
29+
launch_policy{
30+
sycl_grid, sycl_block,
31+
local_mem_size{static_cast<std::size_t>(smem_size)},
32+
kernel_properties{
33+
sycl_exp::sub_group_size<FMHA::DispatchPolicy::SubgroupSize>}},
34+
params);
35+
#else
36+
syclcompat::experimental::launch_properties launch_props{
37+
sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size),
38+
};
39+
syclcompat::experimental::kernel_properties kernel_props{
40+
sycl::ext::oneapi::experimental::sub_group_size<
41+
FMHA::DispatchPolicy::SubgroupSize>};
42+
syclcompat::experimental::launch_policy policy{sycl_grid, sycl_block,
43+
launch_props, kernel_props};
44+
auto event = syclcompat::experimental::launch<cutlass::device_kernel<FMHA>>(
45+
policy, params);
46+
#endif
47+
48+
EventManager::getInstance().addEvent(event);
49+
}
50+
51+
template <bool Causal, typename TileShapeQK, typename TileShapePV,
52+
typename TileShapeOutput, typename SubgroupLayout, int PipelineStages>
53+
static auto attention_run(const at::Tensor &Q, const at::Tensor &K,
54+
const at::Tensor &V, at::Tensor &O, int Batch,
55+
int NumHeadsQ, int NumHeadsKV, int SeqLengthQO,
56+
int SeqLengthKV, int HeadSizeQK, int HeadSizeVO,
57+
float sm_scale) -> int {
58+
RECORD_FUNCTION("cutlass fa", {});
59+
60+
using ElementAccumulator = float;
61+
using ElementInputQ = cutlass::half_t;
62+
using ElementInputKV = cutlass::half_t;
63+
using ElementOutput = float;
64+
65+
using LayoutQ = cutlass::layout::RowMajor;
66+
using LayoutK = cutlass::layout::ColumnMajor;
67+
using LayoutV = cutlass::layout::RowMajor;
68+
using LayoutO = cutlass::layout::RowMajor;
69+
70+
using GEMMDispatchPolicy =
71+
cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
72+
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;
73+
74+
using MMAOperation = cute::XE_8x16x16_F32F16F16F32_TT;
75+
76+
using GmemTiledCopyQ = cute::XE_2D_U16x8x32_LD_N;
77+
using GmemTiledCopyK = cute::XE_2D_U16x16x16_LD_T;
78+
using GmemTiledCopyV = cute::XE_2D_U16x16x32_LD_V;
79+
using GmemTiledCopyStore = cute::XE_2D_U32x8x16_ST_N;
80+
81+
using ProblemShapeType = cute::tuple<int, int, int, int, int, int, int>;
82+
83+
/// MAIN LOOP ///
84+
85+
using CollectiveMainloop =
86+
cutlass::flash_attention::collective::FlashPrefillMma<
87+
GEMMDispatchPolicy, ProblemShapeType, ElementInputQ,
88+
cutlass::gemm::TagToStrideA_t<LayoutQ>, ElementInputKV,
89+
cutlass::gemm::TagToStrideB_t<LayoutK>, ElementInputKV,
90+
cutlass::gemm::TagToStrideB_t<LayoutV>, MMAOperation, TileShapeQK,
91+
TileShapePV, SubgroupLayout,
92+
GmemTiledCopyQ, // Q
93+
GmemTiledCopyK, // K
94+
GmemTiledCopyV, // V,
95+
Causal>;
96+
97+
/// EPILOGUE LOOP ///
98+
99+
using CollectiveSoftmaxEpilogue =
100+
cutlass::flash_attention::collective::FlashPrefillSoftmaxEpilogue<
101+
Causal, EpilogueDispatchPolicy, ElementAccumulator>;
102+
using CollectiveEpilogue =
103+
cutlass::flash_attention::collective::FlashPrefillEpilogue<
104+
EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, SubgroupLayout,
105+
ElementAccumulator, cutlass::gemm::TagToStrideC_t<LayoutO>,
106+
ElementOutput, GmemTiledCopyStore>;
107+
108+
/// FA ///
109+
110+
using FMHAPrefillKernel = cutlass::flash_attention::kernel::FMHAPrefill<
111+
ProblemShapeType, CollectiveMainloop, CollectiveSoftmaxEpilogue,
112+
CollectiveEpilogue>;
113+
114+
/// FA INVOCATION ///
115+
116+
try {
117+
/// Buffer Initialization
118+
const cutlass::half_t *_Q =
119+
static_cast<const cutlass::half_t *>(Q.data_ptr());
120+
const cutlass::half_t *_K =
121+
static_cast<const cutlass::half_t *>(K.data_ptr());
122+
const cutlass::half_t *_V =
123+
static_cast<const cutlass::half_t *>(V.data_ptr());
124+
const float *_O = static_cast<const float *>(O.data_ptr());
125+
126+
/// Problem size
127+
using ProblemShapeType = typename FMHAPrefillKernel::ProblemShape;
128+
ProblemShapeType problem_size =
129+
ProblemShapeType{Batch, NumHeadsQ, NumHeadsKV, SeqLengthQO,
130+
SeqLengthKV, HeadSizeQK, HeadSizeVO};
131+
132+
/// Stride
133+
using StrideQ = typename FMHAPrefillKernel::StrideQ;
134+
using StrideK = typename FMHAPrefillKernel::StrideK;
135+
using StrideV = typename FMHAPrefillKernel::StrideV;
136+
using StrideO = typename FMHAPrefillKernel::StrideO;
137+
StrideQ stride_Q = cutlass::make_cute_packed_stride(
138+
StrideQ{},
139+
cute::make_shape(SeqLengthQO, HeadSizeQK, Batch * NumHeadsQ));
140+
StrideK stride_K = cutlass::make_cute_packed_stride(
141+
StrideK{},
142+
cute::make_shape(SeqLengthKV, HeadSizeQK, Batch * NumHeadsKV));
143+
StrideV stride_V = cutlass::make_cute_packed_stride(
144+
StrideV{},
145+
cute::make_shape(HeadSizeVO, SeqLengthKV, Batch * NumHeadsKV));
146+
StrideO stride_O = cutlass::make_cute_packed_stride(
147+
StrideO{},
148+
cute::make_shape(SeqLengthQO, HeadSizeVO, Batch * NumHeadsQ));
149+
150+
static cutlass::KernelHardwareInfo hw_info;
151+
if (hw_info.sm_count == 0) {
152+
hw_info.sm_count =
153+
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0);
154+
CUTLASS_TRACE_HOST(
155+
"Query result for SM count per device: " << hw_info.sm_count);
156+
}
157+
158+
typename FMHAPrefillKernel::Arguments arguments = {
159+
cutlass::gemm::GemmUniversalMode::kGemm,
160+
problem_size,
161+
{_Q, stride_Q, _K, stride_K, _V, stride_V},
162+
{sm_scale},
163+
{_O, stride_O},
164+
hw_info};
165+
166+
size_t workspace_size = FMHAPrefillKernel::get_workspace_size(arguments);
167+
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
168+
auto workspace_ptr = workspace.get();
169+
170+
if (!FMHAPrefillKernel::can_implement(arguments)) {
171+
std::cout << "Invalid Problem Size: " << Batch << 'x' << NumHeadsQ << 'x'
172+
<< SeqLengthQO << 'x' << SeqLengthKV << 'x' << HeadSizeQK << 'x'
173+
<< HeadSizeVO << (Causal ? "xCausal" : "xNonCausal")
174+
<< std::endl;
175+
return -1;
176+
}
177+
178+
CUTLASS_CHECK(
179+
FMHAPrefillKernel::initialize_workspace(arguments, workspace_ptr));
180+
auto params =
181+
FMHAPrefillKernel::to_underlying_arguments(arguments, workspace_ptr);
182+
run<FMHAPrefillKernel>(params);
183+
184+
syclcompat::wait();
185+
186+
} catch (std::exception &e) {
187+
std::cerr << "Runtime error: " << e.what() << std::endl;
188+
return -1;
189+
} catch (...) {
190+
std::cerr << "Unexpected error" << std::endl;
191+
return -1;
192+
}
193+
194+
return 0;
195+
}
196+
197+
////////////////////////////////////////////////////////////////////////////////
198+
// PUBLIC FUNCTION
199+
////////////////////////////////////////////////////////////////////////////////
200+
201+
using FARunPtr = int (*)(const at::Tensor &Q, const at::Tensor &K,
202+
const at::Tensor &V, at::Tensor &O, int Batch,
203+
int NumHeadsQ, int NumHeadsKV, int SeqLengthQO,
204+
int SeqLengthKV, int HeadSizeQK, int HeadSizeVO,
205+
float sm_scale);
206+
207+
auto attention(const at::Tensor &Q, const at::Tensor &K, const at::Tensor &V,
208+
at::Tensor &O, int Batch, int NumHeadsQ, int NumHeadsKV,
209+
int SeqLengthQO, int SeqLengthKV, int HeadSizeQK, int HeadSizeVO,
210+
bool Causal, float sm_scale) -> int {
211+
constexpr int PipelineStages = 2;
212+
FARunPtr f = nullptr;
213+
214+
if (HeadSizeVO == 64) {
215+
using ShapeQK = cute::Shape<cute::_128, cute::_64, cute::_64>;
216+
using ShapePV = cute::Shape<cute::_128, cute::_32, cute::_64>;
217+
using ShapeOutPut = cute::Shape<cute::_128, cute::_64, cute::_64>;
218+
using SubgroupLayout =
219+
cute::Layout<cute::Shape<cute::_8, cute::_1, cute::_1>,
220+
cute::Stride<cute::_1, cute::_1, cute::_1>>;
221+
222+
f = Causal ? attention_run<true, ShapeQK, ShapePV, ShapeOutPut,
223+
SubgroupLayout, PipelineStages>
224+
: attention_run<false, ShapeQK, ShapePV, ShapeOutPut,
225+
SubgroupLayout, PipelineStages>;
226+
227+
} else if (HeadSizeVO == 128) {
228+
using ShapeQK = cute::Shape<cute::_128, cute::_64, cute::_64>;
229+
using ShapePV = cute::Shape<cute::_128, cute::_32, cute::_64>;
230+
using ShapeOutPut = cute::Shape<cute::_128, cute::_128, cute::_64>;
231+
using SubgroupLayout =
232+
cute::Layout<cute::Shape<cute::_16, cute::_1, cute::_1>,
233+
cute::Stride<cute::_1, cute::_1, cute::_1>>;
234+
235+
f = Causal ? attention_run<true, ShapeQK, ShapePV, ShapeOutPut,
236+
SubgroupLayout, PipelineStages>
237+
: attention_run<false, ShapeQK, ShapePV, ShapeOutPut,
238+
SubgroupLayout, PipelineStages>;
239+
} else {
240+
std::cerr << "Unsupported HeadSizeVO: " << HeadSizeVO << std::endl;
241+
return -1;
242+
}
243+
244+
return f(Q, K, V, O, Batch, NumHeadsQ, NumHeadsKV, SeqLengthQO, SeqLengthKV,
245+
HeadSizeQK, HeadSizeVO, sm_scale);
246+
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
dee33709bdc0cc579df49f251da894d4546b2624
1+
dd43242ea2f3e08e73a73153f00a5dbe5a31c41c
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
target_include_directories(cutlass_kernel PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})

0 commit comments

Comments
 (0)