Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 42 additions & 46 deletions fastdeploy/entrypoints/engine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,50 +116,37 @@ def create_zmq_client(self, model, mode):
self.zmq_client = ZmqIpcClient(model, mode)
self.zmq_client.connect()

async def format_and_add_data(self, prompts: dict):
async def format_request(self, request: dict):
"""
Format the request data and send the request to the server.
"""
if "request_id" not in prompts:
if "request_id" not in request:
request_id = str(uuid.uuid4())
prompts["request_id"] = request_id
request["request_id"] = request_id

if "max_tokens" not in prompts:
prompts["max_tokens"] = self.max_model_len - 1
if "max_tokens" not in request:
request["max_tokens"] = self.max_model_len - 1

await self.add_requests(prompts)
return prompts["prompt_token_ids"]

async def add_requests(self, task):
"""
Add a new request to the queue.

Args:
task: Request A dictionary representing the request.
sampling_params: A dictionary representing the sampling parameters.

Returns:
None
"""

task["preprocess_start_time"] = time.time()
request["preprocess_start_time"] = time.time()
try:
chat_template_kwargs = task.get("chat_template_kwargs", {})
chat_template_kwargs.update({"chat_template": task.get("chat_template"), "tools": task.get("tools")})
task["chat_template_kwargs"] = chat_template_kwargs
chat_template_kwargs = request.get("chat_template_kwargs", {})
chat_template_kwargs.update({"chat_template": request.get("chat_template"), "tools": request.get("tools")})
request["chat_template_kwargs"] = chat_template_kwargs
if inspect.iscoroutinefunction(self.data_processor.process_request_dict):
await self.data_processor.process_request_dict(task, self.max_model_len)
await self.data_processor.process_request_dict(request, self.max_model_len)
else:
self.data_processor.process_request_dict(task, self.max_model_len)

task["prompt_token_ids_len"] = len(task["prompt_token_ids"])
input_ids_len = task["prompt_token_ids_len"]
task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens"))
min_tokens = task.get("min_tokens", 1)
if "messages" in task:
del task["messages"]
api_server_logger.info(f"task['max_tokens']:{task['max_tokens']}")
work_process_metrics.request_params_max_tokens.observe(task["max_tokens"])
self.data_processor.process_request_dict(request, self.max_model_len)

request["prompt_token_ids_len"] = len(request["prompt_token_ids"])
input_ids_len = request["prompt_token_ids_len"]
request["max_tokens"] = min(self.max_model_len - input_ids_len, request.get("max_tokens"))
if request.get("reasoning_max_tokens", None) is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里reasoning_max_tokens的逻辑去掉吧

request["reasoning_max_tokens"] = max(int(request["max_tokens"] * 0.8), 1)
min_tokens = request.get("min_tokens", 1)
if "messages" in request:
del request["messages"]
api_server_logger.info(f"request['max_tokens']:{request['max_tokens']}")
work_process_metrics.request_params_max_tokens.observe(request["max_tokens"])
work_process_metrics.prompt_tokens_total.inc(input_ids_len)
work_process_metrics.request_prompt_tokens.observe(input_ids_len)
except Exception as e:
Expand All @@ -181,8 +168,8 @@ async def add_requests(self, task):
api_server_logger.error(error_msg)
raise EngineError(error_msg, error_code=400)

if "stop_seqs_len" in task:
stop_seqs_len = task["stop_seqs_len"]
if "stop_seqs_len" in request:
stop_seqs_len = request["stop_seqs_len"]
max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
if len(stop_seqs_len) > max_stop_seqs_num:
error_msg = (
Expand All @@ -201,15 +188,28 @@ async def add_requests(self, task):
api_server_logger.error(error_msg)
raise EngineError(error_msg, error_code=400)

task["preprocess_end_time"] = time.time()
preprocess_cost_time = task["preprocess_end_time"] - task["preprocess_start_time"]
request["preprocess_end_time"] = time.time()
preprocess_cost_time = request["preprocess_end_time"] - request["preprocess_start_time"]
api_server_logger.info(
f"Cache request with request_id ({task.get('request_id')}), "
f"Cache request with request_id ({request.get('request_id')}), "
f"preprocess time cost {preprocess_cost_time}"
)

self.valid_parameters(task)
api_server_logger.debug(f"Receive task: {task}")
self.valid_parameters(request)
api_server_logger.debug(f"Receive request: {request}")

async def add_requests(self, task):
"""
Add a new request to the queue.

Args:
task: Request A dictionary representing the request.
sampling_params: A dictionary representing the sampling parameters.

Returns:
None
"""

try:
if not self.enable_mm:
self.zmq_client.send_json(task)
Expand All @@ -226,10 +226,6 @@ def valid_parameters(self, data):
前置到了ChatCompletionRequest/CompletionRequest中
"""

if data.get("n") is not None:
if data["n"] != 1:
raise ParameterError("n", "n only support 1.")

if data.get("max_tokens") is not None:
if data["max_tokens"] < 1 or data["max_tokens"] >= self.max_model_len:
raise ParameterError("max_tokens", f"max_tokens can be defined [1, {self.max_model_len}).")
Expand Down
Loading
Loading