Skip to content

Commit 504461b

Browse files
authored
[Iluvatar GPU] Optimize attention performance and fix moe load ckpt error (#3651)
1 parent 5532e8a commit 504461b

File tree

17 files changed

+1339
-358
lines changed

17 files changed

+1339
-358
lines changed

.github/workflows/ci_iluvatar.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,22 @@ jobs:
2828
REPO="https://github.com/${{ github.repository }}.git"
2929
FULL_REPO="${{ github.repository }}"
3030
REPO_NAME="${FULL_REPO##*/}"
31+
BASE_BRANCH="${{ github.base_ref }}"
3132
# Clean the repository directory before starting
3233
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
3334
-e "REPO_NAME=${REPO_NAME}" \
35+
-e "BASE_BRANCH=${BASE_BRANCH}" \
3436
${docker_image} /bin/bash -c '
3537
if [ -d ${REPO_NAME} ]; then
3638
echo "Directory ${REPO_NAME} exists, removing it..."
3739
rm -rf ${REPO_NAME}
3840
fi
3941
'
42+
git config --global http.proxy "http://61.151.249.150:33128"
43+
git config --global https.proxy "http://61.151.249.150:33128"
4044
git config --global user.name "FastDeployCI"
4145
git config --global user.email "[email protected]"
42-
git clone ${REPO} ${REPO_NAME}
46+
git clone --recursive ${REPO} ${REPO_NAME} -b ${BASE_BRANCH}
4347
cd FastDeploy
4448
if [ "${{ github.event_name }}" = "pull_request" ]; then
4549
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}

custom_ops/gpu_ops/helper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,13 @@ template <> class PDTraits<paddle::DataType::UINT8> {
193193
typedef uint8_t data_t;
194194
};
195195

196+
#ifndef PADDLE_WITH_COREX
196197
template <> class PDTraits<paddle::DataType::FLOAT8_E4M3FN> {
197198
public:
198199
typedef __nv_fp8_e4m3 DataType;
199200
typedef paddle::float8_e4m3fn data_t;
200201
};
202+
#endif
201203

202204
template <typename T, int Size> struct alignas(sizeof(T) * Size) AlignedVector {
203205
T val[Size];

custom_ops/iluvatar_ops/mixed_fused_attn.cu

Lines changed: 376 additions & 0 deletions
Large diffs are not rendered by default.

custom_ops/iluvatar_ops/moe_dispatch.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ void MoeDispatchKernel(const paddle::Tensor& input,
5353
const paddle::optional<paddle::Tensor>& gating_correction_bias,
5454
const int moe_topk,
5555
const bool group_moe,
56+
const std::string &moe_quant_type,
5657
const bool topk_only_mode,
5758
const int num_rows,
5859
const int hidden_size,
@@ -183,6 +184,7 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
183184
const paddle::optional<paddle::Tensor>& w4a8_in_scale,
184185
const int moe_topk,
185186
const bool group_moe,
187+
const std::string &moe_quant_type,
186188
const bool topk_only_mode) {
187189
const auto input_type = input.dtype();
188190
auto place = input.place();
@@ -220,6 +222,7 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
220222
gating_correction_bias,
221223
moe_topk,
222224
group_moe,
225+
moe_quant_type,
223226
topk_only_mode,
224227
num_rows,
225228
hidden_size,
@@ -236,6 +239,7 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
236239
gating_correction_bias,
237240
moe_topk,
238241
group_moe,
242+
moe_quant_type,
239243
topk_only_mode,
240244
num_rows,
241245
hidden_size,
@@ -305,7 +309,7 @@ PD_BUILD_STATIC_OP(moe_expert_dispatch)
305309
"top_k_weight",
306310
"top_k_indices",
307311
"expert_idx_per_token"})
308-
.Attrs({"moe_topk:int", "group_moe:bool", "topk_only_mode:bool"})
312+
.Attrs({"moe_topk:int", "group_moe:bool", "moe_quant_type:std::string", "topk_only_mode:bool"})
309313
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
310314
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))
311315
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype));

custom_ops/iluvatar_ops/paged_attn.cu

