Skip to content

Commit bb8c9b8

Browse files
committed
[feat] completion api supports passing input token ids in either prompt or prompt_token_ids
1 parent 37569cc commit bb8c9b8

File tree

7 files changed

+354
-78
lines changed

7 files changed

+354
-78
lines changed

fastdeploy/entrypoints/engine_client.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,19 +94,19 @@ def create_zmq_client(self, model, mode):
9494
self.zmq_client = ZmqClient(model, mode)
9595
self.zmq_client.connect()
9696

97-
def format_and_add_data(self, prompts: dict):
97+
def format_and_add_data(self, req_dict: dict):
9898
"""
9999
Format the request data and send the request to the server.
100100
"""
101-
if "request_id" not in prompts:
101+
if "request_id" not in req_dict:
102102
request_id = str(uuid.uuid4())
103-
prompts["request_id"] = request_id
103+
req_dict["request_id"] = request_id
104104

105-
if "max_tokens" not in prompts:
106-
prompts["max_tokens"] = self.max_model_len - 1
105+
if "max_tokens" not in req_dict:
106+
req_dict["max_tokens"] = self.max_model_len - 1
107107

108-
self.add_requests(prompts)
109-
return prompts["prompt_token_ids"]
108+
self.add_requests(req_dict)
109+
return req_dict["prompt_token_ids"]
110110

