Skip to content

[feat] completion api supports passing input token ids in either prompt or prompt_token_ids #3311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from
14 changes: 7 additions & 7 deletions fastdeploy/entrypoints/engine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,19 +96,19 @@ def create_zmq_client(self, model, mode):
self.zmq_client = ZmqClient(model, mode)
self.zmq_client.connect()

def format_and_add_data(self, prompts: dict):
def format_and_add_data(self, req_dict: dict):
"""
Format the request data and send the request to the server.
"""
if "request_id" not in prompts:
if "request_id" not in req_dict:
request_id = str(uuid.uuid4())
prompts["request_id"] = request_id
req_dict["request_id"] = request_id

if "max_tokens" not in prompts:
prompts["max_tokens"] = self.max_model_len - 1
if "max_tokens" not in req_dict:
req_dict["max_tokens"] = self.max_model_len - 1

self.add_requests(prompts)
return prompts["prompt_token_ids"]
self.add_requests(req_dict)
return req_dict["prompt_token_ids"]

def add_requests(self, task):
"""
Expand Down
14 changes: 7 additions & 7 deletions fastdeploy/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ class CompletionRequest(BaseModel):

max_streaming_response_tokens: Optional[int] = None
return_token_ids: Optional[bool] = None
prompt_token_ids: Optional[List[int]] = None
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None
# doc: end-completion-extra-params

def to_dict_for_infer(self, request_id=None, prompt=None):
Expand All @@ -400,11 +400,11 @@ def to_dict_for_infer(self, request_id=None, prompt=None):
if prompt is not None:
req_dict["prompt"] = prompt

if "prompt_token_ids" in req_dict:
if "prompt" in req_dict:
del req_dict["prompt"]
else:
assert len(prompt) > 0
# if "prompt_token_ids" in req_dict:
# if "prompt" in req_dict:
# del req_dict["prompt"]
# else:
# assert len(prompt) > 0

guided_json_object = None
if self.response_format is not None:
Expand Down Expand Up @@ -503,7 +503,7 @@ class ChatCompletionRequest(BaseModel):
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
# doc: end-chat-completion-sampling-params

# doc: start-completion-extra-params
# doc: start-chat-completion-extra-params
chat_template_kwargs: Optional[dict] = None
reasoning_max_tokens: Optional[int] = None
structural_tag: Optional[str] = None
Expand Down
61 changes: 41 additions & 20 deletions fastdeploy/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from aiozmq import zmq

from fastdeploy.engine.request import RequestOutput
from fastdeploy.entrypoints.engine_client import EngineClient
from fastdeploy.entrypoints.openai.protocol import (
CompletionLogprobs,
CompletionRequest,
Expand All @@ -41,7 +42,7 @@

class OpenAIServingCompletion:
def __init__(self, engine_client, pid, ips, max_waiting_time):
self.engine_client = engine_client
self.engine_client: EngineClient = engine_client
self.pid = pid
self.master_ip = ips
self.host_ip = get_host_ip()
Expand Down Expand Up @@ -72,41 +73,57 @@ async def create_completion(self, request: CompletionRequest):
request_id = f"cmpl-{request.user}-{uuid.uuid4()}"
else:
request_id = f"cmpl-{uuid.uuid4()}"
api_server_logger.info(f"initialize request {request_id}")
api_server_logger.info(f"Initialize request {request_id}: {request}")
request_prompt_ids = None
request_prompts = None

# Handle prompt and prompt_token_ids
try:
if isinstance(request.prompt, str):
request_prompts = [request.prompt]
elif isinstance(request.prompt, list) and all(isinstance(item, int) for item in request.prompt):
request_prompt_ids = [request.prompt]
elif isinstance(request.prompt, list) and all(isinstance(item, str) for item in request.prompt):
request_prompts = request.prompt
elif isinstance(request.prompt, list):
for item in request.prompt:
if isinstance(item, list) and all(isinstance(x, int) for x in item):
continue
else:
raise ValueError("Prompt must be a string, a list of strings or a list of integers.")
request_prompt_ids = request.prompt
if request.prompt_token_ids is not None: # let `prompt_token_ids` support batch inference
assert len(request.prompt_token_ids) > 0, "prompt_token_ids should not be an empty list"
if isinstance(request.prompt_token_ids[0], list):
request_prompt_ids = request.prompt_token_ids
elif isinstance(request.prompt_token_ids[0], int):
request_prompt_ids = [request.prompt_token_ids]
else:
raise ValueError(
"If prompt_token_ids is provided, its type should be one of: list[int], list[list[int]]"
)
# reset `prompt_token_ids` to avoid data processor directly using it; let data processor fill it
request.prompt_token_ids = None
else:
raise ValueError("Prompt must be a string, a list of strings or a list of integers.")
if isinstance(request.prompt, str):
request_prompts = [request.prompt]
elif isinstance(request.prompt, list) and all(isinstance(item, int) for item in request.prompt):
request_prompt_ids = [request.prompt]
elif isinstance(request.prompt, list) and all(isinstance(item, str) for item in request.prompt):
request_prompts = request.prompt
elif isinstance(request.prompt, list):
for item in request.prompt:
if isinstance(item, list) and all(isinstance(x, int) for x in item):
continue
else:
raise ValueError("If prompt is a list, each item type must be one of: str, list[int]")
request_prompt_ids = request.prompt
else:
raise ValueError("Prompt type must be one of: str, list[str], list[int], list[list[int]]")
except Exception as e:
return ErrorResponse(message=str(e), code=400)

if request_prompt_ids is not None:
request_prompts = request_prompt_ids
num_choices = len(request_prompts)

api_server_logger.info(f"start inference for request {num_choices}")
num_choices = len(request_prompts)
api_server_logger.info(f"Start preprocessing request: req_id={request_id}), num_choices={num_choices}")
prompt_batched_token_ids = []
try:
for idx, prompt in enumerate(request_prompts):
for idx, prompt in enumerate(request_prompts): # process each prompt for this batch completion request
request_id_idx = f"{request_id}-{idx}"
current_req_dict = request.to_dict_for_infer(request_id_idx, prompt)
api_server_logger.debug(f"current_req_dict: {current_req_dict}")
try:
current_req_dict["arrival_time"] = time.time()
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict)
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict) # tokenize
if isinstance(prompt_token_ids, np.ndarray):
prompt_token_ids = prompt_token_ids.tolist()
prompt_batched_token_ids.append(prompt_token_ids)
Expand All @@ -115,6 +132,10 @@ async def create_completion(self, request: CompletionRequest):

del current_req_dict

api_server_logger.info(
f"Finish preprocessing request: req_id={request_id}, lengths={[len(t) for t in prompt_batched_token_ids]}"
)

try:
if self.max_waiting_time < 0:
await self.engine_client.semaphore.acquire()
Expand Down
68 changes: 43 additions & 25 deletions fastdeploy/input/ernie_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,40 +88,52 @@ def process_request(self, request, max_model_len=None, **kwargs):
request = self._apply_default_parameters(request)
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
request.eos_token_ids = self.eos_token_ids

# processing stop_sequences
stop_sequences = request.get("stop", [])
if stop_sequences is not None and len(stop_sequences) != 0:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
request.set("stop_token_ids", stop_seqs)
request.set("stop_seqs_len", stop_seqs_len)

# processing prompt_token_ids
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
if request.prompt is None and request.messages is None:
raise ValueError(f"The request should have `prompt_token_ids`, `prompt` or `messages`: {request}.")
if request.prompt is not None:
prompt = request.prompt if request.prompt is not None else request.messages[0]
prompt = prompt[0] if isinstance(prompt, list) else prompt
tokens = self.tokenizer.tokenize(prompt)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
request.prompt_token_ids = token_ids
data_processor_logger.info(f"req_id:{request.request_id}, tokens:{tokens}, token_ids: {token_ids}")
else:
# prompt = request.prompt if request.prompt is not None else request.messages[0]
prompt = request.prompt
assert isinstance(prompt, str) or (
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
), f"prompt must be a string or a list of integers, but got {type(prompt)}"

if isinstance(prompt, list): # if prompt is a token id list
request["prompt_token_ids"] = prompt
else:
tokens = self.tokenizer.tokenize(prompt)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
request.prompt_token_ids = token_ids
data_processor_logger.debug(
f"request_ids: {request.request_id}, prompt: {prompt}, tokens: {tokens}, token_ids: {token_ids}"
)
elif request.messages is not None:
request.prompt_token_ids = self.messages2ids(request.to_dict())
else:
raise ValueError(f"The request should have `prompt_token_ids`, `prompt` or `messages`: {request}.")

if len(request.prompt_token_ids) == 0:
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")

# truncate prompts that exceed the length limit
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
if request.get("max_tokens") is None:
request.set(
"max_tokens",
max(1, max_model_len - len(request.prompt_token_ids)),
)
request.set("max_tokens", max(1, max_model_len - len(request.prompt_token_ids)))
if request.get("temperature") < _SAMPLING_EPS:
# zero temperature is equivalent to greedy sampling
request.set("temperature", 1)
if request.get("top_p") < _SAMPLING_EPS:
request.set("top_p", _SAMPLING_EPS)
data_processor_logger.info(f"Processed request {request}")

data_processor_logger.info(f"Processed request: {request}")
return request

def process_request_dict(self, request, max_model_len=None):
Expand All @@ -148,19 +160,25 @@ def process_request_dict(self, request, max_model_len=None):

# processing prompt_token_ids
if not request.get("prompt_token_ids"):
if request.get("prompt") is None and request.get("messages") is None:
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
if request.get("prompt"):
prompt = request.get("prompt")
prompt = prompt[0] if isinstance(prompt, list) else prompt

tokens = self.tokenizer.tokenize(prompt)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
request["prompt_token_ids"] = token_ids
req_id = request.get("request_id", None)
data_processor_logger.info(f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}")
else:
assert isinstance(prompt, str) or (
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
if isinstance(prompt, list): # if prompt is a token id list
request["prompt_token_ids"] = prompt
else:
tokens = self.tokenizer.tokenize(prompt)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
request["prompt_token_ids"] = token_ids
data_processor_logger.debug(
f"request_ids: {request.get('request_id')}, prompt: {prompt}, tokens: {tokens}, token_ids: {token_ids}"
)
elif request.get("messages"):
request["prompt_token_ids"] = self.messages2ids(request)
else:
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")

if len(request["prompt_token_ids"]) == 0:
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")

Expand All @@ -174,8 +192,8 @@ def process_request_dict(self, request, max_model_len=None):
request["temperature"] = 1
if request.get("top_p") < _SAMPLING_EPS:
request["top_p"] = _SAMPLING_EPS
data_processor_logger.info(f"Processed request {request}")

data_processor_logger.info(f"Processed request dict: {request}")
return request

def process_response(self, response_dict, **kwargs):
Expand Down
46 changes: 35 additions & 11 deletions fastdeploy/input/text_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,15 +206,24 @@ def process_request(self, request, max_model_len=None, **kwargs):
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
request.eos_token_ids = self.eos_token_ids

# processing stop_sequences
stop_sequences = request.get("stop", [])
if stop_sequences is not None and len(stop_sequences) != 0:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
request.set("stop_token_ids", stop_seqs)
request.set("stop_seqs_len", stop_seqs_len)

# processing prompt_token_ids
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
if request.prompt is not None:
request.prompt_token_ids = self.text2ids(request.prompt, max_model_len)
prompt = request.prompt
assert isinstance(prompt, str) or (
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
if isinstance(prompt, list): # if prompt is a token id list
request.prompt_token_ids = prompt
else:
request.prompt_token_ids = self.text2ids(request.prompt, max_model_len)
elif request.messages is not None:
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat_template.")
Expand All @@ -223,19 +232,22 @@ def process_request(self, request, max_model_len=None, **kwargs):
request.prompt_token_ids = self.messages2ids(task)
else:
raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.")

if len(request.prompt_token_ids) == 0:
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")

# truncate prompts that exceed the length limit
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
if request.get("max_tokens") is None:
request.set(
"max_tokens",
max(1, max_model_len - len(request.prompt_token_ids)),
)
request.set("max_tokens", max(1, max_model_len - len(request.prompt_token_ids)))
if request.get("temperature") < _SAMPLING_EPS:
# zero temperature is equivalent to greedy sampling
request.set("temperature", 1)
if request.get("top_p") < _SAMPLING_EPS:
request.set("top_p", _SAMPLING_EPS)
data_processor_logger.info(f"Processed request {request}")

data_processor_logger.info(f"Processed request: {request}")
return request

def process_request_dict(self, request, max_model_len=None, **kwargs):
Expand All @@ -260,27 +272,39 @@ def process_request_dict(self, request, max_model_len=None, **kwargs):
request["stop_token_ids"] = stop_seqs
request["stop_seqs_len"] = stop_seqs_len

data_processor_logger.info(f"Processing request {request}")
# processing prompt_token_ids
if not request.get("prompt_token_ids"):
if "prompt" in request:
request["prompt_token_ids"] = self.text2ids(request["prompt"], max_model_len).tolist()
elif "messages" in request:
if request.get("prompt"):
prompt = request.get("prompt")
assert isinstance(prompt, str) or (
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
if isinstance(prompt, list): # if prompt is a token id list
request["prompt_token_ids"] = prompt
else:
request["prompt_token_ids"] = self.text2ids(request["prompt"], max_model_len).tolist()
elif request.get("messages"):
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat_template.")
request["prompt_token_ids"] = self.messages2ids(request)
else:
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")

if len(request["prompt_token_ids"]) == 0:
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")

# truncate prompts that exceed the length limit
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
if request.get("max_tokens") is None:
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"]))
if request.get("temperature") < _SAMPLING_EPS:
# zero temperature is equivalent to greedy sampling
request["temperature"] = 1
if request.get("top_p") < _SAMPLING_EPS:
request["top_p"] = _SAMPLING_EPS
data_processor_logger.info(f"Processed request {request}")

data_processor_logger.info(f"Processed request dict: {request}")
return request

def process_logprob_response(self, token_ids, **kwargs):
Expand Down
Loading
Loading