Skip to content

Commit b01cfd6

Browse files
authored
[BugFix] support real batch_size (#3109)
* support real bsz * fix * fix xpu_model_runner.py,gpu_model_runner.py,gcu_model_runner.py,iluvatar_model_runner.py * add event_loop_ep * fix * Add comments * fix * support mtp real_batch_size * fix * self.tmp_seq_lens_this_time->self.seq_lens_this_time_buffer * fix * fix VL real_seq_lens_this_time * fix * fix mtp * fix * fix mtp * fix xpu * fix
1 parent 55939f7 commit b01cfd6

File tree

10 files changed

+110
-58
lines changed

10 files changed

+110
-58
lines changed

fastdeploy/spec_decode/mtp.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode
107107
idx = i
108108
self.model_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
109109
self.model_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1)
110-
self.model_inputs["seq_lens_this_time"][idx : idx + 1] = input_length
110+
self.seq_lens_this_time_buffer[idx : idx + 1] = input_length
111111
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = input_length
112112
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = 0
113113
self.model_inputs["step_idx"][idx : idx + 1] = 0
@@ -118,6 +118,7 @@ def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode
118118
self.model_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange(
119119
idx * block_num, (idx + 1) * block_num, 1
120120
)
121+
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
121122

122123
def initialize_kv_cache(self):
123124
"""
@@ -263,7 +264,8 @@ def _init_model_inputs(self):
263264
# Same shape/dytpe with base model
264265
self.model_inputs["block_tables"] = paddle.clone(self.main_model_inputs["block_tables"])
265266
self.model_inputs["input_ids"] = paddle.clone(self.main_model_inputs["input_ids"])
266-
self.model_inputs["seq_lens_this_time"] = paddle.clone(self.main_model_inputs["seq_lens_this_time"])
267+
self.seq_lens_this_time_buffer = paddle.clone(self.main_model_inputs["seq_lens_this_time"])
268+
267269
self.model_inputs["seq_lens_encoder"] = paddle.clone(self.main_model_inputs["seq_lens_encoder"])
268270
self.model_inputs["seq_lens_decoder"] = paddle.clone(self.main_model_inputs["seq_lens_decoder"])
269271
self.model_inputs["step_idx"] = paddle.clone(self.main_model_inputs["step_idx"])
@@ -338,7 +340,7 @@ def _init_model_inputs(self):
338340
self.main_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32"
339341
)
340342

341-
def insert_prefill_inputs(self, req_dicts: List[Request]):
343+
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int):
342344
"""
343345
Process inputs for prefill tasks and insert it to model_inputs buffer
344346
"""
@@ -372,7 +374,7 @@ def insert_prefill_inputs(self, req_dicts: List[Request]):
372374

373375
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
374376
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = length
375-
self.model_inputs["seq_lens_this_time"][idx : idx + 1] = prefill_token_num
377+
self.seq_lens_this_time_buffer[idx : idx + 1] = prefill_token_num
376378

377379
self.model_inputs["stop_flags"][idx : idx + 1] = False
378380
self.model_inputs["batch_drop"][idx : idx + 1] = False
@@ -397,10 +399,10 @@ def insert_prefill_inputs(self, req_dicts: List[Request]):
397399
if self.cache_config.enable_chunked_prefill:
398400
token_chunk_size = request.prefill_chunk_info[0]
399401
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
400-
self.model_inputs["seq_lens_this_time"][idx : idx + 1] = token_chunk_size
402+
self.seq_lens_this_time_buffer[idx : idx + 1] = token_chunk_size
401403
else:
402404
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length
403-
self.model_inputs["seq_lens_this_time"][idx : idx + 1] = length
405+
self.seq_lens_this_time_buffer[idx : idx + 1] = length
404406

405407
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
406408
self.model_inputs["stop_flags"][idx : idx + 1] = False
@@ -413,6 +415,7 @@ def insert_prefill_inputs(self, req_dicts: List[Request]):
413415
request.get("block_tables"), dtype="int32"
414416
)
415417
self.model_inputs["not_need_stop"][0] = True
418+
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
416419

417420
def _initialize_forward_meta(self):
418421
"""

fastdeploy/worker/gcu_model_runner.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,11 @@ def _init_logits_processor(self, request):
152152
schemata_key,
153153
)
154154

155-
def insert_prefill_inputs(self, req_dicts: List[Request]):
155+
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int = None):
156156
"""
157157
Process inputs for prefill tasks and insert it to share_inputs buffer
158+
req_dict: A list of Request dict
159+
num_running_requests: batch_size
158160
"""
159161