Lines changed: 33 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ void PagedAttnKernel(const paddle::Tensor& q,
2727
const paddle::optional<paddle::Tensor> &v,
2828
const paddle::optional<paddle::Tensor> &rope_sin,
2929
const paddle::optional<paddle::Tensor> &rope_cos,
30+
int num_heads,
31+
int head_dim,
3032
int num_kv_heads,
3133
float scale,
3234
int block_size,
@@ -86,32 +88,36 @@ void PagedAttnKernel(const paddle::Tensor& q,
8688
common::errors::InvalidArgument(
8789
"paged_attention expects seq_lens is contiguous"));
8890
// check dim and shape
89-
// k_cache: [num_blocks, kv_num_heads, block_size, head_size]
90-
// v_cache: [num_blocks, kv_num_heads, block_size, head_size]
91+
// k_cache: [num_blocks, kv_num_heads, block_size, head_dim]
92+
// v_cache: [num_blocks, kv_num_heads, block_size, head_dim]
9193
// block_table: [num_seqs, max_num_blocks_per_seq]
9294
// seq_lens: [num_seqs]
9395
// q and out:
94-
// merged_qkv = false: [num_seqs, num_heads, head_size]
95-
// merged_qkv = true: [num_seqs, num_heads+2*num_kv_heads, head_size]
96+
// if merged_qkv = false:
97+
// q:[num_seqs, hidden_size]
98+
// out:[num_seqs, hidden_size]
99+
// if merged_qkv = true:
100+
// q: [num_seqs, (num_heads+2*num_kv_heads)*head_dim]
101+
// out: [num_seqs, hidden_size]
96102

97103
const auto& q_dims = q.dims();
98104
PADDLE_ENFORCE_EQ(q_dims.size(),
99-
3,
105+
2,
100106
common::errors::InvalidArgument(
101107
"paged_attn receive query dims is "
102-
"[num_seqs, num_heads, head_size]"));
108+
"[num_seqs, (num_heads+2*num_kv_heads)*head_dim]"));
103109
PADDLE_ENFORCE_EQ(out.dims().size(),
104-
3,
110+
2,
105111
common::errors::InvalidArgument(
106112
"paged_attn receive out dims is "
107-
"[num_seqs, num_heads, head_size]"));
113+
"[num_seqs, hidden_size]"));
108114

109115
const auto& kv_cache_dims = k_cache.dims();
110116
PADDLE_ENFORCE_EQ(kv_cache_dims.size(),
111117
4,
112118
common::errors::InvalidArgument(
113119
"paged_attn receive kv cache dims is "
114-
"[num_blocks, kv_num_heads, block_size, head_size]"));
120+
"[num_blocks, kv_num_heads, block_size, head_dim]"));
115121

116122
const auto& block_table_dims = block_table.dims();
117123
PADDLE_ENFORCE_EQ(block_table_dims.size(),
@@ -127,8 +133,6 @@ void PagedAttnKernel(const paddle::Tensor& q,
127133
"paged_attn receive seq_lens dims is [num_seqs]"));
128134

