Skip to content

Commit c593e1a

Browse files
authored
[Bug Fix]Fix bug of append attention test case (#3202)
1 parent e39159f commit c593e1a

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

test/layers/test_append_attention.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,11 @@ def init_tensor(self):
352352
self.max_dec_len_this_time = paddle.to_tensor([self.max_dec_len_this_time], "int32", place=paddle.CPUPlace())
353353
self.seq_lens_this_time = self.seq_lens_encoder
354354

355+
self.decoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
356+
self.decoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
357+
self.decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory()
358+
self.max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
359+
355360
self.cache_shape = (
356361
self.max_block_num,
357362
self.kv_num_head,
@@ -414,16 +419,15 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask
414419
kv_batch_ids,
415420
kv_tile_ids_per_batch,
416421
kv_num_blocks,
417-
decoder_batch_ids,
418-
decoder_tile_ids_per_batch,
419-
decoder_num_blocks,
420422
max_len_kv,
421-
set_max_lengths,
422423
) = get_block_shape_and_split_kv_block(
423424
self.seq_lens_encoder,
424425
self.seq_lens_decoder,
425426
self.seq_lens_this_time,
426-
self.cum_offset,
427+
self.decoder_batch_ids,
428+
self.decoder_tile_ids_per_batch,
429+
self.decoder_num_blocks_cpu,
430+
self.max_len_tensor_cpu,
427431
64,
428432
12,
429433
(self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head,
@@ -454,10 +458,10 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask
454458
kv_batch_ids,
455459
kv_tile_ids_per_batch,
456460
kv_num_blocks,
457-
decoder_batch_ids,
458-
decoder_tile_ids_per_batch,
459-
decoder_num_blocks,
460-
set_max_lengths,
461+
self.decoder_batch_ids,
462+
self.decoder_tile_ids_per_batch,
463+
self.decoder_num_blocks_cpu,
464+
self.max_len_tensor_cpu,
461465
max_len_kv,
462466
self.rope_emb, # rope_emb
463467
None, # attn_mask

0 commit comments

Comments
 (0)