Skip to content

Commit 46c8491

Browse files
authored
merge logprob into batch_output (#3266)
1 parent 566badb commit 46c8491

File tree

2 files changed

+41
-157
lines changed

2 files changed

+41
-157
lines changed

fastdeploy/engine/args_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
TaskOption,
3131
)
3232
from fastdeploy.engine.config import Config
33+
from fastdeploy.platforms import current_platform
3334
from fastdeploy.scheduler.config import SchedulerConfig
3435
from fastdeploy.utils import DeprecatedOptionWarning, FlexibleArgumentParser
3536

@@ -344,6 +345,13 @@ def __post_init__(self):
344345
"""
345346
if not self.tokenizer:
346347
self.tokenizer = self.model
348+
if self.enable_logprob:
349+
if self.speculative_config is not None:
350+
raise NotImplementedError("Logprob does not support speculation_config.")
351+
if self.enable_expert_parallel:
352+
raise NotImplementedError("Logprob does not support enable_expert_parallel.")
353+
if not current_platform.is_cuda():
354+
raise NotImplementedError("Only CUDA platform supports logprob.")
347355

348356
@staticmethod
349357
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:

fastdeploy/output/token_processor.py

Lines changed: 33 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,15 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn
5757
self.split_connector = split_connector
5858

5959
self.speculative_decoding = self.cfg.speculative_config.method is not None
60+
self.use_logprobs = self.cfg.enable_logprob
6061

6162
if self.speculative_decoding:
6263
self.output_tokens = paddle.full(
6364
shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2],
6465
fill_value=2,
6566
dtype="int64",
6667
)
67-
elif self.cfg.enable_logprob:
68+
elif self.use_logprobs:
6869
self.output_tokens = paddle.full(shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64")
6970
self.output_scores = paddle.full(shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32")
7071
self.output_ranks = paddle.full(shape=[MAX_BSZ], fill_value=0, dtype="int64")
@@ -125,53 +126,12 @@ def run(self):
125126
assert self.resource_manager is not None, "The resource manager is None, cannot run."
126127
if self.worker is not None:
127128
raise Exception("Worker is already running!")
128-
use_logprobs = (
129-
self.cfg.enable_logprob
130-
and not self.speculative_decoding
131-
and not self.cfg.parallel_config.enable_expert_parallel
132-
)
133-
134-
target_func = self.process_sampling_with_logprob_results if use_logprobs else self.process_sampling_results
135129

136-
self.worker = threading.Thread(target=target_func)
130+
self.worker = threading.Thread(target=self.process_sampling_results)
137131

138132
self.worker.daemon = True
139133
self.worker.start()
140134

141-
def process_sampling_with_logprob_results(self):
142-
"""
143-
read tokens from paddle inference engine and process logprob results
144-
"""
145-
if current_platform.is_cuda():
146-
from fastdeploy.model_executor.ops.gpu import get_output_topk
147-
else:
148-
raise NotImplementedError("Only CUDA platform supports logprob.")
149-
150-
rank_id = self.cfg.parallel_config.local_data_parallel_id
151-
152-
while True:
153-
try:
154-
is_blocking = True
155-
get_output_topk(
156-
self.output_tokens,
157-
self.output_scores,
158-
self.output_ranks,
159-
K,
160-
rank_id,
161-
is_blocking,
162-
)
163-
164-
if self.output_tokens[0, 0] == -2:
165-
continue
166-
llm_logger.debug(
167-
f"rank_id {rank_id} self.output_tokens[0, 0] {self.output_tokens[0, 0]}"
168-
f"rank_id {rank_id} self.output_scores[0, 0] {self.output_scores[0, 0]}"
169-
)
170-
self._process_prefill_metrics()
171-
self._process_sampling_with_logprob_batch_output()
172-
except Exception as e:
173-
llm_logger.info(f"while get input_data error: {e} {traceback.format_exc()!s}")
174-
175135
def process_sampling_results(self):
176136
"""
177137
read tokens from paddle inference engine and process
@@ -187,6 +147,7 @@ def process_sampling_results(self):
187147
from fastdeploy.model_executor.ops.gpu import (
188148
get_output,
189149
get_output_ep,
150+
get_output_topk,
190151
speculate_get_output,
191152
)
192153
rank_id = self.cfg.parallel_config.local_data_parallel_id
@@ -207,7 +168,17 @@ def process_sampling_results(self):
207168
get_output_ep(self.output_tokens, rank_id, is_blocking)
208169

