Skip to content

Commit 7b596d0

Browse files
authored
[BugFix] fix real_bsz in ep (#3366)
* Your commit message here * fix ep * delete cuda_graph
1 parent 0ea8712 commit 7b596d0

File tree

3 files changed

+6
-2
lines changed

3 files changed

+6
-2
lines changed

fastdeploy/model_executor/models/ernie4_5_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def empty_input_forward(self):
450450
self.fd_config.model_config.moe_layer_start_index,
451451
self.fd_config.model_config.num_hidden_layers,
452452
):
453-
self.ernie.layers[i].mlp.expert(fake_hidden_states)
453+
self.ernie.layers[i].mlp.experts(fake_hidden_states, self.ernie.layers[i].mlp.gate)
454454

455455
def forward(
456456
self,

fastdeploy/worker/gpu_model_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,9 @@ def _init_share_inputs(self, max_num_seqs: int):
636636
self.share_inputs["max_length"] = paddle.full(
637637
[max_num_seqs, 1], self.model_config.max_model_len, dtype="int64"
638638
)
639-
self.seq_lens_this_time_buffer = paddle.full(max_num_seqs, 0, dtype="int32")
639+
self.seq_lens_this_time_buffer = paddle.full([max_num_seqs, 1], 0, dtype="int32")
640+
if self.fd_config.parallel_config.enable_expert_parallel:
641+
self.share_inputs["seq_lens_this_time"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
640642
self.share_inputs["seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
641643
self.share_inputs["seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
642644
self.share_inputs["step_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")

fastdeploy/worker/worker_process.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ def event_loop_ep(self) -> None:
250250
while True:
251251
self.worker_healthy_live_signal.value[self.local_rank % self.max_chips_per_node] = int(time.time())
252252

253+
num_running_requests = 0
253254
if self.fd_config.parallel_config.tensor_parallel_rank == 0 and self.task_queue.num_tasks() > 0:
254255
tasks, read_finish = self.task_queue.get_tasks()
255256

@@ -276,6 +277,7 @@ def event_loop_normal(self) -> None:
276277
self.nnode = int((self.parallel_config.tensor_parallel_size + 7) // 8)
277278
mp_num_per_node = self.parallel_config.tensor_parallel_size // self.nnode
278279
req_ids = []
280+
num_running_requests = 0
279281
while True:
280282
if self.local_rank == 0:
281283
if self.model_weights_status.value[0] != 0:

0 commit comments

Comments
 (0)