Skip to content

Commit bc0b92b

Browse files
authored
[BugFix] support real batch_size (#3109) (#3217)
* support real bsz * fix * fix xpu_model_runner.py,gpu_model_runner.py,gcu_model_runner.py,iluvatar_model_runner.py * add event_loop_ep * fix * Add comments * fix * support mtp real_batch_size * fix * self.tmp_seq_lens_this_time->self.seq_lens_this_time_buffer * fix * fix VL real_seq_lens_this_time * fix * fix mtp * fix * fix mtp * fix xpu * fix
1 parent 3dd8492 commit bc0b92b

File tree

10 files changed

+110
-58
lines changed

10 files changed

+110
-58
lines changed

fastdeploy/spec_decode/mtp.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode
107107
idx = i
108108
self.model_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
109109
self.model_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1)
110-
self.model_inputs["seq_lens_this_time"][idx : idx + 1] = input_length
110+
self.seq_lens_this_time_buffer[idx : idx + 1] = input_length
111111
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = input_length
112112
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = 0
113113
self.model_inputs["step_idx"][idx : idx + 1] = 0
@@ -118,6 +118,7 @@ def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode
118118
self.model_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange(
119119
idx * block_num, (idx + 1) * block_num, 1
120120
)
121+
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
121122

122123
def initialize_kv_cache(self):
123124
"""
@@ -263,7 +264,8 @@ def _init_model_inputs(self):
263264
# Same shape/dytpe with base model
264265
self.model_inputs["block_tables"] = paddle.clone(self.main_model_inputs["block_tables"])
265266
self.model_inputs["input_ids"] = paddle.clone(self.main_model_inputs["input_ids"])
266-
self.model_inputs["seq_lens_this_time"] = paddle.clone(self.main_model_inputs["seq_lens_this_time"])
267+
self.seq_lens_this_time_buffer = paddle.clone(self.main_model_inputs["seq_lens_this_time"])
268+
267269
self.model_inputs["seq_lens_encoder"] = paddle.clone(self.main_model_inputs["seq_lens_encoder"])
268270
self.model_inputs["seq_lens_decoder"] = paddle.clone(self.main_model_inputs["seq_lens_decoder"])
269271
self.model_inputs["step_idx"] = paddle.clone(self.main_model_inputs["step_idx"])
@@ -338,7 +340,7 @@ def _init_model_inputs(self):
338340
self.main_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32"
339341
)
340342

341-
def insert_prefill_inputs(self, req_dicts: List[Request]):
343+
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int):
342344
"""
343345
Process inputs for prefill tasks and insert it to model_inputs buffer
344346
"""
@@ -372,7 +374,7 @@ def insert_prefill_inputs(self, req_dicts: List[Request]):
372374

373375
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
374376
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = length
375-
self.model_inputs["seq_lens_this_time"][idx : idx + 1] = prefill_token_num
377+
self.seq_lens_this_time_buffer[idx : idx + 1] = prefill_token_num
376378

377379
self.model_inputs["stop_flags"][idx : idx + 1] = False
378380
self.model_inputs["batch_drop"][idx : idx + 1] = False
@@ -397,10 +399,10 @@ def insert_prefill_inputs(self, req_dicts: List[Request]):
397399
if self.cache_config.enable_chunked_prefill:
398400
token_chunk_size = request.prefill_chunk_info[0]
399401
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
400-
self.model_inputs["seq_lens_this_time"][idx : idx + 1] = token_chunk_size
402+
self.seq_lens_this_time_buffer[idx : idx + 1] = token_chunk_size
401403
else:
402404
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length
403-
self.model_inputs["seq_lens_this_time"][idx : idx + 1] = length
405+
self.seq_lens_this_time_buffer[idx : idx + 1] = length
404406

405407
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
406408
self.model_inputs["stop_flags"][idx : idx + 1] = False
@@ -413,6 +415,7 @@ def insert_prefill_inputs(self, req_dicts: List[Request]):
413415
request.get("block_tables"), dtype="int32"
414416
)
415417
self.model_inputs["not_need_stop"][0] = True
418+
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
416419

417420
def _initialize_forward_meta(self):
418421
"""

fastdeploy/worker/gcu_model_runner.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,11 @@ def _init_logits_processor(self, request):
152152
schemata_key,
153153
)
154154

155-
def insert_prefill_inputs(self, req_dicts: List[Request]):
155+
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int = None):
156156
"""
157157
Process inputs for prefill tasks and insert it to share_inputs buffer
158+
req_dict: A list of Request dict
159+
num_running_requests: batch_size
158160
"""
159161

