Skip to content

Commit 420435c

Browse files
committed
[feat] completion api supports passing input token ids in either prompt or prompt_token_ids
1 parent 21caa63 commit 420435c

File tree

7 files changed

+354
-78
lines changed

7 files changed

+354
-78
lines changed

fastdeploy/entrypoints/engine_client.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,19 +96,19 @@ def create_zmq_client(self, model, mode):
9696
self.zmq_client = ZmqClient(model, mode)
9797
self.zmq_client.connect()
9898

99-
def format_and_add_data(self, prompts: dict):
99+
def format_and_add_data(self, req_dict: dict):
100100
"""
101101
Format the request data and send the request to the server.
102102
"""
103-
if "request_id" not in prompts:
103+
if "request_id" not in req_dict:
104104
request_id = str(uuid.uuid4())
105-
prompts["request_id"] = request_id
105+
req_dict["request_id"] = request_id
106106

107-
if "max_tokens" not in prompts:
108-
prompts["max_tokens"] = self.max_model_len - 1
107+
if "max_tokens" not in req_dict:
108+
req_dict["max_tokens"] = self.max_model_len - 1
109109

110-
self.add_requests(prompts)
111-
return prompts["prompt_token_ids"]
110+
self.add_requests(req_dict)
111+
return req_dict["prompt_token_ids"]
112112

