Skip to content

Commit 8829724

Browse files
[feat] completion api supports passing input token ids in either prompt or prompt_token_ids (#3311)
* [feat] completion api supports passing input token ids in either `prompt` or `prompt_token_ids` * [fix] update comment * [fix] fix type error * [test] add a unittest file for serving api test * [test] try to fix ci error * [chore] rename test function names * [test] try to fix ci error * [test] try to fix ci error * [test] add tests for qwen
1 parent 17b414c commit 8829724

File tree

6 files changed

+343
-70
lines changed

6 files changed

+343
-70
lines changed

fastdeploy/entrypoints/openai/protocol.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ class CompletionRequest(BaseModel):
438438

439439
max_streaming_response_tokens: Optional[int] = None
440440
return_token_ids: Optional[bool] = None
441-
prompt_token_ids: Optional[List[int]] = None
441+
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None
442442
# doc: end-completion-extra-params
443443

444444
def to_dict_for_infer(self, request_id=None, prompt=None):
@@ -463,11 +463,11 @@ def to_dict_for_infer(self, request_id=None, prompt=None):
463463
if prompt is not None:
464464
req_dict["prompt"] = prompt
465465

466-
if "prompt_token_ids" in req_dict:
467-
if "prompt" in req_dict:
468-
del req_dict["prompt"]
469-
else:
470-
assert len(prompt) > 0
466+
# if "prompt_token_ids" in req_dict:
467+
# if "prompt" in req_dict:
468+
# del req_dict["prompt"]
469+
# else:
470+
# assert len(prompt) > 0
471471

472472
guided_json_object = None
473473
if self.response_format is not None:
@@ -572,7 +572,7 @@ class ChatCompletionRequest(BaseModel):
572572
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
573573
# doc: end-chat-completion-sampling-params
574574

575-
# doc: start-completion-extra-params
575+
# doc: start-chat-completion-extra-params
576576
chat_template_kwargs: Optional[dict] = None
577577
chat_template: Optional[str] = None
578578
reasoning_max_tokens: Optional[int] = None

fastdeploy/entrypoints/openai/serving_completion.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -81,35 +81,50 @@ async def create_completion(self, request: CompletionRequest):
8181
request_id = f"cmpl-{request.user}-{uuid.uuid4()}"
8282
else:
8383
request_id = f"cmpl-{uuid.uuid4()}"
84-
api_server_logger.info(f"initialize request {request_id}")
84+
api_server_logger.info(f"Initialize request {request_id}: {request}")
8585
request_prompt_ids = None
8686
request_prompts = None
87+
88+
# Handle prompt and prompt_token_ids
8789
try:
88-
if isinstance(request.prompt, str):
89-
request_prompts = [request.prompt]
90-
elif isinstance(request.prompt, list) and all(isinstance(item, int) for item in request.prompt):
91-
request_prompt_ids = [request.prompt]
92-
elif isinstance(request.prompt, list) and all(isinstance(item, str) for item in request.prompt):
93-
request_prompts = request.prompt
94-
elif isinstance(request.prompt, list):
95-
for item in request.prompt:
96-
if isinstance(item, list) and all(isinstance(x, int) for x in item):
97-
continue
98-
else:
99-
raise ValueError("Prompt must be a string, a list of strings or a list of integers.")
100-
request_prompt_ids = request.prompt
90+
if request.prompt_token_ids is not None: # let `prompt_token_ids` support batch inference
91+
assert len(request.prompt_token_ids) > 0, "prompt_token_ids should not be an empty list"
92+
if isinstance(request.prompt_token_ids[0], list):
93+
request_prompt_ids = request.prompt_token_ids
94+
elif isinstance(request.prompt_token_ids[0], int):
95+
request_prompt_ids = [request.prompt_token_ids]
96+
else:
97+
raise ValueError(
98+
"If prompt_token_ids is provided, its type should be one of: list[int], list[list[int]]"
99+
)
100+
# reset `prompt_token_ids` to avoid data processor directly using it; let data processor fill it
101+
request.prompt_token_ids = None
101102
else:
102-
raise ValueError("Prompt must be a string, a list of strings or a list of integers.")
103+
if isinstance(request.prompt, str):
104+
request_prompts = [request.prompt]
105+
elif isinstance(request.prompt, list) and all(isinstance(item, int) for item in request.prompt):
106+
request_prompt_ids = [request.prompt]
107+
elif isinstance(request.prompt, list) and all(isinstance(item, str) for item in request.prompt):
108+
request_prompts = request.prompt
109+
elif isinstance(request.prompt, list):
110+
for item in request.prompt:
111+
if isinstance(item, list) and all(isinstance(x, int) for x in item):
112+
continue
113+
else:
114+
raise ValueError("If prompt is a list, each item type must be one of: str, list[int]")
115+
request_prompt_ids = request.prompt
116+
else:
117+
raise ValueError("Prompt type must be one of: str, list[str], list[int], list[list[int]]")
103118
except Exception as e:
104119
error_msg = f"OpenAIServingCompletion create_completion: {e}, {str(traceback.format_exc())}"
105120
api_server_logger.error(error_msg)
106121
return ErrorResponse(message=error_msg, code=400)
107122

108123
if request_prompt_ids is not None:
109124
request_prompts = request_prompt_ids
110-
num_choices = len(request_prompts)
111125

112-
api_server_logger.info(f"start inference for request {num_choices}")
126+
num_choices = len(request_prompts)
127+
api_server_logger.info(f"Start preprocessing request: req_id={request_id}), num_choices={num_choices}")
113128
prompt_batched_token_ids = []
114129
text_after_process_list = []
115130
try:
@@ -131,7 +146,7 @@ async def create_completion(self, request: CompletionRequest):
131146
request_id_idx = f"{request_id}-{idx}"
132147
current_req_dict = request.to_dict_for_infer(request_id_idx, prompt)
133148
current_req_dict["arrival_time"] = time.time()
134-
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict)
149+
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict) # tokenize
135150
if isinstance(prompt_token_ids, np.ndarray):
136151
prompt_token_ids = prompt_token_ids.tolist()
137152
text_after_process_list.append(current_req_dict.get("text_after_process"))