111111
def add_requests(self, task):
112112
"""

fastdeploy/entrypoints/openai/protocol.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ class CompletionRequest(BaseModel):
375375

376376
max_streaming_response_tokens: Optional[int] = None
377377
return_token_ids: Optional[bool] = None
378-
prompt_token_ids: Optional[List[int]] = None
378+
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None
379379
# doc: end-completion-extra-params
380380

381381
def to_dict_for_infer(self, request_id=None, prompt=None):
@@ -400,11 +400,11 @@ def to_dict_for_infer(self, request_id=None, prompt=None):
400400
if prompt is not None:
401401
req_dict["prompt"] = prompt
402402

403-
if "prompt_token_ids" in req_dict:
404-
if "prompt" in req_dict:
405-
del req_dict["prompt"]
406-
else:
407-
assert len(prompt) > 0
403+
# if "prompt_token_ids" in req_dict:
404+
# if "prompt" in req_dict:
405+
# del req_dict["prompt"]
406+
# else:
407+
# assert len(prompt) > 0
408408

409409
guided_json_object = None
410410
if self.response_format is not None:
@@ -503,7 +503,7 @@ class ChatCompletionRequest(BaseModel):
503503
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
504504
# doc: end-chat-completion-sampling-params
505505

506-
# doc: start-completion-extra-params
506+
# doc: start-chat-completion-extra-params
507507
chat_template_kwargs: Optional[dict] = None
508508
reasoning_max_tokens: Optional[int] = None
509509
structural_tag: Optional[str] = None

fastdeploy/entrypoints/openai/serving_completion.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from aiozmq import zmq
2626

2727
from fastdeploy.engine.request import RequestOutput
28+
from fastdeploy.entrypoints.engine_client import EngineClient
2829
from fastdeploy.entrypoints.openai.protocol import (
2930
CompletionLogprobs,
3031
CompletionRequest,
@@ -41,7 +42,7 @@
4142

4243
class OpenAIServingCompletion:
4344
def __init__(self, engine_client, pid, ips):
44-
self.engine_client = engine_client
45+
self.engine_client: EngineClient = engine_client
4546
self.pid = pid
4647
self.master_ip = ips
4748
self.host_ip = get_host_ip()
@@ -71,41 +72,57 @@ async def create_completion(self, request: CompletionRequest):
7172
request_id = f"cmpl-{request.user}-{uuid.uuid4()}"
7273
else:
7374
request_id = f"cmpl-{uuid.uuid4()}"
74-
api_server_logger.info(f"initialize request {request_id}")
75+
api_server_logger.info(f"Initialize request {request_id}: {request}")
7576
request_prompt_ids = None
7677
request_prompts = None
78+
79+
# Handle prompt and prompt_token_ids
7780
try:
78-
if isinstance(request.prompt, str):
79-
request_prompts = [request.prompt]
80-
elif isinstance(request.prompt, list) and all(isinstance(item, int) for item in request.prompt):
81-
request_prompt_ids = [request.prompt]
82-
elif isinstance(request.prompt, list) and all(isinstance(item, str) for item in request.prompt):
83-
request_prompts = request.prompt
84-
elif isinstance(request.prompt, list):
85-
for item in request.prompt:
86-
if isinstance(item, list) and all(isinstance(x, int) for x in item):
87-
continue
88-
else:
89-
raise ValueError("Prompt must be a string, a list of strings or a list of integers.")
90-
request_prompt_ids = request.prompt
81+
if request.prompt_token_ids is not None: # let `prompt_token_ids` support batch inference
82+
assert len(request.prompt_token_ids) > 0, "prompt_token_ids should not be an empty list"
83+
if isinstance(request.prompt_token_ids[0], list):
84+
request_prompt_ids = request.prompt_token_ids
85+
elif isinstance(request.prompt_token_ids[0], int):
86+
request_prompt_ids = [request.prompt_token_ids]
87+
else:
88+
raise ValueError(
89+
"If prompt_token_ids is provided, its type should be one of: list[int], list[list[int]]"
90+
)
91+
# reset `prompt_token_ids` because token ids are expected to be passed in `prompt`
92+
request.prompt_token_ids = None
9193
else:
92-
raise ValueError("Prompt must be a string, a list of strings or a list of integers.")
94+
if isinstance(request.prompt, str):
95+
request_prompts = [request.prompt]
96+
elif isinstance(request.prompt, list) and all(isinstance(item, int) for item in request.prompt):
97+
request_prompt_ids = [request.prompt]
98+
elif isinstance(request.prompt, list) and all(isinstance(item, str) for item in request.prompt):
99+
request_prompts = request.prompt
100+
elif isinstance(request.prompt, list):
101+
for item in request.prompt:
102+
if isinstance(item, list) and all(isinstance(x, int) for x in item):
103+
continue
104+
else:
105+
raise ValueError("If prompt is a list, each item type must be one of: str, list[int]")
106+
request_prompt_ids = request.prompt
107+
else:
108+
raise ValueError("Prompt type must be one of: str, list[str], list[int], list[list[int]]")
93109
except Exception as e:
94110
return ErrorResponse(message=str(e), code=400)
95111

96112
if request_prompt_ids is not None:
97113
request_prompts = request_prompt_ids
98-
num_choices = len(request_prompts)
99114

100-
api_server_logger.info(f"start inference for request {num_choices}")
115+
num_choices = len(request_prompts)
116+
api_server_logger.info(f"Start preprocessing request: req_id={request_id}), num_choices={num_choices}")
101117
prompt_batched_token_ids = []
102118
try:
103-
for idx, prompt in enumerate(request_prompts):
119+
for idx, prompt in enumerate(request_prompts): # process each prompt for this batch completion request
104120
request_id_idx = f"{request_id}-{idx}"
105121
current_req_dict = request.to_dict_for_infer(request_id_idx, prompt)
122+
api_server_logger.debug(f"current_req_dict: {current_req_dict}")
106123
try:
107124
current_req_dict["arrival_time"] = time.time()
108-
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict)
125+
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict) # tokenize
109126
if isinstance(prompt_token_ids, np.ndarray):
110127
prompt_token_ids = prompt_token_ids.tolist()
111128
prompt_batched_token_ids.append(prompt_token_ids)
@@ -114,6 +131,10 @@ async def create_completion(self, request: CompletionRequest):
114131

115132
del current_req_dict
116133

134+
api_server_logger.info(
135+
f"Finish preprocessing request: req_id={request_id}, lengths={[len(t) for t in prompt_batched_token_ids]}"
136+
)
137+
117138
if request.stream:
118139
return self.completion_stream_generator(
119140
request=request,

fastdeploy/input/ernie_processor.py

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -91,40 +91,52 @@ def process_request(self, request, max_model_len=None, **kwargs):
9191
request = self._apply_default_parameters(request)
9292
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
9393
request.eos_token_ids = self.eos_token_ids
94+
95+
# processing stop_sequences
9496
stop_sequences = request.get("stop", [])
9597
if stop_sequences is not None and len(stop_sequences) != 0:
9698
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
9799
request.set("stop_token_ids", stop_seqs)
98100
request.set("stop_seqs_len", stop_seqs_len)
99101

102+
# processing prompt_token_ids
100103
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
101-
if request.prompt is None and request.messages is None:
102-
raise ValueError(f"The request should have `prompt_token_ids`, `prompt` or `messages`: {request}.")
103104
if request.prompt is not None:
104-
prompt = request.prompt if request.prompt is not None else request.messages[0]
105-
prompt = prompt[0] if isinstance(prompt, list) else prompt
106-
tokens = self.tokenizer.tokenize(prompt)
107-
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
108-
request.prompt_token_ids = token_ids
109-
data_processor_logger.info(f"req_id:{request.request_id}, tokens:{tokens}, token_ids: {token_ids}")
110-
else:
105+
# prompt = request.prompt if request.prompt is not None else request.messages[0]
106+
prompt = request.prompt
107+
assert isinstance(prompt, str) or (
108+
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
109+
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
110+
111+
if isinstance(prompt, list): # if prompt is a token id list
112+
request["prompt_token_ids"] = prompt
113+
else:
114+
tokens = self.tokenizer.tokenize(prompt)
115+
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
116+
request.prompt_token_ids = token_ids
117+
data_processor_logger.debug(
118+
f"request_ids: {request.request_id}, prompt: {prompt}, tokens: {tokens}, token_ids: {token_ids}"
119+
)
120+
elif request.messages is not None:
111121
request.prompt_token_ids = self.messages2ids(request.to_dict())
122+
else:
123+
raise ValueError(f"The request should have `prompt_token_ids`, `prompt` or `messages`: {request}.")
112124

113125
if len(request.prompt_token_ids) == 0:
114126
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
127+
128+
# truncate prompts that exceed the length limit
115129
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
116130
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
117131
if request.get("max_tokens") is None:
118-
request.set(
119-
"max_tokens",
120-
max(1, max_model_len - len(request.prompt_token_ids)),
121-
)
132+
request.set("max_tokens", max(1, max_model_len - len(request.prompt_token_ids)))
122133
if request.get("temperature") < _SAMPLING_EPS:
123134
# zero temperature is equivalent to greedy sampling
124135
request.set("temperature", 1)
125136
if request.get("top_p") < _SAMPLING_EPS:
126137
request.set("top_p", _SAMPLING_EPS)
127-
data_processor_logger.info(f"Processed request {request}")
138+
139+
data_processor_logger.info(f"Processed request: {request}")
128140
return request
129141

130142
def process_request_dict(self, request, max_model_len=None):
@@ -151,19 +163,25 @@ def process_request_dict(self, request, max_model_len=None):
151163

152164
# processing prompt_token_ids
153165
if not request.get("prompt_token_ids"):
154-
if request.get("prompt") is None and request.get("messages") is None:
155-
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
156166
if request.get("prompt"):
157167
prompt = request.get("prompt")
158-
prompt = prompt[0] if isinstance(prompt, list) else prompt
159-
160-
tokens = self.tokenizer.tokenize(prompt)
161-
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
162-
request["prompt_token_ids"] = token_ids
163-
req_id = request.get("request_id", None)
164-
data_processor_logger.info(f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}")
165-
else:
168+
assert isinstance(prompt, str) or (
169+
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
170+
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
171+
if isinstance(prompt, list): # if prompt is a token id list
172+
request["prompt_token_ids"] = prompt
173+
else:
174+
tokens = self.tokenizer.tokenize(prompt)
175+
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
176+
request["prompt_token_ids"] = token_ids
177+
data_processor_logger.debug(
178+
f"request_ids: {request.get('request_id')}, prompt: {prompt}, tokens: {tokens}, token_ids: {token_ids}"
179+
)
180+
elif request.get("messages"):
166181
request["prompt_token_ids"] = self.messages2ids(request)
182+
else:
183+
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
184+
167185
if len(request["prompt_token_ids"]) == 0:
168186
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
169187

@@ -177,8 +195,8 @@ def process_request_dict(self, request, max_model_len=None):
177195
request["temperature"] = 1
178196
if request.get("top_p") < _SAMPLING_EPS:
179197
request["top_p"] = _SAMPLING_EPS
180-
data_processor_logger.info(f"Processed request {request}")
181198

199+
data_processor_logger.info(f"Processed request dict: {request}")
182200
return request
183201

184202
def process_response(self, response_dict, **kwargs):

fastdeploy/input/text_processor.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -223,15 +223,24 @@ def process_request(self, request, max_model_len=None, **kwargs):
223223
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
224224
request.eos_token_ids = self.eos_token_ids
225225

226+
# processing stop_sequences
226227
stop_sequences = request.get("stop", [])
227228
if stop_sequences is not None and len(stop_sequences) != 0:
228229
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
229230
request.set("stop_token_ids", stop_seqs)
230231
request.set("stop_seqs_len", stop_seqs_len)
231232

233+
# processing prompt_token_ids
232234
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
233235
if request.prompt is not None:
234-
request.prompt_token_ids = self.text2ids(request.prompt, max_model_len)
236+
prompt = request.prompt
237+
assert isinstance(prompt, str) or (
238+
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
239+
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
240+
if isinstance(prompt, list): # if prompt is a token id list
241+
request.prompt_token_ids = prompt
242+
else:
243+
request.prompt_token_ids = self.text2ids(request.prompt, max_model_len)
235244
elif request.messages is not None:
236245
if self.tokenizer.chat_template is None:
237246
raise ValueError("This model does not support chat_template.")
@@ -240,19 +249,22 @@ def process_request(self, request, max_model_len=None, **kwargs):
240249
request.prompt_token_ids = self.messages2ids(task)
241250
else:
242251
raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.")
252+
243253
if len(request.prompt_token_ids) == 0:
244254
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
255+
256+
# truncate prompts that exceed the length limit
257+
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
258+
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
245259
if request.get("max_tokens") is None:
246-
request.set(
247-
"max_tokens",
248-
max(1, max_model_len - len(request.prompt_token_ids)),
249-
)
260+
request.set("max_tokens", max(1, max_model_len - len(request.prompt_token_ids)))
250261
if request.get("temperature") < _SAMPLING_EPS:
251262
# zero temperature is equivalent to greedy sampling
252263
request.set("temperature", 1)
253264
if request.get("top_p") < _SAMPLING_EPS:
254265
request.set("top_p", _SAMPLING_EPS)
255-
data_processor_logger.info(f"Processed request {request}")
266+
267+
data_processor_logger.info(f"Processed request: {request}")
256268
return request
257269

258270
def process_request_dict(self, request, max_model_len=None, **kwargs):
@@ -277,27 +289,39 @@ def process_request_dict(self, request, max_model_len=None, **kwargs):
277289
request["stop_token_ids"] = stop_seqs
278290
request["stop_seqs_len"] = stop_seqs_len
279291

280-
data_processor_logger.info(f"Processing request {request}")
281292
# processing prompt_token_ids
282293
if not request.get("prompt_token_ids"):
283-
if "prompt" in request:
284-
request["prompt_token_ids"] = self.text2ids(request["prompt"], max_model_len).tolist()
285-
elif "messages" in request:
294+
if request.get("prompt"):
295+
prompt = request.get("prompt")
296+
assert isinstance(prompt, str) or (
297+
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
298+
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
299+
if isinstance(prompt, list): # if prompt is a token id list
300+
request["prompt_token_ids"] = prompt
301+
else:
302+
request["prompt_token_ids"] = self.text2ids(request["prompt"], max_model_len).tolist()
303+
elif request.get("messages"):
286304
if self.tokenizer.chat_template is None:
287305
raise ValueError("This model does not support chat_template.")
288306
request["prompt_token_ids"] = self.messages2ids(request)
289307
else:
290308
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
309+
291310
if len(request["prompt_token_ids"]) == 0:
292311
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
312+
313+
# truncate prompts that exceed the length limit
314+
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
315+
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
293316
if request.get("max_tokens") is None:
294317
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"]))
295318
if request.get("temperature") < _SAMPLING_EPS:
296319
# zero temperature is equivalent to greedy sampling
297320
request["temperature"] = 1
298321
if request.get("top_p") < _SAMPLING_EPS:
299322
request["top_p"] = _SAMPLING_EPS
300-
data_processor_logger.info(f"Processed request {request}")
323+
324+
data_processor_logger.info(f"Processed request dict: {request}")
301325
return request
302326

303327
def process_logprob_response(self, token_ids, **kwargs):

0 commit comments

Comments
 (0)