@@ -352,6 +352,11 @@ def init_tensor(self):
352
352
self .max_dec_len_this_time = paddle .to_tensor ([self .max_dec_len_this_time ], "int32" , place = paddle .CPUPlace ())
353
353
self .seq_lens_this_time = self .seq_lens_encoder
354
354
355
+ self .decoder_batch_ids = paddle .full ([self .batch_size ], 0 , dtype = "int32" )
356
+ self .decoder_tile_ids_per_batch = paddle .full ([self .batch_size ], 0 , dtype = "int32" )
357
+ self .decoder_num_blocks_cpu = paddle .full ([1 ], 0 , dtype = "int32" ).pin_memory ()
358
+ self .max_len_tensor_cpu = paddle .full ([8 ], 0 , dtype = "int32" ).cpu ()
359
+
355
360
self .cache_shape = (
356
361
self .max_block_num ,
357
362
self .kv_num_head ,
@@ -414,16 +419,15 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask
414
419
kv_batch_ids ,
415
420
kv_tile_ids_per_batch ,
416
421
kv_num_blocks ,
417
- decoder_batch_ids ,
418
- decoder_tile_ids_per_batch ,
419
- decoder_num_blocks ,
420
422
max_len_kv ,
421
- set_max_lengths ,
422
423
) = get_block_shape_and_split_kv_block (
423
424
self .seq_lens_encoder ,
424
425
self .seq_lens_decoder ,
425
426
self .seq_lens_this_time ,
426
- self .cum_offset ,
427
+ self .decoder_batch_ids ,
428
+ self .decoder_tile_ids_per_batch ,
429
+ self .decoder_num_blocks_cpu ,
430
+ self .max_len_tensor_cpu ,
427
431
64 ,
428
432
12 ,
429
433
(self .q_num_head + 2 * self .kv_num_head ) // self .kv_num_head ,
@@ -454,10 +458,10 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask
454
458
kv_batch_ids ,
455
459
kv_tile_ids_per_batch ,
456
460
kv_num_blocks ,
457
- decoder_batch_ids ,
458
- decoder_tile_ids_per_batch ,
459
- decoder_num_blocks ,
460
- set_max_lengths ,
461
+ self . decoder_batch_ids ,
462
+ self . decoder_tile_ids_per_batch ,
463
+ self . decoder_num_blocks_cpu ,
464
+ self . max_len_tensor_cpu ,
461
465
max_len_kv ,
462
466
self .rope_emb , # rope_emb
463
467
None , # attn_mask
0 commit comments