Skip to content

Commit 1e4968e

Browse files
authored
[Excutor] Fixed the issue of CUDA graph execution failure caused by different branches during decoding (#3223)
* 彻底解决解码切块问题 * update C8 and C4 kernel * fix problem * fix with pre-commit * retain branch for mtp
1 parent 31d4fcb commit 1e4968e

File tree

3 files changed

+47
-47
lines changed

3 files changed

+47
-47
lines changed

custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,12 +1061,11 @@ void MultiQueryAppendAttention(
10611061
if (!is_decoder) {
10621062
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
10631063
}
1064-
const int num_chunks = div_up(max_dec_len, chunk_size);
10651064

1065+
const int num_chunks = div_up(max_seq_len, chunk_size);
10661066
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
10671067
dim3 blocks(32, num_warps);
1068-
1069-
if (num_chunks <= 1) {
1068+
if (num_chunks <= 0) {
10701069
auto nosplit_kv_kernel =
10711070
multi_query_append_attention_warp1_4_kernel<NV_TYPE,
10721071
false,
@@ -1161,8 +1160,8 @@ void MultiQueryAppendAttention(
11611160
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())),
11621161
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())),
11631162
shift_bias ? reinterpret_cast<NV_TYPE *>(
1164-
const_cast<T *>(shift_bias.get().data<T>()))
1165-
: nullptr,
1163+
const_cast<T *>(shift_bias.get().data<T>()))
1164+
: nullptr,
11661165
smooth_weight ? reinterpret_cast<NV_TYPE *>(
11671166
const_cast<T *>(smooth_weight.get().data<T>()))
11681167
: nullptr,
@@ -1208,8 +1207,8 @@ void MultiQueryAppendAttention(
12081207
seq_lens_encoder.data<int>(),
12091208
cu_seqlens_q.data<int>(),
12101209
shift_bias ? reinterpret_cast<NV_TYPE *>(
1211-
const_cast<T *>(shift_bias.get().data<T>()))
1212-
: nullptr,
1210+
const_cast<T *>(shift_bias.get().data<T>()))
1211+
: nullptr,
12131212
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
12141213
smooth_weight.get().data<T>()))
12151214
: nullptr,
@@ -1226,14 +1225,14 @@ void MultiQueryAppendAttention(
12261225
constexpr int blockx = HEAD_DIM / vec_size;
12271226
constexpr int blocky = (128 + blockx - 1) / blockx;
12281227
dim3 grids_merge(min(sm_count * 4, token_num),
1229-
num_heads);
1228+
num_heads);
12301229
dim3 blocks_merge(blockx, blocky);
12311230
merge_multi_chunks_v2_kernel<NV_TYPE,
1232-
vec_size,
1233-
blocky,
1234-
HEAD_DIM,
1235-
OUT_NV_TYPE,
1236-
ENABLE_PREFILL>
1231+
vec_size,
1232+
blocky,
1233+
HEAD_DIM,
1234+
OUT_NV_TYPE,
1235+
ENABLE_PREFILL>
12371236
<<<grids_merge, blocks_merge, 0, stream>>>(
12381237
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
12391238
static_cast<float *>(tmp_m->ptr()),
@@ -1244,8 +1243,8 @@ void MultiQueryAppendAttention(
12441243
batch_id_per_token.data<int>(),
12451244
cu_seqlens_q.data<int>(),
12461245
shift_bias ? reinterpret_cast<NV_TYPE *>(
1247-
const_cast<T *>(shift_bias.get().data<T>()))
1248-
: nullptr,
1246+
const_cast<T *>(shift_bias.get().data<T>()))
1247+
: nullptr,
12491248
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
12501249
smooth_weight.get().data<T>()))
12511250
: nullptr,

custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,10 +1285,11 @@ void MultiQueryAppendC4Attention(
12851285
if (!is_decoder) {
12861286
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
12871287
}
1288-
const int num_chunks = div_up(max_dec_len, chunk_size);
1288+
1289+
const int num_chunks = div_up(max_seq_len, chunk_size);
12891290
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
12901291
dim3 blocks(32, num_warps);
1291-
if (num_chunks <= 1) {
1292+
if (num_chunks <= 0) {
12921293
auto nosplit_kv_kernel =
12931294
multi_query_append_attention_c4_warp1_4_kernel<NV_TYPE,
12941295
uint8_t,
@@ -1392,15 +1393,15 @@ void MultiQueryAppendC4Attention(
13921393
const_cast<uint8_t *>(cache_v.data<uint8_t>()),
13931394
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_scale.data<T>())),
13941395
cache_k_zp ? reinterpret_cast<NV_TYPE *>(
1395-
const_cast<T *>(cache_k_zp.get().data<T>()))
1396-
: nullptr,
1396+
const_cast<T *>(cache_k_zp.get().data<T>()))
1397+
: nullptr,
13971398
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_scale.data<T>())),
13981399
cache_v_zp ? reinterpret_cast<NV_TYPE *>(
1399-
const_cast<T *>(cache_v_zp.get().data<T>()))
1400-
: nullptr,
1400+
const_cast<T *>(cache_v_zp.get().data<T>()))
1401+
: nullptr,
14011402
shift_bias ? reinterpret_cast<NV_TYPE *>(
1402-
const_cast<T *>(shift_bias.get().data<T>()))
1403-
: nullptr,
1403+
const_cast<T *>(shift_bias.get().data<T>()))
1404+
: nullptr,
14041405
smooth_weight ? reinterpret_cast<NV_TYPE *>(
14051406
const_cast<T *>(smooth_weight.get().data<T>()))
14061407
: nullptr,
@@ -1445,8 +1446,8 @@ void MultiQueryAppendC4Attention(
14451446
seq_lens_encoder.data<int>(),
14461447
cu_seqlens_q.data<int>(),
14471448
shift_bias ? reinterpret_cast<NV_TYPE *>(
1448-
const_cast<T *>(shift_bias.get().data<T>()))
1449-
: nullptr,
1449+
const_cast<T *>(shift_bias.get().data<T>()))
1450+
: nullptr,
14501451
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
14511452
smooth_weight.get().data<T>()))
14521453
: nullptr,
@@ -1463,14 +1464,14 @@ void MultiQueryAppendC4Attention(
14631464
constexpr int blockx = HEAD_DIM / vec_size;
14641465
constexpr int blocky = (128 + blockx - 1) / blockx;
14651466
dim3 grids_merge(min(sm_count * 4, token_num),
1466-
num_heads);
1467+
num_heads);
14671468
dim3 blocks_merge(blockx, blocky);
14681469
merge_multi_chunks_v2_kernel<NV_TYPE,
1469-
vec_size,
1470-
blocky,
1471-
HEAD_DIM,
1472-
OUT_NV_TYPE,
1473-
ENABLE_PREFILL>
1470+
vec_size,
1471+
blocky,
1472+
HEAD_DIM,
1473+
OUT_NV_TYPE,
1474+
ENABLE_PREFILL>
14741475
<<<grids_merge, blocks_merge, 0, stream>>>(
14751476
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
14761477
static_cast<float *>(tmp_m->ptr()),
@@ -1481,8 +1482,8 @@ void MultiQueryAppendC4Attention(
14811482
batch_id_per_token.data<int>(),
14821483
cu_seqlens_q.data<int>(),
14831484
shift_bias ? reinterpret_cast<NV_TYPE *>(
1484-
const_cast<T *>(shift_bias.get().data<T>()))
1485-
: nullptr,
1485+
const_cast<T *>(shift_bias.get().data<T>()))
1486+
: nullptr,
14861487
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
14871488
smooth_weight.get().data<T>()))
14881489
: nullptr,

custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1254,10 +1254,10 @@ void MultiQueryAppendC8Attention(
12541254
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
12551255
}
12561256

1257-
const int num_chunks = div_up(max_dec_len, chunk_size);
1257+
const int num_chunks = div_up(max_seq_len, chunk_size);
12581258
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
12591259
dim3 blocks(32, num_warps);
1260-
if (num_chunks <= 1) {
1260+
if (num_chunks <= 0) {
12611261
auto nosplit_kv_kernel =
12621262
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
12631263
uint8_t,
@@ -1377,8 +1377,8 @@ void MultiQueryAppendC8Attention(
13771377
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_scale.data<T>())),
13781378
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_scale.data<T>())),
13791379
shift_bias ? reinterpret_cast<NV_TYPE *>(
1380-
const_cast<T *>(shift_bias.get().data<T>()))
1381-
: nullptr,
1380+
const_cast<T *>(shift_bias.get().data<T>()))
1381+
: nullptr,
13821382
smooth_weight ? reinterpret_cast<NV_TYPE *>(
13831383
const_cast<T *>(smooth_weight.get().data<T>()))
13841384
: nullptr,
@@ -1418,8 +1418,8 @@ void MultiQueryAppendC8Attention(
14181418
seq_lens_encoder.data<int>(),
14191419
cu_seqlens_q.data<int>(),
14201420
shift_bias ? reinterpret_cast<NV_TYPE *>(
1421-
const_cast<T *>(shift_bias.get().data<T>()))
1422-
: nullptr,
1421+
const_cast<T *>(shift_bias.get().data<T>()))
1422+
: nullptr,
14231423
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
14241424
smooth_weight.get().data<T>()))
14251425
: nullptr,
@@ -1436,14 +1436,14 @@ void MultiQueryAppendC8Attention(
14361436
constexpr int blockx = HEAD_DIM / vec_size;
14371437
constexpr int blocky = (128 + blockx - 1) / blockx;
14381438
dim3 grids_merge(min(sm_count * 4, token_num),
1439-
num_heads);
1439+
num_heads);
14401440
dim3 blocks_merge(blockx, blocky);
14411441
merge_multi_chunks_v2_kernel<NV_TYPE,
1442-
vec_size,
1443-
blocky,
1444-
HEAD_DIM,
1445-
OUT_NV_TYPE,
1446-
ENABLE_PREFILL>
1442+
vec_size,
1443+
blocky,
1444+
HEAD_DIM,
1445+
OUT_NV_TYPE,
1446+
ENABLE_PREFILL>
14471447
<<<grids_merge, blocks_merge, 0, stream>>>(
14481448
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
14491449
static_cast<float *>(tmp_m->ptr()),
@@ -1454,8 +1454,8 @@ void MultiQueryAppendC8Attention(
14541454
batch_id_per_token.data<int>(),
14551455
cu_seqlens_q.data<int>(),
14561456
shift_bias ? reinterpret_cast<NV_TYPE *>(
1457-
const_cast<T *>(shift_bias.get().data<T>()))
1458-
: nullptr,
1457+
const_cast<T *>(shift_bias.get().data<T>()))
1458+
: nullptr,
14591459
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
14601460
smooth_weight.get().data<T>()))
14611461
: nullptr,

0 commit comments

Comments
 (0)