@@ -298,40 +298,64 @@ class HPUForwardMeta(ForwardMeta):
298298 block_tables : Optional [paddle .Tensor ] = None
299299
300300 #
301- block_groups : Optional [paddle .Tensor ] = None
301+ rotary_embs_encoder : Optional [paddle .Tensor ] = None
302302
303303 #
304- block_list : Optional [paddle .Tensor ] = None
304+ block_groups_encoder : Optional [paddle .Tensor ] = None
305305
306306 #
307- block_indices : Optional [paddle .Tensor ] = None
307+ block_list_encoder : Optional [paddle .Tensor ] = None
308308
309309 #
310- block_offsets : Optional [paddle .Tensor ] = None
310+ block_indices_encoder : Optional [paddle .Tensor ] = None
311311
312312 #
313- block_mapping : Optional [paddle .Tensor ] = None
313+ block_offsets_encoder : Optional [paddle .Tensor ] = None
314314
315315 #
316- attention_mask : Optional [paddle .Tensor ] = None
316+ block_mapping_encoder : Optional [paddle .Tensor ] = None
317317
318318 #
319- block_size : Optional [paddle .Tensor ] = None
319+ attention_mask_encoder : Optional [paddle .Tensor ] = None
320+
321+ #
322+ batch_ids_encoder : Optional [paddle .Tensor ] = None
323+
324+ #
325+ total_batch_encoder : int = 0
320326
321327 #
322- batch_ids : Optional [paddle .Tensor ] = None
328+ rotary_embs_decoder : Optional [paddle .Tensor ] = None
323329
324330 #
325- total_batch : Optional [paddle .Tensor ] = None
331+ block_groups_decoder : Optional [paddle .Tensor ] = None
326332
327333 #
328- is_prompt : Optional [paddle .Tensor ] = None
334+ block_list_decoder : Optional [paddle .Tensor ] = None
335+
336+ #
337+ block_indices_decoder : Optional [paddle .Tensor ] = None
338+
339+ #
340+ block_offsets_decoder : Optional [paddle .Tensor ] = None
341+
342+ #
343+ block_mapping_decoder : Optional [paddle .Tensor ] = None
344+
345+ #
346+ attention_mask_decoder : Optional [paddle .Tensor ] = None
347+
348+ #
349+ batch_ids_decoder : Optional [paddle .Tensor ] = None
350+
351+ #
352+ total_batch_decoder : int = 0
329353
330354 #
331355 attn_backend : "AttentionBackend_HPU" = None
332356
333357 #
334- rotary_embs : Optional [paddle .Tensor ] = None
358+ block_size : Optional [paddle .Tensor ] = None
335359
336360 #
337361 caches : Optional [paddle .Tensor ] = None
@@ -349,10 +373,12 @@ class HPUForwardMeta(ForwardMeta):
349373 def init_forward_meta (cls , share_inputs : Dict , attn_backend : "AttentionBackend_HPU" ):
350374 """init forward meta"""
351375 # TODO(gongshaotian): delete this func
352- is_prompt = share_inputs ["is_prompt" ]
353- forward_mode = ForwardMode .DECODE
354- if is_prompt :
376+ if share_inputs [ "total_batch_encoder" ] > 0 and share_inputs ["total_batch_decoder" ] > 0 :
377+ forward_mode = ForwardMode .MIXED
378+ elif share_inputs [ "total_batch_encoder" ] > 0 :
355379 forward_mode = ForwardMode .EXTEND
380+ elif share_inputs ["total_batch_decoder" ] > 0 :
381+ forward_mode = ForwardMode .DECODE
356382 ret = cls (
357383 forward_mode = forward_mode ,
358384 input_ids = share_inputs ["input_ids" ],
@@ -361,18 +387,26 @@ def init_forward_meta(cls, share_inputs: Dict, attn_backend: "AttentionBackend_H
361387 seq_lens_decoder = share_inputs ["seq_lens_decoder" ],
362388 seq_lens_this_time = share_inputs ["seq_lens_this_time" ],
363389 block_tables = share_inputs ["block_tables" ],
364- block_groups = share_inputs ["block_groups" ],
365- block_list = share_inputs ["block_list" ],
366- block_indices = share_inputs ["block_indices" ],
367- block_offsets = share_inputs ["block_offsets" ],
368- block_mapping = share_inputs ["block_mapping" ],
369- attention_mask = share_inputs ["block_bias" ],
390+ rotary_embs_encoder = share_inputs ["rotary_embs_encoder" ],
391+ block_groups_encoder = share_inputs ["block_groups_encoder" ],
392+ block_list_encoder = share_inputs ["block_list_encoder" ],
393+ block_indices_encoder = share_inputs ["block_indices_encoder" ],
394+ block_offsets_encoder = share_inputs ["block_offsets_encoder" ],
395+ block_mapping_encoder = share_inputs ["block_mapping_encoder" ],
396+ attention_mask_encoder = share_inputs ["block_bias_encoder" ],
397+ total_batch_encoder = share_inputs ["total_batch_encoder" ],
398+ batch_ids_encoder = share_inputs ["batch_ids_encoder" ],
399+ rotary_embs_decoder = share_inputs ["rotary_embs_decoder" ],
400+ block_groups_decoder = share_inputs ["block_groups_decoder" ],
401+ block_list_decoder = share_inputs ["block_list_decoder" ],
402+ block_indices_decoder = share_inputs ["block_indices_decoder" ],
403+ block_offsets_decoder = share_inputs ["block_offsets_decoder" ],
404+ block_mapping_decoder = share_inputs ["block_mapping_decoder" ],
405+ attention_mask_decoder = share_inputs ["block_bias_decoder" ],
406+ total_batch_decoder = share_inputs ["total_batch_decoder" ],
407+ batch_ids_decoder = share_inputs ["batch_ids_decoder" ],
370408 block_size = share_inputs ["block_size" ],
371- total_batch = share_inputs ["total_batch" ],
372- batch_ids = share_inputs ["batch_ids" ],
373- is_prompt = share_inputs ["is_prompt" ],
374409 attn_backend = attn_backend ,
375- rotary_embs = share_inputs ["rotary_embs" ],
376410 caches = share_inputs ["caches" ],
377411 )
378412 return ret
0 commit comments