Skip to content

Commit 40f7f3e

Browse files
[New Feature] fa3 支持flash mask (#3184)
* 支持flash mask * 修改test_flash_mask * 修改test.sh
1 parent b8f3c73 commit 40f7f3e

File tree

8 files changed

+1702
-0
lines changed

8 files changed

+1702
-0
lines changed
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
/******************************************************************************
2+
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3+
******************************************************************************/
4+
5+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
19+
#include "paddle/extension.h"
20+
#include "kernel_traits.h"
21+
#include "flash_mask_attn_kernel.hpp"
22+
23+
template <typename paddle_type>
24+
struct cuteType;
25+
26+
template <>
27+
struct cuteType<phi::dtype::float16> {
28+
using type = cutlass::half_t;
29+
};
30+
31+
template <>
32+
struct cuteType<phi::dtype::bfloat16> {
33+
using type = cutlass::bfloat16_t;
34+
};
35+
36+
template <typename T>
37+
std::vector<paddle::Tensor> DispatchFlashAttentionMask(
38+
const paddle::Tensor& q_input,
39+
const paddle::Tensor& k_input,
40+
const paddle::Tensor& v_input,
41+
const paddle::Tensor& cu_seq_q,
42+
const paddle::Tensor& cu_seq_k,
43+
const paddle::Tensor& seq_len_encoder,
44+
const paddle::optional<paddle::Tensor>& mask,
45+
const int head_num,
46+
const int kv_head_num,
47+
const int head_dim,
48+
const int max_seq_len,
49+
const int max_enc_len_this_time,
50+
const int max_dec_len_this_time) {
51+
52+
constexpr int kBlockM = 128;
53+
constexpr int kBlockN = 128;
54+
const int batch_size = cu_seq_q.dims()[0];
55+
56+
paddle::Tensor out = paddle::empty(
57+
{q_input.dims()[0], head_num * head_dim}, q_input.dtype(), q_input.place());
58+
59+
Flash_mask_params params;
60+
memset(&params, 0, sizeof(Flash_mask_params));
61+
62+
params.q_ptr = const_cast<T*>(q_input.data<T>());
63+
params.k_ptr = const_cast<T*>(k_input.data<T>());
64+
params.v_ptr = const_cast<T*>(v_input.data<T>());
65+
params.o_ptr = const_cast<T*>(out.data<T>());
66+
params.cu_seq_q = const_cast<int*>(cu_seq_q.data<int>());
67+
params.cu_seq_k = const_cast<int*>(cu_seq_k.data<int>());
68+
params.seq_len_encoder = const_cast<int*>(seq_len_encoder.data<int>());
69+
params.head_num = head_num;
70+
params.kv_head_num = kv_head_num;
71+
params.max_seq_len_q = max_enc_len_this_time;
72+
params.max_seq_len_k = max_enc_len_this_time + max_dec_len_this_time;
73+
params.batch_size = batch_size;
74+
params.gqa_group_size = head_num / kv_head_num;
75+
constexpr float kLog2e = 1.4426950408889634074;
76+
params.scale_softmax_log2 = 1.0f / std::sqrt(head_dim) * kLog2e;
77+
78+
using cute_type = typename cuteType<T>::type;
79+
80+
if (mask) {
81+
params.mask = const_cast<int*>(mask.get().data<int>());
82+
flash_attn_headdim128<kBlockM, kBlockN, true, cute_type>(params, 0);
83+
} else {
84+
flash_attn_headdim128<kBlockM, kBlockN, false, cute_type>(params, 0);
85+
}
86+
87+
return {out};
88+
}
89+
90+
91+
std::vector<paddle::Tensor> FlashAttentionMask(
92+
const paddle::Tensor& q_input,
93+
const paddle::Tensor& k_input,
94+
const paddle::Tensor& v_input,
95+
const paddle::Tensor& cu_seq_q,
96+
const paddle::Tensor& cu_seq_k,
97+
const paddle::Tensor& seq_len_encoder,
98+
const paddle::optional<paddle::Tensor> &mask,
99+
const int head_num,
100+
const int kv_head_num,
101+
const int head_dim,
102+
const int max_seq_len,
103+
const int max_enc_len_this_time,
104+
const int max_dec_len_this_time) {
105+
106+
if (q_input.dtype() == paddle::DataType::FLOAT16) {
107+
using T = phi::dtype::float16;
108+
return std::move(
109+
DispatchFlashAttentionMask<T>(
110+
q_input,
111+
k_input,
112+
v_input,
113+
cu_seq_q,
114+
cu_seq_k,
115+
seq_len_encoder,
116+
mask,
117+
head_num,
118+
kv_head_num,
119+
head_dim,
120+
max_seq_len,
121+
max_enc_len_this_time,
122+
max_dec_len_this_time));
123+
} else if (q_input.dtype() == paddle::DataType::BFLOAT16) {
124+
using T = phi::dtype::bfloat16;
125+
return std::move(
126+
DispatchFlashAttentionMask<T>(
127+
q_input,
128+
k_input,
129+
v_input,
130+
cu_seq_q,
131+
cu_seq_k,
132+
seq_len_encoder,
133+
mask,
134+
head_num,
135+
kv_head_num,
136+
head_dim,
137+
max_seq_len,
138+
max_enc_len_this_time,
139+
max_dec_len_this_time));
140+
}
141+
142+
}
143+
144+
145+
PD_BUILD_OP(flash_attention_mask)
146+
.Inputs({
147+
"q_input",
148+
"k_input",
149+
"v_input",
150+
"cu_seq_q",
151+
"cu_seq_k",
152+
"seq_len_encoder",
153+
paddle::Optional("mask")})
154+
.Attrs({
155+
"head_num: int",
156+
"kv_head_num: int",
157+
"head_dim: int",
158+
"max_seq_len: int",
159+
"max_enc_len_this_time: int",
160+
"max_dec_len_this_time: int"})
161+
.Outputs({
162+
"out"})
163+
.SetKernelFn(PD_KERNEL(FlashAttentionMask));
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
/******************************************************************************
2+
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3+
******************************************************************************/
4+
5+
#pragma once
6+
#include "cute/algorithm/copy.hpp"
7+
#include "cute/atom/mma_atom.hpp"
8+
#include "cutlass/gemm/collective/collective_builder.hpp"
9+
10+
#include "cutlass/cutlass.h"
11+
#include "cutlass/layout/layout.h"
12+
#include "cutlass/numeric_types.h"
13+
#include "cutlass/pipeline/pipeline.hpp"
14+
#include "cutlass/cluster_launch.hpp"
15+
#include "cutlass/arch/reg_reconfig.h"
16+
17+
#include "kernel_traits.h"
18+
#include "mainloop_attn.hpp"
19+
#include "softmax.hpp"
20+
21+
using namespace cute;
22+
23+
template <int kHeadDim>
24+
auto get_gmem_layout(int token_num, int head_num) {
25+
return make_layout(
26+
make_shape(token_num, kHeadDim, head_num),
27+
make_stride(head_num * kHeadDim, cute::_1{}, kHeadDim));
28+
}
29+
30+
template <typename Ktraits>
31+
__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
32+
compute_attn_ws(
33+
CUTE_GRID_CONSTANT typename CollectiveMainloopAttn<Ktraits>::Params const mainloop_params,
34+
CUTE_GRID_CONSTANT Flash_mask_params const data_params) {
35+
36+
using Element = typename Ktraits::Element;
37+
using ElementAccum = typename Ktraits::ElementAccum;
38+
using SoftType = ElementAccum;
39+
using TileShape_MNK = typename Ktraits::TileShape_MNK;
40+
using ClusterShape = typename Ktraits::ClusterShape_MNK;
41+
42+
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
43+
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
44+
static constexpr int kBlockM = Ktraits::kBlockM;
45+
static constexpr int kBlockN = Ktraits::kBlockN;
46+
constexpr int kHeadDim = Ktraits::kHeadDim;
47+
constexpr bool NeedMask = Ktraits::NeedMask;
48+
49+
using CollectiveMainloop = CollectiveMainloopAttn<Ktraits>;
50+
51+
using MainloopPipeline = typename Ktraits::MainloopPipeline;
52+
using PipelineParams = typename MainloopPipeline::Params;
53+
using PipelineState = typename MainloopPipeline::PipelineState;
54+
55+
extern __shared__ char shared_memory[];
56+
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
57+
58+
__align__(16) __shared__ int mask[kBlockM];
59+
60+
const int m_block = blockIdx.x;
61+
const int bidh = blockIdx.y;
62+
const int bidb = blockIdx.z;
63+
64+
if constexpr (NeedMask) {
65+
const int *mask_this_batch = data_params.mask + data_params.cu_seq_q[bidb] + m_block * kBlockM;
66+
67+
for (int i = threadIdx.x; i < kBlockM; i += Ktraits::kNWarps * cutlass::NumThreadsPerWarp) {
68+
mask[i] = mask_this_batch[i];
69+
}
70+
}
71+
72+
const int seq_len_q = data_params.seq_len_encoder[bidb];
73+
const int seq_len_k = data_params.cu_seq_k[bidb + 1] - data_params.cu_seq_k[bidb];
74+
75+
if (m_block * kBlockM >= seq_len_q) {
76+
return;
77+
}
78+
79+
int const lane_predicate = cute::elect_one_sync();
80+
int const warp_idx = cutlass::canonical_warp_idx_sync();
81+
82+
if (warp_idx == 0 && lane_predicate) {
83+
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
84+
}
85+
86+
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
87+
88+
PipelineParams pipeline_params;
89+
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
90+
int warp_group_idx = cutlass::canonical_warp_group_idx();
91+
pipeline_params.role = warp_group_idx == 0
92+
? MainloopPipeline::ThreadCategory::Producer
93+
: MainloopPipeline::ThreadCategory::Consumer;
94+
pipeline_params.is_leader = warp_group_thread_idx == 0;
95+
pipeline_params.num_consumers = NumMmaThreads;
96+
97+
if (warp_idx == 0 && lane_predicate) {
98+
shared_storage.barrier_Q.init(1);
99+
}
100+
101+
MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{});
102+
MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{});
103+
104+
__syncthreads();
105+
106+
CollectiveMainloop collective_mainloop;
107+
108+
const int real_seq = seq_len_q - m_block * kBlockM;
109+
110+
const int n_block_max = NeedMask ? cute::ceil_div(mask[min(kBlockM - 1, real_seq - 1)], kBlockN) : cute::ceil_div((m_block + 1) * kBlockM + seq_len_k - seq_len_q, kBlockN);
111+
112+
if (warp_group_idx == 0) { // Producer
113+
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 8 ? 56 : 24>();
114+
115+
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
116+
if (warp_idx_in_warpgroup == 0) { // Load Q, K, V
117+
PipelineState smem_pipe_write_k = cutlass::make_producer_start_state<MainloopPipeline>();
118+
PipelineState smem_pipe_write_v = cutlass::make_producer_start_state<MainloopPipeline>();
119+
120+
collective_mainloop.load(
121+
mainloop_params,
122+
pipeline_k,
123+
pipeline_v,
124+
smem_pipe_write_k,
125+
smem_pipe_write_v,
126+
shared_storage,
127+
n_block_max,
128+
m_block,
129+
bidh,
130+
bidb,
131+
data_params.cu_seq_q,
132+
data_params.cu_seq_k,
133+
seq_len_q,
134+
seq_len_k);
135+
}
136+
} else { // Consumer
137+
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 8 ? 256 : 240>();
138+
typename Ktraits::TiledMma1 tiled_mma1;
139+
140+
PipelineState smem_pipe_read_k, smem_pipe_read_v;
141+
142+
Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
143+
Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax;
144+
145+
collective_mainloop.mma(
146+
mainloop_params,
147+
pipeline_k,
148+
pipeline_v,
149+
smem_pipe_read_k,
150+
smem_pipe_read_v,
151+
tOrO,
152+
softmax,
153+
mask,
154+
n_block_max,
155+
threadIdx.x - NumCopyThreads,
156+
m_block,
157+
seq_len_q,
158+
seq_len_k,
159+
shared_storage);
160+
161+
const int o_head_stride = data_params.head_num * kHeadDim;
162+
const int store_offset = (data_params.cu_seq_q[bidb] + m_block * kBlockM) * o_head_stride + bidh * kHeadDim;
163+
164+
collective_mainloop.store<NumMmaThreads>(
165+
mainloop_params,
166+
tOrO,
167+
shared_storage,
168+
tiled_mma1,
169+
threadIdx.x - NumCopyThreads,
170+
o_head_stride,
171+
real_seq,
172+
reinterpret_cast<Element*>(data_params.o_ptr) + store_offset);
173+
}
174+
175+
}
176+
177+
178+
template<typename Kernel_traits>
179+
void run_flash_mask(Flash_mask_params &params, cudaStream_t stream) {
180+
using Element = typename Kernel_traits::Element;
181+
using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
182+
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
183+
184+
using CollectiveMainloop = CollectiveMainloopAttn<Kernel_traits>;
185+
constexpr int kHeadDim = Kernel_traits::kHeadDim;
186+
187+
typename CollectiveMainloop::Params mainloop_params =
188+
CollectiveMainloop::to_underlying_arguments({
189+
static_cast<Element const*>(params.q_ptr),
190+
get_gmem_layout<kHeadDim>(params.max_seq_len_q, params.head_num),
191+
static_cast<Element const*>(params.k_ptr),
192+
get_gmem_layout<kHeadDim>(params.max_seq_len_k, params.kv_head_num),
193+
static_cast<Element const*>(params.v_ptr),
194+
get_gmem_layout<kHeadDim>(params.max_seq_len_k, params.kv_head_num),
195+
params.scale_softmax_log2
196+
});
197+
198+
int num_blocks_m = cutlass::ceil_div(params.max_seq_len_q, Kernel_traits::kBlockM);
199+
200+
num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
201+
202+
void *kernel;
203+
kernel = (void *)compute_attn_ws<Kernel_traits>;
204+
int smem_size = sizeof(typename Kernel_traits::SharedStorage);
205+
206+
if (smem_size >= 48 * 1024) {
207+
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
208+
}
209+
210+
dim3 grid_dims;
211+
grid_dims.x = num_blocks_m;
212+
grid_dims.y = params.head_num;
213+
grid_dims.z = params.batch_size;
214+
215+
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
216+
dim3 block_dims(ctaSize);
217+
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
218+
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
219+
cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, params);
220+
}
221+
222+
template <int kBlockM, int kBlockN, bool NeedMask, typename InputType>
223+
void flash_attn_headdim128(Flash_mask_params &params, cudaStream_t stream) {
224+
225+
constexpr static int Headdim = 128;
226+
constexpr static int kNWarps = kBlockM / 16 + 4;
227+
constexpr static int kStages = 2;
228+
229+
using Ktraits = Flash_mask_kernel_traits<Headdim, kBlockM, kBlockN, kNWarps, kStages, NeedMask, InputType>;
230+
run_flash_mask<Ktraits>(params, stream);
231+
}

0 commit comments

Comments
 (0)