fastdeploy/input/ernie4_5_processor.py

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -87,33 +87,45 @@ def process_request(self, request, max_model_len=None, **kwargs):
8787
bool: Whether preprocessing is successful
8888
str: error message
8989
"""
90+
data_processor_logger.info(f"Start processing request: {request}")
9091
request.chat_template = kwargs.get("chat_template")
9192
request = self._apply_default_parameters(request)
9293
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
9394
request.eos_token_ids = self.eos_token_ids
95+
96+
# processing stop_sequences
9497
stop_sequences = request.get("stop", [])
9598
if stop_sequences is not None and len(stop_sequences) != 0:
9699
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
97100
request.set("stop_token_ids", stop_seqs)
98101
request.set("stop_seqs_len", stop_seqs_len)
99102

103+
# processing bad_words
100104
bad_words = request.get("bad_words")
101105
bad_words_token_ids = request.get("bad_words_token_ids")
102106
if bad_words:
103107
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
104108
request["bad_words_token_ids"] = bad_words_token_ids
105109

110+
# processing prompt_token_ids
106111
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
107-
if request.prompt is None and request.messages is None:
108-
raise ValueError(f"The request should have `prompt_token_ids`, `prompt` or `messages`: {request}.")
109112
if request.prompt is not None:
110-
prompt = request.prompt if request.prompt is not None else request.messages[0]
111-
prompt = prompt[0] if isinstance(prompt, list) else prompt
112-
tokens = self.tokenizer.tokenize(prompt)
113-
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
114-
request.prompt_token_ids = token_ids
115-
data_processor_logger.info(f"req_id:{request.request_id}, tokens:{tokens}, token_ids: {token_ids}")
116-
else:
113+
# prompt = request.prompt if request.prompt is not None else request.messages[0]
114+
prompt = request.prompt
115+
assert isinstance(prompt, str) or (
116+
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
117+
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
118+
119+
if isinstance(prompt, list): # if prompt is a token id list
120+
request.prompt_token_ids = prompt
121+
else:
122+
tokens = self.tokenizer.tokenize(prompt)
123+
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
124+
request.prompt_token_ids = token_ids
125+
data_processor_logger.debug(
126+
f"request_ids: {request.request_id}, prompt: {prompt}, tokens: {tokens}, token_ids: {token_ids}"
127+
)
128+
elif request.messages is not None:
117129
task = request.to_dict()
118130
chat_template_kwargs = kwargs.get("chat_template_kwargs")
119131
if chat_template_kwargs:
@@ -124,24 +136,26 @@ def process_request(self, request, max_model_len=None, **kwargs):
124136
else:
125137
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
126138
request.prompt_token_ids = self.messages2ids(task)
139+
else:
140+
raise ValueError(f"The request should have `prompt_token_ids`, `prompt` or `messages`: {request}.")
127141

128142
if len(request.prompt_token_ids) == 0:
129143
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
144+
145+
# truncate prompts that exceed the length limit
130146
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
131147
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
132148
if request.get("max_tokens") is None:
133-
request.set(
134-
"max_tokens",
135-
max(1, max_model_len - len(request.prompt_token_ids)),
136-
)
149+
request.set("max_tokens", max(1, max_model_len - len(request.prompt_token_ids)))
137150
if request.get("temperature") < _SAMPLING_EPS:
138151
# zero temperature is equivalent to greedy sampling
139152
request.set("temperature", 1)
140153
if request.get("top_p") < _SAMPLING_EPS:
141154
request.set("top_p", _SAMPLING_EPS)
142155
if self.reasoning_parser and self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser":
143156
request.enable_thinking = True
144-
data_processor_logger.info(f"Processed request {request}")
157+
158+
data_processor_logger.info(f"Processed request: {request}")
145159
return request
146160

147161
def process_request_dict(self, request, max_model_len=None):
@@ -155,6 +169,7 @@ def process_request_dict(self, request, max_model_len=None):
155169
bool: Whether preprocessing is successful
156170
str: error message
157171
"""
172+
data_processor_logger.info(f"Start processing request dict: {request}")
158173
request = self._apply_default_parameters(request)
159174
if not request.get("eos_token_ids"):
160175
request["eos_token_ids"] = self.eos_token_ids
@@ -175,18 +190,21 @@ def process_request_dict(self, request, max_model_len=None):
175190