209170
else:
210-
get_output(self.output_tokens, rank_id, is_blocking)
171+
if self.use_logprobs:
172+
get_output_topk(
173+
self.output_tokens,
174+
self.output_scores,
175+
self.output_ranks,
176+
K,
177+
rank_id,
178+
is_blocking,
179+
)
180+
else:
181+
get_output(self.output_tokens, rank_id, is_blocking)
211182

212183
if self.output_tokens[0, 0] == -2:
213184
continue
@@ -305,129 +276,23 @@ def _compute_speculative_status(self):
305276
self.total_step = 0
306277
self.speculative_stats_step += 1
307278

308-
def _process_sampling_with_logprob_batch_output(self):
309-
"""
310-
batch post-processing logprob output function
311-
"""
312-
313-
batch = self.output_tokens[1, 0]
314-
tokens = self.output_tokens[2 : batch * (K + 1) + 2].numpy().reshape([batch, K + 1])[:, : (K + 1)]
315-
scores = self.output_scores[: batch * (K + 1)].numpy().reshape([batch, K + 1])[:, : (K + 1)]
316-
ranks = self.output_ranks[:batch].numpy()
317-
batch_result = list()
318-
for i in range(batch):
319-
if self.resource_manager.stop_flags[i]:
320-
continue
321-
task = self.resource_manager.tasks_list[i]
322-
task_id = task.request_id
323-
token_id = int(tokens[i, 0])
324-
token_ids = [token_id]
325-
recovery_stop = token_id == RECOVERY_STOP_SIGNAL
326-
if recovery_stop:
327-
llm_logger.info(f"recovery stop signal found at task {task_id}")
328-
if not recovery_stop and token_id < 0:
329-
continue
330-
331-
if task.get("prefill_chunk_info", None) is not None:
332-
prefill_chunk_num = task.get("prefill_chunk_num", 0)
333-
task.prefill_chunk_num = prefill_chunk_num + 1
334-
335-
if task.prefill_chunk_num < len(task.prefill_chunk_info):
336-
continue
337-
338-
self.total_step += 1
339-
current_time = time.time()
340-
if self.tokens_counter[task_id] == 0:
341-
metrics = RequestMetrics(
342-
arrival_time=task.arrival_time,
343-
inference_start_time=task.inference_start_time,
344-
first_token_time=time.time() - task.inference_start_time,
345-
time_in_queue=task.schedule_start_time - task.preprocess_end_time,
346-
preprocess_cost_time=task.preprocess_end_time - task.preprocess_start_time,
347-
request_start_time=task.arrival_time,
348-
)
349-
350-
self._record_first_token_metrics(task, current_time)
351-
352-
else:
353-
metrics = RequestMetrics(
354-
arrival_time=time.time(),
355-
request_start_time=task.arrival_time,
356-
)
357-
self.number_of_output_tokens += len(token_ids)
358-
self._record_metrics(task, current_time, token_ids)
359-
result = RequestOutput(
360-
request_id=task_id,
361-
outputs=CompletionOutput(
362-
index=i,
363-
send_idx=self.tokens_counter[task_id],
364-
token_ids=[],
365-
logprob=None,
366-
draft_token_ids=[],
367-
top_logprobs=None,
368-
),
369-
finished=False,
370-
metrics=metrics,
371-
)
372-
if self.tokens_counter[task_id] == 0:
373-
if task.messages is not None:
374-
result.prompt = task.messages
375-
result.num_cached_tokens = task.num_cached_tokens
376-
377-
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
378-
379-
if is_prefill and len(token_ids) > 1:
380-
result.outputs.draft_token_ids = copy.deepcopy(token_ids)
381-
382-
for idx, token_id in enumerate(token_ids):
383-
self.tokens_counter[task_id] += 1
384-
if token_id != RECOVERY_STOP_SIGNAL:
385-
result.outputs.token_ids.append(token_id)
386-
task.output_token_ids.append(token_id)
387-
result.outputs.logprob = float(scores[i, 0])
388-
# Construct top_logprobs
389-
topk_token_ids = tokens[i, :].tolist()
390-
topk_logprobs = scores[i, :].tolist()
391-
sampled_rank = ranks[i].item()
392-
393-
result.outputs.top_logprobs = LogprobsLists(
394-
logprob_token_ids=[topk_token_ids],
395-
logprobs=[topk_logprobs],
396-
sampled_token_ranks=[sampled_rank],
397-
)
398-
399-
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
400-
result.finished = True
401-
if recovery_stop:
402-
result.error_msg = "Recover is not supported, the result is incomplete!"
403-
llm_logger.info(
404-
f"Request: {task_id} finished, number of " f"generated tokens: {self.tokens_counter[task_id]}."
405-
)
406-
llm_logger.info(
407-
f"Request: {task_id} token ratio: {self.tokens_counter[task_id] / (time.time() - task.inference_start_time)}"
408-
)
409-
llm_logger.info(f"{self.resource_manager.info()}")
410-
if self.cfg.speculative_config.method:
411-
self._compute_speculative_status()
412-
if not is_prefill:
413-
self._record_completion_metrics(task, current_time)
414-
self._recycle_resources(task_id, i, task, result, is_prefill)
415-
break
416-
if not is_prefill or self.cfg.scheduler_config.name == "splitwise":
417-
batch_result.append(result)
418-
419-
self.postprocess(batch_result)
420-
421279
def _process_batch_output(self):
422280
"""
423281
batch post-processing function
424282
"""
425283

