Skip to content

Commit 81092c0

Browse files
committed
add tool parser
1 parent ad816f2 commit 81092c0

23 files changed

+1054
-39
lines changed

fastdeploy/engine/args_utils.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
"""
1616

1717
import json
18+
import os
1819
from dataclasses import asdict, dataclass
1920
from dataclasses import fields as dataclass_fields
2021
from typing import Any, Dict, List, Optional
21-
import os
2222

2323
from fastdeploy.config import (
2424
CacheConfig,
@@ -93,6 +93,14 @@ class EngineArgs:
9393
"""
9494
specifies the reasoning parser to use for extracting reasoning content from the model output
9595
"""
96+
tool_call_parser: str = None
97+
"""
98+
specifies the tool call parser to use for extracting tool call from the model output
99+
"""
100+
tool_parser_plugin: str = None
101+
"""
102+
tool parser plugin used to register user defined tool parsers
103+
"""
96104
enable_mm: bool = False
97105
"""
98106
Flags to enable multi-modal model
@@ -421,6 +429,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
421429
help="Flag specifies the reasoning parser to use for extracting "
422430
"reasoning content from the model output",
423431
)
432+
model_group.add_argument(
433+
"--tool-call-parser",
434+
type=str,
435+
default=EngineArgs.tool_call_parser,
436+
help="Flag specifies the tool call parser to use for extracting" "tool call from the model output",
437+
)
438+
model_group.add_argument(
439+
"--tool-parser-plugin",
440+
type=str,
441+
default=EngineArgs.tool_parser_plugin,
442+
help="tool parser plugin used to register user defined tool parsers",
443+
)
424444
model_group.add_argument(
425445
"--speculative-config",
426446
type=json.loads,
@@ -866,10 +886,10 @@ def create_engine_config(self) -> Config:
866886
if self.enable_chunked_prefill:
867887
self.max_num_batched_tokens = 2048
868888
else:
869-
if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')):
889+
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
870890
self.max_num_batched_tokens = self.max_model_len
871891
else:
872-
self.max_num_batched_tokens = 8192
892+
self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
873893

874894
all_dict = asdict(self)
875895
all_dict["model_cfg"] = model_cfg
@@ -908,6 +928,7 @@ def create_engine_config(self) -> Config:
908928
mm_processor_kwargs=self.mm_processor_kwargs,
909929
enable_mm=self.enable_mm,
910930
reasoning_parser=self.reasoning_parser,
931+
tool_parser=self.tool_call_parser,
911932
splitwise_role=self.splitwise_role,
912933
innode_prefill_ports=self.innode_prefill_ports,
913934
max_num_partial_prefills=self.max_num_partial_prefills,

fastdeploy/engine/config.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def __init__(
8585
max_long_partial_prefills: int = 1,
8686
long_prefill_token_threshold: int = 0,
8787
reasoning_parser: str = None,
88+
tool_parser: str = None,
8889
guided_decoding_backend: Optional[str] = None,
8990
disable_any_whitespace: bool = False,
9091
enable_logprob: bool = False,
@@ -165,6 +166,7 @@ def __init__(
165166
self.max_long_partial_prefills = max_long_partial_prefills
166167
self.long_prefill_token_threshold = long_prefill_token_threshold
167168
self.reasoning_parser = reasoning_parser
169+
self.tool_parser = tool_parser
168170
self.graph_optimization_config = graph_optimization_config
169171
self.early_stop_config = early_stop_config
170172
self.guided_decoding_backend = guided_decoding_backend
@@ -236,10 +238,10 @@ def postprocess(self):
236238
if self.cache_config.enable_chunked_prefill:
237239
self.max_num_batched_tokens = 2048
238240
else:
239-
if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')):
241+
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
240242
self.max_num_batched_tokens = self.max_model_len
241243
else:
242-
self.max_num_batched_tokens = 8192
244+
self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
243245

244246
if self.long_prefill_token_threshold == 0:
245247
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
@@ -287,7 +289,7 @@ def check(self):
287289
)
288290

289291
if not self.cache_config.enable_chunked_prefill:
290-
if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')):
292+
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
291293
assert self.max_num_batched_tokens >= self.max_model_len, (
292294
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
293295
f"should be larger than or equal to max_model_len: {self.max_model_len}"

fastdeploy/engine/engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def __init__(self, cfg):
106106
cfg.limit_mm_per_prompt,
107107
cfg.mm_processor_kwargs,
108108
cfg.enable_mm,
109+
cfg.tool_parser,
109110
)
110111

111112
self.start_queue_service()

fastdeploy/engine/request.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import numpy as np
2525

2626
from fastdeploy.engine.sampling_params import SamplingParams
27+
from fastdeploy.entrypoints.openai.protocol import ToolCall
2728
from fastdeploy.utils import data_processor_logger
2829
from fastdeploy.worker.output import LogprobsLists, SampleLogprobs
2930

@@ -249,6 +250,7 @@ class CompletionOutput:
249250
draft_token_ids: list[int] = None
250251
text: Optional[str] = None
251252
reasoning_content: Optional[str] = None
253+
tool_calls: Optional[ToolCall] = None
252254

253255
def to_dict(self):
254256
"""

fastdeploy/entrypoints/chat_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
"""
1616

17+
import uuid
1718
from copy import deepcopy
1819
from typing import List, Literal, Union
1920
from urllib.parse import urlparse
@@ -156,3 +157,7 @@ def parse_chat_messages(messages):
156157

157158
conversation.append({"role": role, "content": parsed_content})
158159
return conversation
160+
161+
162+
def random_tool_call_id() -> str:
163+
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"

fastdeploy/entrypoints/engine_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,15 @@ def __init__(
4545
data_parallel_size=1,
4646
enable_logprob=False,
4747
workers=1,
48+
tool_parser=None,
4849
):
4950
input_processor = InputPreprocessor(
5051
tokenizer,
5152
reasoning_parser,
5253
limit_mm_per_prompt,
5354
mm_processor_kwargs,
5455
enable_mm,
56+
tool_parser,
5557
)
5658
self.enable_logprob = enable_logprob
5759
self.enable_mm = enable_mm

fastdeploy/entrypoints/llm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@
2828
from fastdeploy.engine.args_utils import EngineArgs
2929
from fastdeploy.engine.engine import LLMEngine
3030
from fastdeploy.engine.sampling_params import SamplingParams
31-
32-
# from fastdeploy.entrypoints.chat_utils import ChatCompletionMessageParam
31+
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
3332
from fastdeploy.utils import llm_logger, retrive_model_from_server
3433
from fastdeploy.worker.output import Logprob, LogprobsLists
3534

@@ -73,6 +72,9 @@ def __init__(
7372
**kwargs,
7473
):
7574
model = retrive_model_from_server(model, revision)
75+
tool_parser_plugin = kwargs.get("tool_parser_plugin")
76+
if tool_parser_plugin:
77+
ToolParserManager.import_tool_parser(tool_parser_plugin)
7678
engine_args = EngineArgs(
7779
model=model,
7880
tokenizer=tokenizer,

fastdeploy/entrypoints/openai/api_server.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
)
4242
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
4343
from fastdeploy.entrypoints.openai.serving_completion import OpenAIServingCompletion
44+
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
4445
from fastdeploy.metrics.metrics import (
4546
EXCLUDE_LABELS,
4647
cleanup_prometheus_files,
@@ -73,7 +74,8 @@
7374
parser = EngineArgs.add_cli_args(parser)
7475
args = parser.parse_args()
7576
args.model = retrive_model_from_server(args.model, args.revision)
76-
77+
if args.tool_parser_plugin:
78+
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
7779
llm_engine = None
7880

7981

@@ -126,6 +128,7 @@ async def lifespan(app: FastAPI):
126128
args.data_parallel_size,
127129
args.enable_logprob,
128130
args.workers,
131+
args.tool_call_parser,
129132
)
130133
app.state.dynamic_load_weight = args.dynamic_load_weight
131134
chat_handler = OpenAIServingChat(engine_client, pid, args.ips, args.max_waiting_time)

fastdeploy/entrypoints/openai/protocol.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ class ToolCall(BaseModel):
7272
id: str = None
7373
type: Literal["function"] = "function"
7474
function: FunctionCall
75-
index: int
7675

7776

7877
class DeltaFunctionCall(BaseModel):
@@ -96,6 +95,18 @@ class DeltaToolCall(BaseModel):
9695
function: Optional[DeltaFunctionCall] = None
9796

9897

98+
class ExtractedToolCallInformation(BaseModel):
99+
# indicate if tools were called
100+
tools_called: bool
101+
102+
# extracted tool calls
103+
tool_calls: Optional[list[ToolCall]] = None
104+
105+
# content - per OpenAI spec, content AND tool calls can be returned rarely
106+
# But some models will do this intentionally
107+
content: Optional[str] = None
108+
109+
99110
class FunctionDefinition(BaseModel):
100111
"""
101112
Function definition.

fastdeploy/entrypoints/openai/serving_chat.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ async def chat_completion_stream_generator(
141141
previous_num_tokens = 0
142142
num_prompt_tokens = 0
143143
num_choices = 1
144+
tool_called = False
144145
max_streaming_response_tokens = (
145146
request.max_streaming_response_tokens
146147
if request.max_streaming_response_tokens is not None
@@ -244,20 +245,28 @@ async def chat_completion_stream_generator(
244245
output = res["outputs"]
245246
delta_text = output["text"]
246247
output_top_logprobs = output["top_logprobs"]
248+
previous_num_tokens += len(output["token_ids"])
247249
logprobs_res: Optional[LogProbs] = None
248250
if request.logprobs and output_top_logprobs is not None:
249251
logprobs_res = self._create_chat_logprobs(
250252
output_top_logprobs, request.logprobs, request.top_logprobs
251253
)
252-
253-
previous_num_tokens += len(output["token_ids"])
254-
delta_message = DeltaMessage(
255-
content=delta_text,
256-
reasoning_content=output.get("reasoning_content"),
257-
prompt_token_ids=None,
258-
completion_token_ids=None,
259-
tool_calls=output.get("tool_call_content", []),
260-
)
254+
if self.engine_client.data_processor.tool_parser_obj and not res["finished"]:
255+
tool_delta_message = output["tool_delta_message"]
256+
if tool_delta_message is None:
257+
continue
258+
delta_message = tool_delta_message
259+
delta_message.reasoning_content = output.get("reasoning_content")
260+
if delta_message.tool_calls:
261+
tool_called = True
262+
else:
263+
delta_message = DeltaMessage(
264+
content=delta_text,
265+
reasoning_content=output.get("reasoning_content"),
266+
prompt_token_ids=None,
267+
completion_token_ids=None,
268+
tool_calls=None,
269+
)
261270

262271
choice = ChatCompletionResponseStreamChoice(
263272
index=0,
@@ -274,10 +283,7 @@ async def chat_completion_stream_generator(
274283
max_tokens = request.max_completion_tokens or request.max_tokens
275284
if has_no_token_limit or previous_num_tokens != max_tokens:
276285
choice.finish_reason = "stop"
277-
if (
278-
self.engine_client.reasoning_parser == "ernie_x1"
279-
and output.get("finish_reason", "") == "tool_calls"
280-
):
286+
if tool_called:
281287
choice.finish_reason = "tool_calls"
282288
else:
283289
choice.finish_reason = "length"
@@ -414,7 +420,7 @@ async def chat_completion_full_generator(
414420
role="assistant",
415421
content=output["text"],
416422
reasoning_content=output.get("reasoning_content"),
417-
tool_calls=output.get("tool_call_content"),
423+
tool_calls=output.get("tool_call"),
418424
prompt_token_ids=prompt_token_ids if request.return_token_ids else None,
419425
completion_token_ids=completion_token_ids if request.return_token_ids else None,
420426
text_after_process=text_after_process if request.return_token_ids else None,
@@ -434,7 +440,7 @@ async def chat_completion_full_generator(
434440
max_tokens = request.max_completion_tokens or request.max_tokens
435441
if has_no_token_limit or previous_num_tokens != max_tokens:
436442
choice.finish_reason = "stop"
437-
if self.engine_client.reasoning_parser == "ernie_x1" and output.get("finish_reason", "") == "tool_calls":
443+
if output.get("tool_call"):
438444
choice.finish_reason = "tool_calls"
439445
else:
440446
choice.finish_reason = "length"

0 commit comments

Comments
 (0)