160162
if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill":
@@ -193,7 +195,7 @@ def get_attr_from_request(request, attr, default_value=None):
193195
self.share_inputs["prompt_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids)
194196
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
195197
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = length
196-
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = 1
198+
self.seq_lens_this_time_buffer[idx : idx + 1] = 1
197199
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = 0
198200
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = length
199201
self.share_inputs["prompt_lens"][idx : idx + 1] = length
@@ -205,7 +207,7 @@ def get_attr_from_request(request, attr, default_value=None):
205207
request.draft_token_ids[0:num_prefill_send_token],
206208
dtype="int64",
207209
)
208-
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = num_prefill_send_token
210+
self.seq_lens_this_time_buffer[idx : idx + 1] = num_prefill_send_token
209211
else:
210212
self.share_inputs["pre_ids"][idx : idx + 1] = -1
211213
self.share_inputs["step_idx"][idx : idx + 1] = 0
@@ -222,14 +224,14 @@ def get_attr_from_request(request, attr, default_value=None):
222224
)
223225
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
224226
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
225-
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = token_chunk_size
227+
self.seq_lens_this_time_buffer[idx : idx + 1] = token_chunk_size
226228
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = token_chunk_size
227229
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
228230
self.share_inputs["prompt_lens"][idx : idx + 1] = token_chunk_size
229231
else:
230232
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
231233
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
232-
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
234+
self.seq_lens_this_time_buffer[idx : idx + 1] = length
233235
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length
234236
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
235237
self.share_inputs["prompt_lens"][idx : idx + 1] = length
@@ -293,6 +295,7 @@ def get_attr_from_request(request, attr, default_value=None):
293295

294296
if self.speculative_method in ["mtp"]:
295297
self.proposer.insert_prefill_inputs(req_dicts)
298+
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
296299

297300
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
298301
"""Set dummy prefill inputs to share_inputs"""
@@ -311,7 +314,7 @@ def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decod
311314
self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
312315
self.share_inputs["prompt_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
313316
self.share_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1)
314-
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = input_length
317+
self.seq_lens_this_time_buffer[idx : idx + 1] = input_length
315318
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = input_length
316319
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = input_length
317320
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
@@ -329,6 +332,7 @@ def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decod
329332
self.share_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange(
330333
idx * block_num, (idx + 1) * block_num, 1
331334
)
335+
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
332336

333337
def _init_share_inputs(self, max_num_seqs: int):
334338
"""
@@ -379,7 +383,7 @@ def _init_share_inputs(self, max_num_seqs: int):
379383
self.share_inputs["max_length"] = paddle.full(
380384
[max_num_seqs, 1], self.model_config.max_model_len, dtype="int64"
381385
)
382-
self.share_inputs["seq_lens_this_time"] = paddle.full(max_num_seqs, 0, dtype="int32")
386+
self.seq_lens_this_time_buffer = paddle.full(max_num_seqs, 0, dtype="int32")
383387
self.share_inputs["seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
384388
self.share_inputs["seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
385389
self.share_inputs["step_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
@@ -921,13 +925,15 @@ def _get_skip_idx(self, model_forward_batch: Optional[List[Request]] = None):
921925
def execute_model(
922926
self,
923927
model_forward_batch: Optional[List[Request]] = None,
928+
num_running_requests: int = None,
924929
) -> Optional[ModelRunnerOutput]:
925930
"""
926931
The Entrance of model execute.
927932
Args:
928933
model_forward_batch: 'Request' contains information related to prompt and is an abstract
929934
class at the server level, which is too granular for ModelRunner.
930935
We plan to replace it with 'ModelForwardBatch'.
936+
num_running_requests: batch_size
931937
intermediate_tensors:
932938
"""
933939
# If `not_need_stop`` is False, it means the current worker is in an idle state.
@@ -1053,6 +1059,9 @@ class at the server level, which is too granular for ModelRunner.
10531059

10541060
self._update_chunked_prefill(model_forward_batch)
10551061
self._add_cache(model_forward_batch)
1062+
self.seq_lens_this_time_buffer[:num_running_requests].copy_(
1063+
self.share_inputs["seq_lens_this_time"][:num_running_requests], False
1064+
)
10561065
return None
10571066

10581067
def _add_cache(self, model_forward_batch) -> None:

fastdeploy/worker/gcu_worker.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,17 +105,18 @@ def initialize_cache(self, num_gpu_blocks: int) -> None:
105105
def execute_model(
106106
self,
107107
model_forward_batch: Optional[List[Request]] = None,
108+
num_running_requests: int = None,
108109
) -> Optional[ModelRunnerOutput]:
109110
""" """
110-
output = self.model_runner.execute_model(model_forward_batch)
111+
output = self.model_runner.execute_model(model_forward_batch, num_running_requests)
111112
return output
112113

113-
def preprocess_new_task(self, req_dicts: List[Request]) -> None:
114+
def preprocess_new_task(self, req_dicts: List[Request], num_running_requests: int) -> None:
114115
"""Process new requests and then start the decode loop
115116
TODO(gongshaotian):The scheduler should schedule the handling of prefill,
116117
and workers and modelrunners should not perceive it.
117118
"""
118-
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts)
119+
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts, num_running_requests=num_running_requests)
119120

120121
def graph_optimize_and_warm_up_model(self) -> None:
121122
"""

0 commit comments

Comments
 (0)