Skip to content

Commit 77c1bd0

Browse files
authored
[XPU]Fixed the issue of performance degradation caused by enabling ENABLE_V1_KVCACHE_SCHEDULER (#3900)
* fix bug * fix bug * update * udpate * update
1 parent 473cde7 commit 77c1bd0

File tree

4 files changed

+20
-4
lines changed

4 files changed

+20
-4
lines changed

fastdeploy/config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1236,7 +1236,10 @@ def postprocess(self):
12361236

12371237
if self.max_num_batched_tokens is None:
12381238
if int(envs.ENABLE_V1_KVCACHE_SCHEDULER):
1239-
self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
1239+
if paddle.is_compiled_with_xpu():
1240+
self.max_num_batched_tokens = self.max_model_len
1241+
else:
1242+
self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
12401243
else:
12411244
if self.cache_config.enable_chunked_prefill:
12421245
self.max_num_batched_tokens = 2048

fastdeploy/engine/args_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from dataclasses import fields as dataclass_fields
2020
from typing import Any, Dict, List, Optional
2121

22+
import paddle
23+
2224
from fastdeploy import envs
2325
from fastdeploy.config import (
2426
CacheConfig,
@@ -1006,7 +1008,10 @@ def create_engine_config(self) -> FDConfig:
10061008

10071009
if self.max_num_batched_tokens is None:
10081010
if int(envs.ENABLE_V1_KVCACHE_SCHEDULER):
1009-
self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
1011+
if paddle.is_compiled_with_xpu():
1012+
self.max_num_batched_tokens = self.max_model_len
1013+
else:
1014+
self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
10101015
else:
10111016
if self.enable_chunked_prefill:
10121017
self.max_num_batched_tokens = 2048

fastdeploy/engine/sched/resource_manager_v1.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,9 @@ def schedule(self):
345345
while self.waiting and token_budget > 0:
346346
if len(self.running) == self.max_num_seqs:
347347
break
348-
if self.config.model_config.enable_mm and self.exist_prefill(scheduled_reqs):
348+
if (self.config.model_config.enable_mm or paddle.is_compiled_with_xpu()) and self.exist_prefill(
349+
scheduled_reqs
350+
):
349351
break
350352
request = self.waiting[0]
351353
if request.status == RequestStatus.WAITING:

fastdeploy/worker/xpu_model_runner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ def insert_tasks_v1(self, req_dicts: List[Request]):
383383

384384
req_len = len(req_dicts)
385385
has_prefill_task = False
386+
has_decode_task = False
386387
for i in range(req_len):
387388
request = req_dicts[i]
388389
idx = request.idx
@@ -392,6 +393,9 @@ def insert_tasks_v1(self, req_dicts: List[Request]):
392393
prefill_end_index = request.prefill_end_index
393394
length = prefill_end_index - prefill_start_index
394395
input_ids = request.prompt_token_ids + request.output_token_ids
396+
logger.debug(
397+
f"Handle prefill request {request} at idx {idx} prefill_start_index {prefill_start_index} prefill_end_index {prefill_end_index} need_prefilled_token_num {len(input_ids)}"
398+
)
395399
self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(
396400
input_ids[prefill_start_index:prefill_end_index]
397401
)
@@ -401,6 +405,8 @@ def insert_tasks_v1(self, req_dicts: List[Request]):
401405
self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
402406
request.block_tables, dtype="int32"
403407
)
408+
if self.share_inputs["is_block_step"][idx]: # has tasks to continue to decode
409+
has_decode_task = True
404410
self.share_inputs["stop_flags"][idx : idx + 1] = False
405411
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index
406412
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
@@ -474,7 +480,7 @@ def insert_tasks_v1(self, req_dicts: List[Request]):
474480
self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array(
475481
request.get("stop_token_ids"), dtype="int64"
476482
)
477-
if has_prefill_task:
483+
if has_prefill_task or has_decode_task:
478484
self.share_inputs["not_need_stop"][0] = True
479485

480486
def process_prefill_inputs(self, req_dicts: List[Request]):

0 commit comments

Comments
 (0)