426284
tokens = self.output_tokens.numpy()
285+
scores = None
286+
ranks = None
427287
if self.cfg.speculative_config.method:
428288
batch = self.output_tokens[1]
429289
accept_num = tokens[2 : batch + 2]
430290
self._record_speculative_decoding_mertics(accept_num)
291+
elif self.use_logprobs:
292+
batch = self.output_tokens[1, 0]
293+
tokens = tokens[2 : batch * (K + 1) + 2].reshape([batch, K + 1])[:, : (K + 1)]
294+
scores = self.output_scores[: batch * (K + 1)].numpy().reshape([batch, K + 1])[:, : (K + 1)]
295+
ranks = self.output_ranks[:batch].numpy()
431296
else:
432297
batch = self.output_tokens[1, 0]
433298
tokens = tokens[2 : batch + 2]
@@ -522,6 +387,17 @@ def _process_batch_output(self):
522387
if token_id != RECOVERY_STOP_SIGNAL:
523388
result.outputs.token_ids.append(token_id)
524389
task.output_token_ids.append(token_id)
390+
if self.use_logprobs:
391+
result.outputs.logprob = float(scores[i, 0])
392+
# Construct top_logprobs
393+
topk_token_ids = tokens[i, :].tolist()
394+
topk_logprobs = scores[i, :].tolist()
395+
sampled_rank = ranks[i].item()
396+
result.outputs.top_logprobs = LogprobsLists(
397+
logprob_token_ids=[topk_token_ids],
398+
logprobs=[topk_logprobs],
399+
sampled_token_ranks=[sampled_rank],
400+
)
525401
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
526402
result.finished = True
527403
if recovery_stop:

0 commit comments

Comments
 (0)