176191
# processing prompt_token_ids
177192
if not request.get("prompt_token_ids"):
178-
if request.get("prompt") is None and request.get("messages") is None:
179-
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
180193
if request.get("prompt"):
181194
prompt = request.get("prompt")
182-
prompt = prompt[0] if isinstance(prompt, list) else prompt
183-
request["text_after_process"] = prompt
184-
tokens = self.tokenizer.tokenize(prompt)
185-
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
186-
request["prompt_token_ids"] = token_ids
187-
req_id = request.get("request_id", None)
188-
data_processor_logger.info(f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}")
189-
else:
195+
assert isinstance(prompt, str) or (
196+
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
197+
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
198+
if isinstance(prompt, list): # if prompt is a token id list
199+
request["prompt_token_ids"] = prompt
200+
else:
201+
request["text_after_process"] = prompt
202+
tokens = self.tokenizer.tokenize(prompt)
203+
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
204+
request["prompt_token_ids"] = token_ids
205+
req_id = request.get("request_id", None)
206+
data_processor_logger.info(f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}")
207+
elif request.get("messages"):
190208
chat_template_kwargs = request.get("chat_template_kwargs")
191209
if chat_template_kwargs:
192210
if isinstance(chat_template_kwargs, dict):
@@ -196,6 +214,9 @@ def process_request_dict(self, request, max_model_len=None):
196214
else:
197215
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
198216
request["prompt_token_ids"] = self.messages2ids(request)
217+
else:
218+
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
219+
199220
if len(request["prompt_token_ids"]) == 0:
200221
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
201222

@@ -211,8 +232,8 @@ def process_request_dict(self, request, max_model_len=None):
211232
request["top_p"] = _SAMPLING_EPS
212233
if self.reasoning_parser and self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser":
213234
request["enable_thinking"] = True
214-
data_processor_logger.info(f"Processed request {request}")
215235

236+
data_processor_logger.info(f"Processed request dict: {request}")
216237
return request
217238

218239
def process_response(self, response_dict, **kwargs):

0 commit comments

Comments
 (0)