129135
int num_seqs = q_dims[0];
130-
int num_heads = merged_qkv ? q_dims[1] - 2 * num_kv_heads : q_dims[1];
131-
int head_size = q_dims[2];
132136
int max_num_blocks_per_seq = block_table_dims[1];
133137
int q_stride = q.strides()[0];
134138
int num_blocks = kv_cache_dims[0];
@@ -142,9 +146,9 @@ void PagedAttnKernel(const paddle::Tensor& q,
142146
common::errors::InvalidArgument(
143147
"kv_cache_dims[2] must be equal to block_size"));
144148
PADDLE_ENFORCE_EQ(kv_cache_dims[3],
145-
head_size,
149+
head_dim,
146150
common::errors::InvalidArgument(
147-
"kv_cache_dims[3] must be equal to head_size"));
151+
"kv_cache_dims[3] must be equal to head_dim"));
148152
PADDLE_ENFORCE_EQ(block_table_dims[0],
149153
num_seqs,
150154
common::errors::InvalidArgument(
@@ -162,14 +166,13 @@ void PagedAttnKernel(const paddle::Tensor& q,
162166
const float *rope_sin_ptr = merged_qkv ? rope_sin.get().data<float>() : nullptr;
163167
const float *rope_cos_ptr = merged_qkv ? rope_cos.get().data<float>() : nullptr;
164168

165-
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(q.place()));
166169
cuinferHandle_t cuinfer_handle = iluvatar::getContextInstance()->getIxInferHandle();
167170

168171
size_t workspace_size = 0;
169172
CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7(num_seqs,
170173
num_heads,
171174
num_kv_heads,
172-
head_size,
175+
head_dim,
173176
block_size,
174177
max_context_len,
175178
&workspace_size));
@@ -189,7 +192,7 @@ void PagedAttnKernel(const paddle::Tensor& q,
189192
num_seqs,
190193
num_heads,
191194
num_kv_heads,
192-
head_size,
195+
head_dim,
193196
q_stride,
194197
kv_block_stride,
195198
kv_head_stride,
@@ -215,6 +218,8 @@ std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
215218
const paddle::optional<paddle::Tensor> &v,
216219
const paddle::optional<paddle::Tensor> &rope_sin,
217220
const paddle::optional<paddle::Tensor> &rope_cos,
221+
int num_heads,
222+
int head_dim,
218223
int num_kv_heads,
219224
float scale,
220225
int block_size,
@@ -228,11 +233,7 @@ std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
228233
bool merged_qkv) {
229234

230235
const auto dtype = q.dtype();
231-
auto out_shape = q.shape();
232-
if (merged_qkv) {
233-
out_shape[1] -= 2 * num_kv_heads;
234-
}
235-
auto out = paddle::empty(out_shape, dtype, q.place());
236+
auto out = paddle::empty({q.shape()[0], num_heads * head_dim}, dtype, q.place());
236237

237238
switch (dtype) {
238239
case paddle::DataType::BFLOAT16:
@@ -246,6 +247,8 @@ std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
246247
v,
247248
rope_sin,
248249
rope_cos,
250+
num_heads,
251+
head_dim,
249252
num_kv_heads,
250253
scale,
251254
block_size,
@@ -270,6 +273,8 @@ std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
270273
v,
271274
rope_sin,
272275
rope_cos,
276+
num_heads,
277+
head_dim,
273278
num_kv_heads,
274279
scale,
275280
block_size,
@@ -299,6 +304,8 @@ std::vector<std::vector<int64_t>> PagedAttnInferShape(const std::vector<int64_t>
299304
const std::vector<int64_t>& v_shape,
300305
const std::vector<int64_t>& rope_sin_shape,
301306
const std::vector<int64_t>& rope_cos_shape,
307+
int num_heads,
308+
int head_dim,
302309
int num_kv_heads,
303310
float scale,
304311
int block_size,
@@ -311,36 +318,13 @@ std::vector<std::vector<int64_t>> PagedAttnInferShape(const std::vector<int64_t>
311318
bool use_sqrt_alibi,
312319
bool merged_qkv) {
313320
if (merged_qkv) {
314-
int64_t num_tokens = q_shape[0];
315-
int64_t num_heads = q_shape[1] - 2 * num_kv_heads;
316-
int64_t head_dim = q_shape[2];
317-
return {{num_tokens, num_heads, head_dim}};
321+
return {{q_shape[0], num_heads * head_dim}};
318322
} else {
319323
return {q_shape};
320324
}
321325
}
322326

323-
std::vector<paddle::DataType> PagedAttnInferDtype(const paddle::DataType& q_dtype,
324-
const paddle::DataType& k_cache_dtype,
325-
const paddle::DataType& v_cache_dtype,
326-
const paddle::DataType& block_table_dtype,
327-
const paddle::DataType& seq_lens_dtype,
328-
const paddle::DataType& alibi_slopes_dtype,
329-
const paddle::DataType& k_dtype,
330-
const paddle::DataType& v_dtype,
331-
const paddle::DataType& rope_sin_dtype,
332-
const paddle::DataType& rope_cos_dtype,
333-
int num_kv_heads,
334-
float scale,
335-
int block_size,
336-
int max_context_len,
337-
bool causal,
338-
int window_left,
339-
int window_right,
340-
float softcap,
341-
bool enable_cuda_graph,
342-
bool use_sqrt_alibi,
343-
bool merged_qkv) {
327+
std::vector<paddle::DataType> PagedAttnInferDtype(const paddle::DataType& q_dtype) {
344328
return {q_dtype};
345329
}
346330

@@ -351,7 +335,9 @@ PD_BUILD_STATIC_OP(paged_attn)
351335
paddle::Optional("v"), paddle::Optional("rope_sin"),
352336
paddle::Optional("rope_cos")})
353337
.Outputs({"out"})
354-
.Attrs({"num_kv_heads:int",
338+
.Attrs({"num_heads:int",
339+
"head_dim:int",
340+
"num_kv_heads:int",
355341
"scale:float",
356342
"block_size:int",
357343
"max_context_len:int",

0 commit comments

Comments
 (0)