@@ -41,7 +41,9 @@ std::vector<paddle::Tensor> BlockAttnKernel(
41
41
const paddle::Tensor &encoder_seq_lod_cpu,
42
42
const paddle::Tensor &encoder_batch_map_cpu,
43
43
const paddle::Tensor &decoder_context_len_cpu,
44
- const paddle::Tensor &decoder_batch_map_cpu) {
44
+ const paddle::Tensor &decoder_batch_map_cpu,
45
+ const std::string &pos_emb_type=" NORMAL" ,
46
+ bool rope_3d=false ) {
45
47
phi::XPUPlace place (phi::backends::xpu::GetXPUCurrentDeviceId ());
46
48
auto dev_ctx =
47
49
paddle::experimental::DeviceContextPool::Instance ().Get (place);
@@ -72,6 +74,14 @@ std::vector<paddle::Tensor> BlockAttnKernel(
72
74
int enc_batch = enc_batch_tensor.data <int32_t >()[0 ];
73
75
int dec_batch = dec_batch_tensor.data <int32_t >()[0 ];
74
76
int total_enc_len = total_enc_len_tensor.data <int32_t >()[0 ];
77
+ int rope_max_seqlen = 0 ;
78
+ int rope_3d_num_seqs = 1 ;
79
+ if (rope_3d) {
80
+ rope_max_seqlen = rotary_embs.dims ()[3 ];
81
+ rope_3d_num_seqs = rotary_embs.dims ()[0 ];
82
+ } else {
83
+ rope_max_seqlen = rotary_embs.dims ()[2 ];
84
+ }
75
85
76
86
auto block_attn_out =
77
87
paddle::full ({token_num, hidden_dim}, -1 , qkv.type (), qkv.place ());
@@ -151,10 +161,10 @@ std::vector<paddle::Tensor> BlockAttnKernel(
151
161
prefix_lens_vp, // start_tokens
152
162
param.batch_size , // batch_size
153
163
1 , // emb_batch_size
154
- rotary_embs. dims ()[ 2 ], // max_seqlen
164
+ rope_max_seqlen, // max_seqlen
155
165
param.head_num , param.kv_head_num , param.head_dim ,
156
166
param.max_batch_size , block_size, max_block_per_seq, " BLHD" ,
157
- " HLD" , " NORMAL " ,
167
+ " HLD" , pos_emb_type ,
158
168
!p_kcache_perhead_scale.defined ()
159
169
? nullptr
160
170
: p_kcache_perhead_scale.data <float >() +
@@ -246,10 +256,10 @@ std::vector<paddle::Tensor> BlockAttnKernel(
246
256
vsl.slot_mapping_vp , // real_batch
247
257
param.batch_size , // batch_size
248
258
1 , // emb_batch_size
249
- rotary_embs. dims ()[ 2 ], // max_seqlen TODO!!double check
259
+ rope_max_seqlen, // max_seqlen
250
260
param.head_num , param.kv_head_num , param.head_dim ,
251
261
param.max_batch_size , block_size, max_block_per_seq, " BLHD" , " HLD" ,
252
- " NORMAL " ,
262
+ pos_emb_type ,
253
263
!p_kcache_perhead_scale.defined ()
254
264
? nullptr
255
265
: p_kcache_perhead_scale.data <float >() +
@@ -260,7 +270,9 @@ std::vector<paddle::Tensor> BlockAttnKernel(
260
270
param.kv_head_num , // v_cache_scale_inv
261
271
nullptr , // k_cache_zp
262
272
nullptr , // v_cache_zp
263
- false ); // b_c8_pc
273
+ false , // b_c8_pc
274
+ rope_3d, // rope_3d
275
+ rope_3d_num_seqs);
264
276
XFTBLOCK_CHECK_EQ (ret, api::SUCCESS);
265
277
266
278
// attn decode
@@ -314,6 +326,7 @@ PD_BUILD_OP(block_attn)
314
326
" decoder_context_len_cpu" ,
315
327
" decoder_batch_map_cpu" ,
316
328
})
329
+ .Attrs({" pos_emb_type:std::string" , " rope_3d:bool" })
317
330
.Outputs({" block_attn_out" })
318
331
.SetKernelFn(PD_KERNEL(BlockAttnKernel))
319
332
.SetInferShapeFn(PD_INFER_SHAPE(BlockAttnInferShape))
0 commit comments