Skip to content

Commit 1ee285c

Browse files
authored
[Intel HPU] enable chunked prefill (#5903)
* [Intel HPU] enable chunked prefill * fix bug by copilot comments
1 parent 83ae594 commit 1ee285c

File tree

4 files changed

+531
-100
lines changed

4 files changed

+531
-100
lines changed

fastdeploy/engine/sched/resource_manager_v1.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,12 @@ def _get_num_new_tokens(self, request, token_budget):
360360
# TODO: set condition to new _get_num_new_tokens
361361
num_new_tokens = request.need_prefill_tokens - request.num_computed_tokens
362362
num_new_tokens = min(num_new_tokens, token_budget)
363+
if (
364+
current_platform.is_intel_hpu()
365+
and request.need_prefill_tokens - request.num_computed_tokens > token_budget
366+
and token_budget > self.config.cache_config.block_size
367+
):
368+
num_new_tokens = token_budget // self.config.cache_config.block_size * self.config.cache_config.block_size
363369
request.with_image = False
364370

365371
if not self.config.model_config.enable_mm:
@@ -653,6 +659,13 @@ def _allocate_decode_and_extend():
653659
f"request.need_prefill_tokens {request.need_prefill_tokens},"
654660
f"request.num_computed_tokens {request.num_computed_tokens}"
655661
)
662+
if (
663+
current_platform.is_intel_hpu()
664+
and request.need_prefill_tokens - request.num_computed_tokens
665+
>= self.config.cache_config.block_size
666+
and token_budget < self.config.cache_config.block_size
667+
):
668+
continue
656669
num_new_tokens = self._get_num_new_tokens(request, token_budget)
657670
num_new_block = self.get_new_block_nums(request, num_new_tokens)
658671
# Allocate blocks to prefill
@@ -718,6 +731,13 @@ def _allocate_decode_and_extend():
718731
self._free_blocks(request)
719732
break
720733

734+
if (
735+
current_platform.is_intel_hpu()
736+
and request.need_prefill_tokens - request.num_computed_tokens
737+
>= self.config.cache_config.block_size
738+
and token_budget < self.config.cache_config.block_size
739+
):
740+
continue
721741
# Allocate blocks for the tokens that does not hit cache
722742
num_new_tokens = self._get_num_new_tokens(request, token_budget)
723743
num_new_block = self.get_new_block_nums(request, num_new_tokens)

fastdeploy/model_executor/forward_meta.py

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -298,40 +298,64 @@ class HPUForwardMeta(ForwardMeta):
298298
block_tables: Optional[paddle.Tensor] = None
299299

300300
#
301-
block_groups: Optional[paddle.Tensor] = None
301+
rotary_embs_encoder: Optional[paddle.Tensor] = None
302302

303303
#
304-
block_list: Optional[paddle.Tensor] = None
304+
block_groups_encoder: Optional[paddle.Tensor] = None
305305

306306
#
307-
block_indices: Optional[paddle.Tensor] = None
307+
block_list_encoder: Optional[paddle.Tensor] = None
308308

309309
#
310-
block_offsets: Optional[paddle.Tensor] = None
310+
block_indices_encoder: Optional[paddle.Tensor] = None
311311

312312
#
313-
block_mapping: Optional[paddle.Tensor] = None
313+
block_offsets_encoder: Optional[paddle.Tensor] = None
314314

315315
#
316-
attention_mask: Optional[paddle.Tensor] = None
316+
block_mapping_encoder: Optional[paddle.Tensor] = None
317317

318318
#
319-
block_size: Optional[paddle.Tensor] = None
319+
attention_mask_encoder: Optional[paddle.Tensor] = None
320+
321+
#
322+
batch_ids_encoder: Optional[paddle.Tensor] = None
323+
324+
#
325+
total_batch_encoder: int = 0
320326

321327
#
322-
batch_ids: Optional[paddle.Tensor] = None
328+
rotary_embs_decoder: Optional[paddle.Tensor] = None
323329

324330
#
325-
total_batch: Optional[paddle.Tensor] = None
331+
block_groups_decoder: Optional[paddle.Tensor] = None
326332

