Skip to content

Commit 2485333

Browse files
authored
ep support logprob (#4089)
1 parent 10768a4 commit 2485333

File tree

3 files changed

+11
-17
lines changed

3 files changed

+11
-17
lines changed

custom_ops/gpu_ops/get_output_msg_with_topk.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,6 @@ void GetOutputTopK(const paddle::Tensor& x,
3939
int k,
4040
int64_t rank_id,
4141
bool wait_flag) {
42-
if (rank_id > 0) {
43-
return;
44-
}
4542

4643
static struct msgdata msg_rcv;
4744
int msg_queue_id = 1;

fastdeploy/engine/args_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -400,8 +400,6 @@ def __post_init__(self):
400400
if self.enable_logprob:
401401
if self.speculative_config is not None:
402402
raise NotImplementedError("Logprob does not support speculation_config.")
403-
if self.enable_expert_parallel:
404-
raise NotImplementedError("Logprob does not support enable_expert_parallel.")
405403
if not current_platform.is_cuda():
406404
raise NotImplementedError("Only CUDA platform supports logprob.")
407405
if self.splitwise_role != "mixed":

fastdeploy/output/token_processor.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -161,24 +161,23 @@ def process_sampling_results(self):
161161
continue
162162

163163
else:
164-
if (
164+
if self.use_logprobs:
165+
get_output_topk(
166+
self.output_tokens,
167+
self.output_scores,
168+
self.output_ranks,
169+
K,
170+
rank_id,
171+
is_blocking,
172+
)
173+
elif (
165174
self.cfg.parallel_config.enable_expert_parallel
166175
and self.cfg.parallel_config.data_parallel_size > 1
167176
):
168177
get_output_ep(self.output_tokens, rank_id, is_blocking)
169178

170179
else:
171-
if self.use_logprobs:
172-
get_output_topk(
173-
self.output_tokens,
174-
self.output_scores,
175-
self.output_ranks,
176-
K,
177-
rank_id,
178-
is_blocking,
179-
)
180-
else:
181-
get_output(self.output_tokens, rank_id, is_blocking)
180+
get_output(self.output_tokens, rank_id, is_blocking)
182181

183182
if self.output_tokens[0, 0] == -2:
184183
continue

0 commit comments

Comments
 (0)