diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 3906cd29b5..4bac4f04a5 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -294,6 +294,7 @@ class CompletionOutput: decode_type: int = 0 logprob: Optional[float] = None top_logprobs: Optional[LogprobsLists] = None + draft_top_logprobs: Optional[LogprobsLists] = None logprobs: Optional[SampleLogprobs] = None draft_token_ids: list[int] = None text: Optional[str] = None @@ -308,9 +309,9 @@ def to_dict(self): "index": self.index, "send_idx": self.send_idx, "token_ids": self.token_ids, - "decode_type": self.decode_type, "logprob": self.logprob, "top_logprobs": self.top_logprobs, + "draft_top_logprobs": self.draft_top_logprobs, "logprobs": self.logprobs, "draft_token_ids": self.draft_token_ids, "text": self.text, @@ -336,6 +337,8 @@ def __repr__(self) -> str: f"draft_token_ids={self.draft_token_ids}, " f"reasoning_content={self.reasoning_content!r}, " f"logprobs={self.logprobs}, " + f"top_logprobs={self.top_logprobs}, " + f"draft_top_logprobs={self.draft_top_logprobs}, " ) @@ -420,6 +423,7 @@ def __init__( request_id: str, prompt: Optional[str] = None, prompt_token_ids: Optional[list[int]] = None, + output_type: Optional[int] = 3, outputs: CompletionOutput = None, finished: bool = False, metrics: Optional[RequestMetrics] = None, @@ -430,6 +434,7 @@ def __init__( self.request_id = request_id self.prompt = prompt self.prompt_token_ids = prompt_token_ids + self.output_type = output_type self.outputs = outputs self.finished = finished self.metrics = metrics @@ -458,12 +463,21 @@ def add(self, next_output: RequestOutput) -> None: self.outputs.top_logprobs.logprob_token_ids.extend(next_output.outputs.top_logprobs.logprob_token_ids) self.outputs.top_logprobs.logprobs.extend(next_output.outputs.top_logprobs.logprobs) self.outputs.top_logprobs.sampled_token_ranks.extend(next_output.outputs.top_logprobs.sampled_token_ranks) + if next_output.outputs.draft_top_logprobs is not None: + self.outputs.draft_top_logprobs.logprob_token_ids.extend( + next_output.outputs.draft_top_logprobs.logprob_token_ids + ) + self.outputs.draft_top_logprobs.logprobs.extend(next_output.outputs.draft_top_logprobs.logprobs) + self.outputs.draft_top_logprobs.sampled_token_ranks.extend( + next_output.outputs.draft_top_logprobs.sampled_token_ranks + ) def __repr__(self) -> str: return ( f"RequestOutput(request_id={self.request_id}, " f"prompt={self.prompt!r}, " f"prompt_token_ids={self.prompt_token_ids}, " + f"output_type={self.output_type}, " f"outputs={self.outputs}, " f"finished={self.finished}, " f"num_cached_tokens={self.num_cached_tokens}, " @@ -484,6 +498,7 @@ def to_dict(self): "request_id": self.request_id, "prompt": self.prompt, "prompt_token_ids": self.prompt_token_ids, + "output_type": self.output_type, "outputs": None if self.outputs is None else self.outputs.to_dict(), "metrics": None if self.metrics is None else self.metrics.to_dict(), "finished": self.finished, diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index 0e6802e76b..ad767f563b 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -189,6 +189,7 @@ class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage logprobs: Optional[LogProbs] = None + draft_logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]] @@ -251,6 +252,7 @@ class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage logprobs: Optional[LogProbs] = None + draft_logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None arrival_time: Optional[float] = None @@ -283,6 +285,7 @@ class CompletionResponseChoice(BaseModel): completion_tokens: Optional[str] = None arrival_time: Optional[float] = None logprobs: Optional[CompletionLogprobs] = None + draft_logprobs: Optional[CompletionLogprobs] = None reasoning_content: Optional[str] = None finish_reason: Optional[Literal["stop", "length", "tool_calls"]] tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None @@ -321,6 +324,7 @@ class CompletionResponseStreamChoice(BaseModel): text: str arrival_time: float = None logprobs: Optional[CompletionLogprobs] = None + draft_logprobs: Optional[CompletionLogprobs] = None prompt_token_ids: Optional[List[int]] = None completion_token_ids: Optional[List[int]] = None text_after_process: Optional[str] = None @@ -410,6 +414,7 @@ class CompletionRequest(BaseModel): echo: Optional[bool] = False frequency_penalty: Optional[float] = Field(default=None, ge=-2, le=2) logprobs: Optional[int] = None + include_draft_logprobs: Optional[bool] = False # For logits and logprobs post processing temp_scaled_logprobs: bool = False top_p_normalized_logprobs: bool = False @@ -545,6 +550,7 @@ class ChatCompletionRequest(BaseModel): frequency_penalty: Optional[float] = Field(None, le=2, ge=-2) logprobs: Optional[bool] = False top_logprobs: Optional[int] = 0 + include_draft_logprobs: Optional[bool] = False # For logits and logprobs post processing temp_scaled_logprobs: bool = False diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 52cd556916..70bbd6a756 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -303,12 +303,18 @@ async def chat_completion_stream_generator( output = res["outputs"] output_top_logprobs = output["top_logprobs"] + output_draft_top_logprobs = output["draft_top_logprobs"] previous_num_tokens += len(output["token_ids"]) logprobs_res: Optional[LogProbs] = None + draft_logprobs_res: Optional[LogProbs] = None if request.logprobs and output_top_logprobs is not None: logprobs_res = self._create_chat_logprobs( output_top_logprobs, request.logprobs, request.top_logprobs ) + if request.include_draft_logprobs and output_draft_top_logprobs is not None: + draft_logprobs_res = self._create_chat_logprobs( + output_draft_top_logprobs, request.logprobs, request.top_logprobs + ) delta_message = DeltaMessage( reasoning_content="", @@ -336,6 +342,7 @@ async def chat_completion_stream_generator( index=0, delta=delta_message, logprobs=logprobs_res, + draft_logprobs=draft_logprobs_res, arrival_time=arrival_time, ) if res["finished"]: @@ -430,6 +437,7 @@ async def chat_completion_full_generator( previous_num_tokens = 0 current_waiting_time = 0 logprob_contents = [] + draft_logprob_contents = [] completion_token_ids = [] response_processor = ChatResponseProcessor( data_processor=self.engine_client.data_processor, @@ -476,12 +484,23 @@ async def chat_completion_full_generator( # The logprob for handling the response output = data["outputs"] output_top_logprobs = output["top_logprobs"] + output_draft_top_logprobs = output["draft_top_logprobs"] if output_top_logprobs is not None: + # logprobs logprobs_res = self._create_chat_logprobs( output_top_logprobs, request.logprobs, request.top_logprobs ) if logprobs_res and logprobs_res.content is not None: logprob_contents.extend(logprobs_res.content) + + # draf_logprobs + if request.include_draft_logprobs and output_draft_top_logprobs is not None: + draft_logprobs_res = self._create_chat_logprobs( + output_draft_top_logprobs, request.logprobs, request.top_logprobs + ) + if draft_logprobs_res and draft_logprobs_res.content is not None: + draft_logprob_contents.extend(draft_logprobs_res.content) + if data["finished"]: final_res = data task_is_finished = True @@ -515,11 +534,15 @@ async def chat_completion_full_generator( logprobs_full_res = None if logprob_contents: logprobs_full_res = LogProbs(content=logprob_contents) + draft_logprobs_full_res = None + if draft_logprob_contents: + draft_logprobs_full_res = LogProbs(content=draft_logprob_contents) choice = ChatCompletionResponseChoice( index=0, message=message, logprobs=logprobs_full_res, + draft_logprobs=draft_logprobs_full_res, finish_reason=None, ) has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 646b282abf..762b18941a 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -223,6 +223,7 @@ async def completion_full_generator( valid_results = [dict()] * num_choices output_tokens = [0] * num_choices aggregated_top_logprobs = [[[], [], []] for _ in range(num_choices)] + aggregated_draft_top_logprobs = [[[], [], []] for _ in range(num_choices)] aggregated_token_ids = [[] for _ in range(num_choices)] completion_batched_token_ids = [[] for _ in range(num_choices)] current_waiting_time = 0 @@ -256,11 +257,18 @@ async def completion_full_generator( output = data["outputs"] output_top_logprobs = output["top_logprobs"] + output_draft_top_logprobs = output["draft_top_logprobs"] if output_top_logprobs is not None: aggregated_top_logprobs[rid][0].extend(output_top_logprobs[0]) aggregated_top_logprobs[rid][1].extend(output_top_logprobs[1]) aggregated_top_logprobs[rid][2].extend(output_top_logprobs[2]) + # draft logprobs + if request.include_draft_logprobs and output_draft_top_logprobs is not None: + aggregated_draft_top_logprobs[rid][0].extend(output_draft_top_logprobs[0]) + aggregated_draft_top_logprobs[rid][1].extend(output_draft_top_logprobs[1]) + aggregated_draft_top_logprobs[rid][2].extend(output_draft_top_logprobs[2]) + aggregated_token_ids[rid].extend(data["outputs"]["token_ids"]) self.engine_client.data_processor.process_response_dict( @@ -271,6 +279,7 @@ async def completion_full_generator( if data.get("finished", False): data["output_token_ids"] = output_tokens[rid] data["outputs"]["top_logprobs"] = aggregated_top_logprobs[rid] + data["outputs"]["draft_top_logprobs"] = aggregated_draft_top_logprobs[rid] data["outputs"]["token_ids"] = aggregated_token_ids[rid] valid_results[rid] = data num_choices -= 1 @@ -423,10 +432,17 @@ async def completion_stream_generator( await self._process_echo_logic(request, idx, res["outputs"]) output = res["outputs"] output_top_logprobs = output["top_logprobs"] + output_draft_top_logprobs = output["draft_top_logprobs"] logprobs_res: Optional[CompletionLogprobs] = None + draft_logprobs_res: Optional[CompletionLogprobs] = None if request.logprobs and output_top_logprobs is not None: logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0) + # draft logprobs + if request.include_draft_logprobs and output_draft_top_logprobs is not None: + draft_logprobs_res = self._create_completion_logprobs( + output_draft_top_logprobs, request.logprobs, 0 + ) output_tokens[idx] += 1 delta_message = CompletionResponseStreamChoice( index=idx, @@ -439,6 +455,7 @@ async def completion_stream_generator( reasoning_content="", arrival_time=arrival_time, logprobs=logprobs_res, + draft_logprobs=draft_logprobs_res, ) if not res["finished"] and "delta_message" in output: delta_message_output = output["delta_message"] @@ -523,15 +540,23 @@ def request_output_to_completion_response( final_res = final_res_batch[idx] prompt_token_ids = prompt_batched_token_ids[idx] assert prompt_token_ids is not None + prompt_text = request.prompt completion_token_ids = completion_batched_token_ids[idx] output = final_res["outputs"] output_top_logprobs = output["top_logprobs"] + output_draft_top_logprobs = output["draft_top_logprobs"] aggregated_logprobs: Optional[CompletionLogprobs] = None if output_top_logprobs is not None: aggregated_logprobs = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0) + aggregated_draft_logprobs: Optional[CompletionLogprobs] = None + if output_draft_top_logprobs is not None: + aggregated_draft_logprobs = self._create_completion_logprobs( + output_draft_top_logprobs, request.logprobs, 0 + ) + if request.echo: prompt_text = self._echo_back_prompt(request, idx) token_ids = [*prompt_token_ids, *output["token_ids"]] @@ -554,6 +579,7 @@ def request_output_to_completion_response( reasoning_content=output.get("reasoning_content"), tool_calls=output.get("tool_call"), logprobs=aggregated_logprobs, + draft_logprobs=aggregated_draft_logprobs, finish_reason=finish_reason, ) choices.append(choice_data) diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 4e16e3d0a2..a1bb79d377 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -22,6 +22,7 @@ import weakref from collections import Counter from concurrent.futures import ThreadPoolExecutor +from typing import List import numpy as np import paddle @@ -67,15 +68,24 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn self.use_logprobs = self.cfg.model_config.enable_logprob if self.speculative_decoding: - self.output_tokens = paddle.full( - shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2], - fill_value=2, - dtype="int64", - ) + if self.use_logprobs: + self.output_tokens = paddle.full( + shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1) + MAX_BSZ + 3, 1], fill_value=2, dtype="int64" + ) + self.output_scores = paddle.full( + shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1), 1], fill_value=0.0, dtype="float32" + ) + self.output_ranks = paddle.full(shape=[MAX_BSZ * MAX_DRAFT_TOKENS], fill_value=0, dtype="int64") + else: + self.output_tokens = paddle.full( + shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2], + fill_value=2, + dtype="int64", + ) elif self.use_logprobs: self.output_tokens = paddle.full(shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64") self.output_scores = paddle.full(shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32") - self.output_ranks = paddle.full(shape=[MAX_BSZ], fill_value=0, dtype="int64") + self.output_ranks = paddle.full(shape=[MAX_BSZ], fill_value=0, dtype="int64") else: self.output_tokens = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64") self.worker = None @@ -107,6 +117,7 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn self.executor = ThreadPoolExecutor(max_workers=1) self.prefill_result_status = dict() self._finalizer = weakref.finalize(self, self._cleanup_resources) + self._batch_result_buffer = None def _cleanup_resources(self): """Cleaning up shared memory resources""" @@ -323,7 +334,14 @@ def process_sampling_results(self): self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1 ): - speculate_get_output(self.output_tokens, rank_id, is_blocking, True) + if self.use_logprobs: + # TODO speculate_get_output_with_topk + pass + else: + speculate_get_output(self.output_tokens, rank_id, is_blocking, True) + elif self.use_logprobs: + # TODO speculate_get_output_with_topk + pass else: speculate_get_output(self.output_tokens, rank_id, is_blocking, False) if self.output_tokens[0] == -2: @@ -372,7 +390,7 @@ def process_metrics(): self.executor.submit(process_metrics) - def postprocess(self, batch_result): + def postprocess(self, batch_result: List[RequestOutput], mtype=3): """ single post-processing function @@ -380,7 +398,25 @@ def postprocess(self, batch_result): batch_result (list): batch results """ try: - self.cached_generated_tokens.put_results(batch_result) + if self.cfg.speculative_config.method and self.use_logprobs: + if mtype == 3: # target + has_finished = any(r.finished for r in batch_result) + if has_finished: + self.cached_generated_tokens.put_results(batch_result) + else: + self._batch_result_buffer = batch_result + elif mtype == 4: # draft + target_batch_result = [] + draft_batch_result = batch_result + for target, decode in zip(self._batch_result_buffer, draft_batch_result): + target.outputs.draft_top_logprobs = decode.outputs.draft_top_logprobs + target_batch_result.append(target) + self._batch_result_buffer = None + self.cached_generated_tokens.put_results(target_batch_result) + else: + self.cached_generated_tokens.put_results(batch_result) + else: + self.cached_generated_tokens.put_results(batch_result) except Exception as e: llm_logger.error(f"Error in TokenProcessor's postprocess: {e}, {str(traceback.format_exc())}") @@ -468,9 +504,25 @@ def _process_batch_output(self): tokens = self.output_tokens.numpy() scores = None ranks = None + # target:3, draft:4 + mtype = 3 if self.cfg.speculative_config.method: - batch = self.output_tokens[1] - accept_num = tokens[2 : batch + 2] + if self.use_logprobs: + mtype = int(self.output_tokens[1, 0].item()) + batch = self.output_tokens[2, 0] + accept_num = [int(num[0]) for num in self.output_tokens[3 : batch + 3]] + tokens = tokens[3 + MAX_BSZ : 3 + MAX_BSZ + batch * MAX_DRAFT_TOKENS * (K + 1)].reshape( + [batch, MAX_DRAFT_TOKENS, K + 1] + ) + scores = ( + self.output_scores[: batch * MAX_DRAFT_TOKENS * (K + 1)] + .numpy() + .reshape([batch, MAX_DRAFT_TOKENS, K + 1]) + ) + ranks = self.output_ranks[: batch * MAX_DRAFT_TOKENS].numpy().reshape([batch, MAX_DRAFT_TOKENS]) + else: + batch = self.output_tokens[1] + accept_num = tokens[2 : batch + 2] self._record_speculative_decoding_mertics(accept_num) elif self.use_logprobs: batch = self.output_tokens[1, 0] @@ -498,6 +550,8 @@ def _process_batch_output(self): if recovery_stop: llm_logger.info(f"recovery stop signal found at task {task_id}") token_ids = [RECOVERY_STOP_SIGNAL] + elif self.use_logprobs: + token_ids = tokens[i][:, 0].tolist()[: accept_num[i]] else: token_ids = tokens[ 2 @@ -553,6 +607,7 @@ def _process_batch_output(self): self._record_metrics(task, current_time, token_ids) result = RequestOutput( request_id=task_id, + output_type=mtype, outputs=CompletionOutput( index=i, send_idx=self.tokens_counter[task_id], @@ -572,23 +627,47 @@ def _process_batch_output(self): if is_prefill and len(token_ids) > 1: result.outputs.draft_token_ids = copy.deepcopy(token_ids) - for token_id in token_ids: + for batch_token_index in range(len(token_ids)): + token_id = token_ids[batch_token_index] self.tokens_counter[task_id] += 1 if token_id != RECOVERY_STOP_SIGNAL: if not (envs.FD_ENABLE_INTERNAL_ADAPTER and token_id in task.eos_token_ids): result.outputs.token_ids.append(token_id) task.output_token_ids.append(token_id) if self.use_logprobs: - result.outputs.logprob = float(scores[i, 0]) - # Construct top_logprobs - topk_token_ids = tokens[i, :].tolist() - topk_logprobs = scores[i, :].tolist() - sampled_rank = ranks[i].item() - result.outputs.top_logprobs = LogprobsLists( - logprob_token_ids=[topk_token_ids], - logprobs=[topk_logprobs], - sampled_token_ranks=[sampled_rank], - ) + if self.cfg.speculative_config.method: + result.outputs.logprob = float(scores[i, batch_token_index, 0]) + topk_token_ids = tokens[i, batch_token_index, :].tolist() + topk_logprobs = scores[i, batch_token_index, :].tolist() + sampled_rank = ranks[i, batch_token_index].item() + else: + result.outputs.logprob = float(scores[i, 0]) + topk_token_ids = tokens[i, :].tolist() + topk_logprobs = scores[i, :].tolist() + sampled_rank = ranks[i].item() + + if mtype == 3: # top_logprobs + if result.outputs.top_logprobs is None: + result.outputs.top_logprobs = LogprobsLists( + logprob_token_ids=[topk_token_ids], + logprobs=[topk_logprobs], + sampled_token_ranks=[sampled_rank], + ) + else: + result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids]) + result.outputs.top_logprobs.logprobs.extend([topk_logprobs]) + result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank]) + elif mtype == 4: # draft_top_logprobs + if result.outputs.draft_top_logprobs is None: + result.outputs.draft_top_logprobs = LogprobsLists( + logprob_token_ids=[topk_token_ids], + logprobs=[topk_logprobs], + sampled_token_ranks=[sampled_rank], + ) + else: + result.outputs.draft_top_logprobs.logprob_token_ids.extend([topk_token_ids]) + result.outputs.draft_top_logprobs.logprobs.extend([topk_logprobs]) + result.outputs.draft_top_logprobs.sampled_token_ranks.extend([sampled_rank]) if token_id in task.eos_token_ids or is_prefill or recovery_stop: result.finished = True if recovery_stop: @@ -613,7 +692,7 @@ def _process_batch_output(self): ): batch_result.append(result) - self.postprocess(batch_result) + self.postprocess(batch_result, mtype) def _record_metrics(self, task, current_time, token_ids): """Record all metrics for a task""" diff --git a/tests/output/test_process_batch_output.py b/tests/output/test_process_batch_output.py new file mode 100644 index 0000000000..2d31ca327d --- /dev/null +++ b/tests/output/test_process_batch_output.py @@ -0,0 +1,191 @@ +import time +import unittest +from unittest.mock import Mock + +import paddle + +from fastdeploy.output.token_processor import TokenProcessor + +paddle.set_device("cpu") + + +# Mock classes and constants needed for the test +class MockConfig: + class ParallelConfig: + local_data_parallel_id = 0 + + class SpeculativeConfig: + method = None + + class ModelConfig: + enable_logprob = False + + class SchedulerConfig: + name = "default" + + parallel_config = ParallelConfig() + speculative_config = SpeculativeConfig() + model_config = ModelConfig() + scheduler_config = SchedulerConfig() + + +class MockTask: + def __init__(self): + self.request_id = "test_request_1" + self.arrival_time = time.time() + self.inference_start_time = time.time() + self.schedule_start_time = time.time() + self.preprocess_end_time = time.time() - 0.1 + self.preprocess_start_time = time.time() - 0.2 + self.eos_token_ids = [2] + self.output_token_ids = [] + self.messages = "Test prompt" + self.num_cached_tokens = 0 + self.disaggregate_info = None + self.prefill_chunk_info = None + self.prefill_chunk_num = 0 + + def get(self, key: str, default_value=None): + if hasattr(self, key): + return getattr(self, key) + elif hasattr(self.sampling_params, key): + return getattr(self.sampling_params, key) + else: + return default_value + + +class MockResourceManager: + def __init__(self): + self.stop_flags = [False] + self.tasks_list = [MockTask()] + self.to_be_rescheduled_request_id_set = set() + + def info(self): + return "Mock resource manager info" + + def reschedule_preempt_task(self, task_id): + pass + + +# Constants +RECOVERY_STOP_SIGNAL = -3 +MAX_BSZ = 512 +K = 20 +MAX_DRAFT_TOKENS = 6 +SPECULATE_MAX_BSZ = 256 + + +class TestTokenProcessorProcessBatchOutput(unittest.TestCase): + + def setup_token_processor(self, speculative_decoding=False, use_logprobs=False): + """Helper method to setup TokenProcessor with different configurations""" + cfg = MockConfig() + cfg.speculative_config.method = "mtp" if speculative_decoding else None + cfg.speculative_config.num_speculative_tokens = 1 + cfg.model_config.enable_logprob = use_logprobs + + processor = TokenProcessor.__new__(TokenProcessor) + processor.cfg = cfg + processor.cached_generated_tokens = [] + processor.executor = Mock() + processor.engine_worker_queue = Mock() + processor.split_connector = Mock() + processor.resource_manager = MockResourceManager() + task = MockTask() + processor.resource_manager.tasks_list = [task] + processor.tokens_counter = {task.request_id: 0} + processor.total_step = 0 + processor.number_of_output_tokens = 0 + processor.prefill_result_status = {} + processor.use_logprobs = use_logprobs + processor.num_draft_tokens = 0 + processor.num_accepted_tokens = 0 + processor.num_emitted_tokens = 0 + processor.max_num_emitted_tokens = 0 + processor.num_rest_requests_per_head = [ + 0, + ] * MAX_DRAFT_TOKENS + processor.num_accept_requests_per_head = [ + 0, + ] * MAX_DRAFT_TOKENS + processor.speculative_stats_step = 0 + + # processor._recycle_resources = Mock() + + if speculative_decoding: + if use_logprobs: + processor.output_tokens = paddle.full( + shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1) + MAX_BSZ + 3, 1], + fill_value=2, + dtype="int64", + ) + processor.output_scores = paddle.full( + shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1), 1], + fill_value=0.0, + dtype="float32", + ) + processor.output_ranks = paddle.full( + shape=[MAX_BSZ * MAX_DRAFT_TOKENS], + fill_value=0, + dtype="int64", + ) + else: + processor.output_tokens = paddle.full( + shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2], + fill_value=2, + dtype="int64", + ) + elif use_logprobs: + processor.output_tokens = paddle.full(shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64") + processor.output_scores = paddle.full(shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32") + processor.output_ranks = paddle.full(shape=[MAX_BSZ], fill_value=0, dtype="int64") + else: + processor.output_tokens = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64") + + return processor + + def test_speculative_decoding_use_logprobs(self): + """Test basic speculative decoding scenario""" + processor = self.setup_token_processor(speculative_decoding=True, use_logprobs=True) + + # stop_flag + processor.output_tokens[0, 0] = 2 + # mtype + processor.output_tokens[1, 0] = 3 # target = 3, decode = 4 + # batch + processor.output_tokens[2, 0] = 2 + # accept_num + processor.output_tokens[3, 0] = 3 + processor.output_tokens[4, 0] = 3 + + batch = processor.output_tokens[2, 0] + accept_num = [int(num[0]) for num in processor.output_tokens[3 : batch + 3]] + + # init + print(f"\nbatch: {batch}, accept_num: {accept_num}") + for i in range(batch): + for j in range(accept_num[i]): + for k in range(K + 1): + index = ( + 3 + + batch + + i * MAX_DRAFT_TOKENS * (K + 1) + + j * (K + 1) + + k + ) + print(f"i:{i}, j:{j} k:{k} index: {index}") + processor.output_tokens[index, 0] = 5 + i * 10 + j * 2 + k + processor.output_scores[i * MAX_DRAFT_TOKENS * (K + 1) + j * (K + 1) + k, 0] = float( + 0.1 * (5 + i * 10 + j * 2 + k) + ) + processor.output_ranks[i * MAX_DRAFT_TOKENS + j] = j + 1 + + print(f"{processor.output_tokens}") + print(f"{processor.output_scores}") + print(f"{processor.output_ranks}") + + # processor._process_batch_output() + + +if __name__ == "__main__": + unittest.main(verbosity=2, buffer=False)