327333
#
328-
is_prompt: Optional[paddle.Tensor] = None
334+
block_list_decoder: Optional[paddle.Tensor] = None
335+
336+
#
337+
block_indices_decoder: Optional[paddle.Tensor] = None
338+
339+
#
340+
block_offsets_decoder: Optional[paddle.Tensor] = None
341+
342+
#
343+
block_mapping_decoder: Optional[paddle.Tensor] = None
344+
345+
#
346+
attention_mask_decoder: Optional[paddle.Tensor] = None
347+
348+
#
349+
batch_ids_decoder: Optional[paddle.Tensor] = None
350+
351+
#
352+
total_batch_decoder: int = 0
329353

330354
#
331355
attn_backend: "AttentionBackend_HPU" = None
332356

333357
#
334-
rotary_embs: Optional[paddle.Tensor] = None
358+
block_size: Optional[paddle.Tensor] = None
335359

336360
#
337361
caches: Optional[paddle.Tensor] = None
@@ -349,10 +373,12 @@ class HPUForwardMeta(ForwardMeta):
349373
def init_forward_meta(cls, share_inputs: Dict, attn_backend: "AttentionBackend_HPU"):
350374
"""init forward meta"""
351375
# TODO(gongshaotian): delete this func
352-
is_prompt = share_inputs["is_prompt"]
353-
forward_mode = ForwardMode.DECODE
354-
if is_prompt:
376+
if share_inputs["total_batch_encoder"] > 0 and share_inputs["total_batch_decoder"] > 0:
377+
forward_mode = ForwardMode.MIXED
378+
elif share_inputs["total_batch_encoder"] > 0:
355379
forward_mode = ForwardMode.EXTEND
380+
elif share_inputs["total_batch_decoder"] > 0:
381+
forward_mode = ForwardMode.DECODE
356382
ret = cls(
357383
forward_mode=forward_mode,
358384
input_ids=share_inputs["input_ids"],
@@ -361,18 +387,26 @@ def init_forward_meta(cls, share_inputs: Dict, attn_backend: "AttentionBackend_H
361387
seq_lens_decoder=share_inputs["seq_lens_decoder"],
362388
seq_lens_this_time=share_inputs["seq_lens_this_time"],
363389
block_tables=share_inputs["block_tables"],
364-
block_groups=share_inputs["block_groups"],
365-
block_list=share_inputs["block_list"],
366-
block_indices=share_inputs["block_indices"],
367-
block_offsets=share_inputs["block_offsets"],
368-
block_mapping=share_inputs["block_mapping"],
369-
attention_mask=share_inputs["block_bias"],
390+
rotary_embs_encoder=share_inputs["rotary_embs_encoder"],
391+
block_groups_encoder=share_inputs["block_groups_encoder"],
392+
block_list_encoder=share_inputs["block_list_encoder"],
393+
block_indices_encoder=share_inputs["block_indices_encoder"],
394+
block_offsets_encoder=share_inputs["block_offsets_encoder"],
395+
block_mapping_encoder=share_inputs["block_mapping_encoder"],
396+
attention_mask_encoder=share_inputs["block_bias_encoder"],
397+
total_batch_encoder=share_inputs["total_batch_encoder"],
398+
batch_ids_encoder=share_inputs["batch_ids_encoder"],
399+
rotary_embs_decoder=share_inputs["rotary_embs_decoder"],
400+
block_groups_decoder=share_inputs["block_groups_decoder"],
401+
block_list_decoder=share_inputs["block_list_decoder"],
402+
block_indices_decoder=share_inputs["block_indices_decoder"],
403+
block_offsets_decoder=share_inputs["block_offsets_decoder"],
404+
block_mapping_decoder=share_inputs["block_mapping_decoder"],
405+
attention_mask_decoder=share_inputs["block_bias_decoder"],
406+
total_batch_decoder=share_inputs["total_batch_decoder"],
407+
batch_ids_decoder=share_inputs["batch_ids_decoder"],
370408
block_size=share_inputs["block_size"],
371-
total_batch=share_inputs["total_batch"],
372-
batch_ids=share_inputs["batch_ids"],
373-
is_prompt=share_inputs["is_prompt"],
374409
attn_backend=attn_backend,
375-
rotary_embs=share_inputs["rotary_embs"],
376410
caches=share_inputs["caches"],
377411
)
378412
return ret

0 commit comments

Comments
 (0)