113113
def add_requests(self, task):
114114
"""

fastdeploy/entrypoints/openai/protocol.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ class CompletionRequest(BaseModel):
375375

376376
max_streaming_response_tokens: Optional[int] = None
377377
return_token_ids: Optional[bool] = None
378-
prompt_token_ids: Optional[List[int]] = None
378+
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None
379379
# doc: end-completion-extra-params
380380

381381
def to_dict_for_infer(self, request_id=None, prompt=None):
@@ -400,11 +400,11 @@ def to_dict_for_infer(self, request_id=None, prompt=None):
400400
if prompt is not None:
401401
req_dict["prompt"] = prompt
402402

403-
if "prompt_token_ids" in req_dict:
404-
if "prompt" in req_dict:
405-
del req_dict["prompt"]
406-
else:
407-
assert len(prompt) > 0
403+
# if "prompt_token_ids" in req_dict:
404+
# if "prompt" in req_dict:
405+
# del req_dict["prompt"]
406+
# else:
407+
# assert len(prompt) > 0
408408

409409
guided_json_object = None
410410
if self.response_format is not None:
@@ -503,7 +503,7 @@ class ChatCompletionRequest(BaseModel):
503503
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
504504
# doc: end-chat-completion-sampling-params
505505

506-
# doc: start-completion-extra-params
506+
# doc: start-chat-completion-extra-params
507507
chat_template_kwargs: Optional[dict] = None
508508
reasoning_max_tokens: Optional[int] = None
509509
structural_tag: Optional[str] = None

fastdeploy/entrypoints/openai/serving_completion.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from aiozmq import zmq
2626

2727
from fastdeploy.engine.request import RequestOutput
28+
from fastdeploy.entrypoints.engine_client import EngineClient
2829
from fastdeploy.entrypoints.openai.protocol import (
2930
CompletionLogprobs,
3031
CompletionRequest,
@@ -41,7 +42,7 @@
4142

4243
class OpenAIServingCompletion:
4344
def __init__(self, engine_client, pid, ips, max_waiting_time):
44-
self.engine_client = engine_client
45+
self.engine_client: EngineClient = engine_client
4546
self.pid = pid
4647
self.master_ip = ips
4748
self.host_ip = get_host_ip()
@@ -72,41 +73,57 @@ async def create_completion(self, request: CompletionRequest):
7273
request_id = f"cmpl-{request.user}-{uuid.uuid4()}"
7374
else:
7475
request_id = f"cmpl-{uuid.uuid4()}"
75-
api_server_logger.info(f"initialize request {request_id}")
76+
api_server_logger.info(f"Initialize request {request_id}: {request}")
7677
request_prompt_ids = None
7778
request_prompts = None
79+
80+
# Handle prompt and prompt_token_ids
7881
try:
79-
if isinstance(request.prompt, str):
80-
request_prompts = [request.prompt]
81-
elif isinstance(request.prompt, list) and all(isinstance(item, int) for item in request.prompt):
82-
request_prompt_ids = [request.prompt]
83-
elif isinstance(request.prompt, list) and all(isinstance(item, str) for item in request.prompt):
84-
request_prompts = request.prompt
85-
elif isinstance(request.prompt, list):
86-
for item in request.prompt:
87-
if isinstance(item, list) and all(isinstance(x, int) for x in item):
88-
continue
89-
else:
90-
raise ValueError("Prompt must be a string, a list of strings or a list of integers.")
91-
request_prompt_ids = request.prompt
82+
if request.prompt_token_ids is not None: # let `prompt_token_ids` support batch inference
83+
assert len(request.prompt_token_ids) > 0, "prompt_token_ids should not be an empty list"
84+
if isinstance(request.prompt_token_ids[0], list):
85+
request_prompt_ids = request.prompt_token_ids
86+
elif isinstance(request.prompt_token_ids[0], int):
87+
request_prompt_ids = [request.prompt_token_ids]
88+
else:
89+
raise ValueError(
90+
"If prompt_token_ids is provided, its type should be one of: list[int], list[list[int]]"
91+
)
92+
# reset `prompt_token_ids` because token ids are expected to be passed in `prompt`
93+
request.prompt_token_ids = None
9294
else:
93-
raise ValueError("Prompt must be a string, a list of strings or a list of integers.")
95+
if isinstance(request.prompt, str):
96+
request_prompts = [request.prompt]
97+
elif isinstance(request.prompt, list) and all(isinstance(item, int) for item in request.prompt):
98+
request_prompt_ids = [request.prompt]
99+
elif isinstance(request.prompt, list) and all(isinstance(item, str) for item in request.prompt):
100+
request_prompts = request.prompt
101+
elif isinstance(request.prompt, list):
102+
for item in request.prompt:
103+
if isinstance(item, list) and all(isinstance(x, int) for x in item):
104+
continue
105+
else:
106+
raise ValueError("If prompt is a list, each item type must be one of: str, list[int]")
107+
request_prompt_ids = request.prompt
108+
else:
109+
raise ValueError("Prompt type must be one of: str, list[str], list[int], list[list[int]]")
94110
except Exception as e:
95111
return ErrorResponse(message=str(e), code=400)
96112

97113
if request_prompt_ids is not None:
98114
request_prompts = request_prompt_ids
99-
num_choices = len(request_prompts)
100115

101-
api_server_logger.info(f"start inference for request {num_choices}")
116+
num_choices = len(request_prompts)
117+
api_server_logger.info(f"Start preprocessing request: req_id={request_id}), num_choices={num_choices}")
102118
prompt_batched_token_ids = []
103119
try:
104-
for idx, prompt in enumerate(request_prompts):
120+
for idx, prompt in enumerate(request_prompts): # process each prompt for this batch completion request
105121
request_id_idx = f"{request_id}-{idx}"
106122
current_req_dict = request.to_dict_for_infer(request_id_idx, prompt)
123+
api_server_logger.debug(f"current_req_dict: {current_req_dict}")
107124
try:
108125
current_req_dict["arrival_time"] = time.time()
109-
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict)
126+
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict) # tokenize
110127
if isinstance(prompt_token_ids, np.ndarray):
111128
prompt_token_ids = prompt_token_ids.tolist()
112129
prompt_batched_token_ids.append(prompt_token_ids)
@@ -115,6 +132,10 @@ async def create_completion(self, request: CompletionRequest):
115132

116133
del current_req_dict
117134

135+
api_server_logger.info(
136+
f"Finish preprocessing request: req_id={request_id}, lengths={[len(t) for t in prompt_batched_token_ids]}"
137+
)
138+
118139
try:
119140
if self.max_waiting_time < 0:
120141
await self.engine_client.semaphore.acquire()

fastdeploy/input/ernie_processor.py

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -91,40 +91,52 @@ def process_request(self, request, max_model_len=None, **kwargs):
9191
request = self._apply_default_parameters(request)
9292
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
9393
request.eos_token_ids = self.eos_token_ids
94+
95+
# processing stop_sequences
9496
stop_sequences = request.get("stop", [])
9597
if stop_sequences is not None and len(stop_sequences) != 0:
9698
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
9799
request.set("stop_token_ids", stop_seqs)
98100
request.set("stop_seqs_len", stop_seqs_len)
99101

102+
# processing prompt_token_ids
100103
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
101-
if request.prompt is None and request.messages is None:
102-
raise ValueError(f"The request should have `prompt_token_ids`, `prompt` or `messages`: {request}.")
103104
if request.prompt is not None:
104-
prompt = request.prompt if request.prompt is not None else request.messages[0]
105-
prompt = prompt[0] if isinstance(prompt, list) else prompt
106-
tokens = self.tokenizer.tokenize(prompt)
107-
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
108-
request.prompt_token_ids = token_ids
109-
data_processor_logger.info(f"req_id:{request.request_id}, tokens:{tokens}, token_ids: {token_ids}")
110-
else:
105+
# prompt = request.prompt if request.prompt is not None else request.messages[0]
106+
prompt = request.prompt
107+
assert isinstance(prompt, str) or (
108+
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
109+
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
110+
111+
if isinstance(prompt, list): # if prompt is a token id list
112+
request["prompt_token_ids"] = prompt
113+
else:
114+
tokens = self.tokenizer.tokenize(prompt)
115+
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
116+
request.prompt_token_ids = token_ids
117+
data_processor_logger.debug(
118+
f"request_ids: {request.request_id}, prompt: {prompt}, tokens: {tokens}, token_ids: {token_ids}"
119+
)
120+
elif request.messages is not None:
111121
request.prompt_token_ids = self.messages2ids(request.to_dict())
122+
else:
123+
raise ValueError(f"The request should have `prompt_token_ids`, `prompt` or `messages`: {request}.")
112124

113125
if len(request.prompt_token_ids) == 0:
114126
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
127+
128+
# truncate prompts that exceed the length limit
115129
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
116130
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
117131
if request.get("max_tokens") is None:
118-
request.set(
119-
"max_tokens",
120-
max(1, max_model_len - len(request.prompt_token_ids)),
121-
)
132+
request.set("max_tokens", max(1, max_model_len - len(request.prompt_token_ids)))
122133
if request.get("temperature") < _SAMPLING_EPS:
123134
# zero temperature is equivalent to greedy sampling
124135
request.set("temperature", 1)
125136
if request.get("top_p") < _SAMPLING_EPS:
126137
request.set("top_p", _SAMPLING_EPS)
127-
data_processor_logger.info(f"Processed request {request}")
138+
139+
data_processor_logger.info(f"Processed request: {request}")
128140
return request
129141

130142
def process_request_dict(self, request, max_model_len=None):
@@ -151,19 +163,25 @@ def process_request_dict(self, request, max_model_len=None):
151163

152164
# processing prompt_token_ids
153165
if not request.get("prompt_token_ids"):
154-
if request.get("prompt") is None and request.get("messages") is None:
155-
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
156166
if request.get("prompt"):
157167
prompt = request.get("prompt")
158-
prompt = prompt[0] if isinstance(prompt, list) else prompt
159-
160-
tokens = self.tokenizer.tokenize(prompt)
161-
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
162-
request["prompt_token_ids"] = token_ids
163-
req_id = request.get("request_id", None)
164-
data_processor_logger.info(f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}")
165-
else:
168+
assert isinstance(prompt, str) or (
169+
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
170+
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
171+
if isinstance(prompt, list): # if prompt is a token id list
172+
request["prompt_token_ids"] = prompt
173+
else:
174+
tokens = self.tokenizer.tokenize(prompt)
175+
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
176+
request["prompt_token_ids"] = token_ids
177+
data_processor_logger.debug(
178+
f"request_ids: {request.get('request_id')}, prompt: {prompt}, tokens: {tokens}, token_ids: {token_ids}"
179+
)
180+
elif request.get("messages"):
166181
request["prompt_token_ids"] = self.messages2ids(request)
182+
else:
183+
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
184+
167185
if len(request["prompt_token_ids"]) == 0:
168186
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
169187

@@ -177,8 +195,8 @@ def process_request_dict(self, request, max_model_len=None):
177195
request["temperature"] = 1
178196
if request.get("top_p") < _SAMPLING_EPS:
179197
request["top_p"] = _SAMPLING_EPS
180-
data_processor_logger.info(f"Processed request {request}")
181198

199+
data_processor_logger.info(f"Processed request dict: {request}")
182200
return request
183201

184202
def process_response(self, response_dict, **kwargs):

fastdeploy/input/text_processor.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -223,15 +223,24 @@ def process_request(self, request, max_model_len=None, **kwargs):
223223
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
224224
request.eos_token_ids = self.eos_token_ids
225225

226+
# processing stop_sequences
226227
stop_sequences = request.get("stop", [])
227228
if stop_sequences is not None and len(stop_sequences) != 0:
228229
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
229230
request.set("stop_token_ids", stop_seqs)
230231
request.set("stop_seqs_len", stop_seqs_len)
231232

233+
# processing prompt_token_ids
232234
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
233235
if request.prompt is not None:
234-
request.prompt_token_ids = self.text2ids(request.prompt, max_model_len)
236+
prompt = request.prompt
237+
assert isinstance(prompt, str) or (
238+
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
239+
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
240+
if isinstance(prompt, list): # if prompt is a token id list
241+
request.prompt_token_ids = prompt
242+
else:
243+
request.prompt_token_ids = self.text2ids(request.prompt, max_model_len)
235244
elif request.messages is not None:
236245
if self.tokenizer.chat_template is None:
237246
raise ValueError("This model does not support chat_template.")
@@ -240,19 +249,22 @@ def process_request(self, request, max_model_len=None, **kwargs):
240249
request.prompt_token_ids = self.messages2ids(task)
241250
else:
242251
raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.")
252+
243253
if len(request.prompt_token_ids) == 0:
244254
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
255+
256+
# truncate prompts that exceed the length limit
257+
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
258+
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
245259
if request.get("max_tokens") is None:
246-
request.set(
247-
"max_tokens",
248-
max(1, max_model_len - len(request.prompt_token_ids)),
249-
)
260+
request.set("max_tokens", max(1, max_model_len - len(request.prompt_token_ids)))
250261
if request.get("temperature") < _SAMPLING_EPS:
251262
# zero temperature is equivalent to greedy sampling
252263
request.set("temperature", 1)
253264
if request.get("top_p") < _SAMPLING_EPS:
254265
request.set("top_p", _SAMPLING_EPS)
255-
data_processor_logger.info(f"Processed request {request}")
266+
267+
data_processor_logger.info(f"Processed request: {request}")
256268
return request
257269

258270
def process_request_dict(self, request, max_model_len=None, **kwargs):
@@ -277,27 +289,39 @@ def process_request_dict(self, request, max_model_len=None, **kwargs):
277289
request["stop_token_ids"] = stop_seqs
278290
request["stop_seqs_len"] = stop_seqs_len
279291

280-
data_processor_logger.info(f"Processing request {request}")
281292
# processing prompt_token_ids
282293
if not request.get("prompt_token_ids"):
283-
if "prompt" in request:
284-
request["prompt_token_ids"] = self.text2ids(request["prompt"], max_model_len).tolist()
285-
elif "messages" in request:
294+
if request.get("prompt"):
295+
prompt = request.get("prompt")
296+
assert isinstance(prompt, str) or (
297+
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
298+
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
299+
if isinstance(prompt, list): # if prompt is a token id list
300+
request["prompt_token_ids"] = prompt
301+
else:
302+
request["prompt_token_ids"] = self.text2ids(request["prompt"], max_model_len).tolist()
303+
elif request.get("messages"):
286304
if self.tokenizer.chat_template is None:
287305
raise ValueError("This model does not support chat_template.")
288306
request["prompt_token_ids"] = self.messages2ids(request)
289307
else:
290308
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
309+
291310
if len(request["prompt_token_ids"]) == 0:
292311
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
312+
313+
# truncate prompts that exceed the length limit
314+
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
315+
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
293316
if request.get("max_tokens") is None:
294317
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"]))
295318
if request.get("temperature") < _SAMPLING_EPS:
296319
# zero temperature is equivalent to greedy sampling
297320
request["temperature"] = 1
298321
if request.get("top_p") < _SAMPLING_EPS:
299322
request["top_p"] = _SAMPLING_EPS
300-
data_processor_logger.info(f"Processed request {request}")
323+
324+
data_processor_logger.info(f"Processed request dict: {request}")
301325
return request
302326

303327
def process_logprob_response(self, token_ids, **kwargs):

0 commit comments

Comments
 (0)