160162
if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill":
@@ -193,7 +195,7 @@ def get_attr_from_request(request, attr, default_value=None):
193195
self.share_inputs["prompt_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids)
194196
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
195197
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = length
196-
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = 1
198+
self.seq_lens_this_time_buffer[idx : idx + 1] = 1
197199
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = 0
198200
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = length
199201
self.share_inputs["prompt_lens"][idx : idx + 1] = length
@@ -205,7 +207,7 @@ def get_attr_from_request(request, attr, default_value=None):
205207
request.draft_token_ids[0:num_prefill_send_token],
206208
dtype="int64",
207209
)
208-
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = num_prefill_send_token
210+
self.seq_lens_this_time_buffer[idx : idx + 1] = num_prefill_send_token
209211
else:
210212
self.share_inputs["pre_ids"][idx : idx + 1] = -1
211213
self.share_inputs["step_idx"][idx : idx + 1] = 0
@@ -222,14 +224,14 @@ def get_attr_from_request(request, attr, default_value=None):
222224
)
223225
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
224226
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
225-
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = token_chunk_size
227+
self.seq_lens_this_time_buffer[idx : idx + 1] = token_chunk_size
226228
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = token_chunk_size
227229
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
228230
self.share_inputs["prompt_lens"][idx : idx + 1] = token_chunk_size
229231
else:
230232
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
231233
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
232-
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
234+
self.seq_lens_this_time_buffer[idx : idx + 1] = length
233235
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length
234236
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
235237
self.share_inputs["prompt_lens"][idx : idx + 1] = length
@@ -295,6 +297,7 @@ def get_attr_from_request(request, attr, default_value=None):
295297

296298
if self.speculative_method in ["mtp"]:
297299
self.proposer.insert_prefill_inputs(req_dicts)
300+
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
298301

299302
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
300303
"""Set dummy prefill inputs to share_inputs"""
@@ -313,7 +316,7 @@ def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decod
313316
self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
314317
self.share_inputs["prompt_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
315318
self.share_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1)
316-
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = input_length
319+
self.seq_lens_this_time_buffer[idx : idx + 1] = input_length
317320
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = input_length
318321
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = input_length
319322
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
@@ -331,6 +334,7 @@ def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decod
331334
self.share_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange(
332335
idx * block_num, (idx + 1) * block_num, 1
333336
)
337+
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
334338

335339
def _init_share_inputs(self, max_num_seqs: int):
336340
"""
@@ -381,7 +385,7 @@ def _init_share_inputs(self, max_num_seqs: int):
381385
self.share_inputs["max_length"] = paddle.full(
382386
[max_num_seqs, 1], self.model_config.max_model_len, dtype="int64"
383387
)
384-
self.share_inputs["seq_lens_this_time"] = paddle.full(max_num_seqs, 0, dtype="int32")
388+
self.seq_lens_this_time_buffer = paddle.full(max_num_seqs, 0, dtype="int32")
385389
self.share_inputs["seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
386390
self.share_inputs["seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
387391
self.share_inputs["step_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
@@ -923,13 +927,15 @@ def _get_skip_idx(self, model_forward_batch: Optional[List[Request]] = None):
923927
def execute_model(
924928
self,
925929
model_forward_batch: Optional[List[Request]] = None,
930+
num_running_requests: int = None,
926931
) -> Optional[ModelRunnerOutput]:
927932
"""
928933
The Entrance of model execute.
929934
Args:
930935
model_forward_batch: 'Request' contains information related to prompt and is an abstract
931936
class at the server level, which is too granular for ModelRunner.
932937
We plan to replace it with 'ModelForwardBatch'.
938+
num_running_requests: batch_size
933939
intermediate_tensors:
934940
"""
935941
# If `not_need_stop`` is False, it means the current worker is in an idle state.
@@ -1055,6 +1061,9 @@ class at the server level, which is too granular for ModelRunner.
10551061

10561062
self._update_chunked_prefill(model_forward_batch)
10571063
self._add_cache(model_forward_batch)
1064+
self.seq_lens_this_time_buffer[:num_running_requests].copy_(
1065+
self.share_inputs["seq_lens_this_time"][:num_running_requests], False
1066+
)
10581067
return None
10591068

10601069
def _add_cache(self, model_forward_batch) -> None:

fastdeploy/worker/gcu_worker.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,17 +105,18 @@ def initialize_cache(self, num_gpu_blocks: int) -> None:
105105
def execute_model(
106106
self,
107107
model_forward_batch: Optional[List[Request]] = None,
108+
num_running_requests: int = None,
108109
) -> Optional[ModelRunnerOutput]:
109110
""" """
110-
output = self.model_runner.execute_model(model_forward_batch)
111+
output = self.model_runner.execute_model(model_forward_batch, num_running_requests)
111112
return output
112113

113-
def preprocess_new_task(self, req_dicts: List[Request]) -> None:
114+
def preprocess_new_task(self, req_dicts: List[Request], num_running_requests: int) -> None:
114115
"""Process new requests and then start the decode loop
115116
TODO(gongshaotian):The scheduler should schedule the handling of prefill,
116117
and workers and modelrunners should not perceive it.
117118
"""
118-
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts)
119+
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts, num_running_requests=num_running_requests)
119120

120121
def graph_optimize_and_warm_up_model(self) -> None:
121122
"""

0 commit comments

Comments
 (0)