Skip to content

Commit 702c313

Browse files
authored
revert pr (#3286)
1 parent 6706ccb commit 702c313

File tree

2 files changed

+13
-21
lines changed

2 files changed

+13
-21
lines changed

fastdeploy/worker/xpu_model_runner.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def __init__(self, fd_config: FDConfig, device: str, rank: int, local_rank: int)
373373
# Forward meta store the global meta information of the forward
374374
self.forward_meta: ForwardMeta = None
375375

376-
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = None):
376+
def insert_tasks_v1(self, req_dicts: List[Request]):
377377
"""
378378
Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1
379379
"""
@@ -403,7 +403,7 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
403403
)
404404
self.share_inputs["stop_flags"][idx : idx + 1] = False
405405
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index
406-
self.seq_lens_this_time_buffer[idx : idx + 1] = length
406+
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
407407
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
408408
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0
409409
self.share_inputs["prompt_lens"][idx : idx + 1] = len(input_ids)
@@ -425,7 +425,7 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
425425
logger.debug(f"Handle preempted request {request} at idx {idx}")
426426
self.share_inputs["block_tables"][idx : idx + 1, :] = -1
427427
self.share_inputs["stop_flags"][idx : idx + 1] = True
428-
self.seq_lens_this_time_buffer[idx : idx + 1] = 0
428+
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = 0
429429
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
430430
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
431431
self.share_inputs["is_block_step"][idx : idx + 1] = False
@@ -462,9 +462,8 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
462462
)
463463
if has_prefill_task:
464464
self.share_inputs["not_need_stop"][0] = True
465-
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
466465

467-
def process_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int = None):
466+
def process_prefill_inputs(self, req_dicts: List[Request]):
468467
"""Process inputs for prefill tasks and update share_inputs buffer"""
469468
req_len = len(req_dicts)
470469
for i in range(req_len):
@@ -483,7 +482,7 @@ def process_prefill_inputs(self, req_dicts: List[Request], num_running_requests:
483482
self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0)
484483
self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0)
485484
self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0)
486-
self.seq_lens_this_time_buffer[idx : idx + 1] = length
485+
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
487486
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length
488487
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
489488
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
@@ -527,7 +526,6 @@ def process_prefill_inputs(self, req_dicts: List[Request], num_running_requests:
527526
)
528527

529528
self.share_inputs["not_need_stop"][0] = True
530-
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
531529

532530
def _init_share_inputs(self, max_num_seqs: int):
533531
"""Initialize all share buffers for model inputs.
@@ -573,7 +571,7 @@ def _init_share_inputs(self, max_num_seqs: int):
573571
self.share_inputs["max_length"] = paddle.full(
574572
[max_num_seqs, 1], self.model_config.max_model_len, dtype="int64"
575573
)
576-
self.seq_lens_this_time_buffer = paddle.full(max_num_seqs, 0, dtype="int32")
574+
self.share_inputs["seq_lens_this_time"] = paddle.full(max_num_seqs, 0, dtype="int32")
577575
self.share_inputs["seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
578576
self.share_inputs["seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
579577
self.share_inputs["step_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
@@ -815,7 +813,7 @@ def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int):
815813
idx = i
816814
self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
817815
self.share_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1)
818-
self.seq_lens_this_time_buffer[idx : idx + 1] = input_length
816+
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = input_length
819817
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = input_length
820818
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = input_length
821819
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
@@ -831,7 +829,6 @@ def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int):
831829
self.share_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange(
832830
idx * block_num, (idx + 1) * block_num, 1
833831
)
834-
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
835832

836833
def _dummy_run(
837834
self,
@@ -925,10 +922,6 @@ class at the server level, which is too granular for ModelRunner.
925922
self.cache_config.block_size,
926923
self.cache_config.enc_dec_block_num,
927924
)
928-
if num_running_requests is not None:
929-
self.seq_lens_this_time_buffer[:num_running_requests].copy_(
930-
self.share_inputs["seq_lens_this_time"][:num_running_requests], False
931-
)
932925

933926
return None
934927

fastdeploy/worker/xpu_worker.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,9 @@ def execute_model(
149149
num_running_requests: Optional[int] = None,
150150
) -> Optional[ModelRunnerOutput]:
151151
""" """
152-
if is_dummy_run:
153-
output = self.model_runner.execute_model(model_forward_batch)
154-
else:
155-
output = self.model_runner.execute_model(model_forward_batch, num_running_requests)
152+
153+
output = self.model_runner.execute_model(model_forward_batch)
154+
156155
return output
157156

158157
def exist_prefill(self):
@@ -161,15 +160,15 @@ def exist_prefill(self):
161160
"""
162161
return self.model_runner.exist_prefill()
163162

164-
def preprocess_new_task(self, req_dicts: List[Request], num_running_requests: int) -> None:
163+
def preprocess_new_task(self, req_dicts: List[Request], num_running_requests: int = -1) -> None:
165164
"""Process new requests and then start the decode loop
166165
TODO(gongshaotian):The scheduler should schedule the handling of prefill,
167166
and workers and modelrunners should not perceive it.
168167
"""
169168
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
170-
self.model_runner.insert_tasks_v1(req_dicts=req_dicts, num_running_requests=num_running_requests)
169+
self.model_runner.insert_tasks_v1(req_dicts=req_dicts)
171170
else:
172-
self.model_runner.process_prefill_inputs(req_dicts=req_dicts, num_running_requests=num_running_requests)
171+
self.model_runner.process_prefill_inputs(req_dicts=req_dicts)
173172

174173
def check_health(self) -> bool:
175174
""" """

0 commit comments

Comments
 (0)