Skip to content

Commit eda83ca

Browse files
luukunnAuferGachet
andauthored
add Tool Parser (#3272)
* add tool-parser * add tool-parser * add tool parser * add tool parser * fix * add offline * add offline * fix * parsers:tool&reasoning * 修改tool parser名称· * update * fix reasoning-parser * add requirements * fix finish reason * fix * fix reasoning-parser * fix * fix * fix * fix * fix --------- Co-authored-by: zhuzixuan <[email protected]>
1 parent 2d1a4ca commit eda83ca

23 files changed

+1054
-36
lines changed

fastdeploy/engine/args_utils.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
"""
1616

1717
import json
18+
import os
1819
from dataclasses import asdict, dataclass
1920
from dataclasses import fields as dataclass_fields
2021
from typing import Any, Dict, List, Optional
21-
import os
2222

2323
from fastdeploy.config import (
2424
CacheConfig,
@@ -94,6 +94,14 @@ class EngineArgs:
9494
"""
9595
specifies the reasoning parser to use for extracting reasoning content from the model output
9696
"""
97+
tool_call_parser: str = None
98+
"""
99+
specifies the tool call parser to use for extracting tool call from the model output
100+
"""
101+
tool_parser_plugin: str = None
102+
"""
103+
tool parser plugin used to register user defined tool parsers
104+
"""
97105
enable_mm: bool = False
98106
"""
99107
Flags to enable multi-modal model
@@ -434,6 +442,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
434442
help="Flag specifies the reasoning parser to use for extracting "
435443
"reasoning content from the model output",
436444
)
445+
model_group.add_argument(
446+
"--tool-call-parser",
447+
type=str,
448+
default=EngineArgs.tool_call_parser,
449+
help="Flag specifies the tool call parser to use for extracting" "tool call from the model output",
450+
)
451+
model_group.add_argument(
452+
"--tool-parser-plugin",
453+
type=str,
454+
default=EngineArgs.tool_parser_plugin,
455+
help="tool parser plugin used to register user defined tool parsers",
456+
)
437457
model_group.add_argument(
438458
"--speculative-config",
439459
type=json.loads,
@@ -885,10 +905,10 @@ def create_engine_config(self) -> Config:
885905
if self.enable_chunked_prefill:
886906
self.max_num_batched_tokens = 2048
887907
else:
888-
if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')):
908+
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
889909
self.max_num_batched_tokens = self.max_model_len
890910
else:
891-
self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
911+
self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
892912

893913
all_dict = asdict(self)
894914
all_dict["model_cfg"] = model_cfg
@@ -927,6 +947,7 @@ def create_engine_config(self) -> Config:
927947
mm_processor_kwargs=self.mm_processor_kwargs,
928948
# enable_mm=self.enable_mm,
929949
reasoning_parser=self.reasoning_parser,
950+
tool_parser=self.tool_call_parser,
930951
splitwise_role=self.splitwise_role,
931952
innode_prefill_ports=self.innode_prefill_ports,
932953
max_num_partial_prefills=self.max_num_partial_prefills,

fastdeploy/engine/config.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def __init__(
8686
max_long_partial_prefills: int = 1,
8787
long_prefill_token_threshold: int = 0,
8888
reasoning_parser: str = None,
89+
tool_parser: str = None,
8990
guided_decoding_backend: Optional[str] = None,
9091
disable_any_whitespace: bool = False,
9192
enable_logprob: bool = False,
@@ -166,6 +167,7 @@ def __init__(
166167
self.max_long_partial_prefills = max_long_partial_prefills
167168
self.long_prefill_token_threshold = long_prefill_token_threshold
168169
self.reasoning_parser = reasoning_parser
170+
self.tool_parser = tool_parser
169171
self.graph_optimization_config = graph_optimization_config
170172
self.early_stop_config = early_stop_config
171173
self.guided_decoding_backend = guided_decoding_backend
@@ -245,10 +247,10 @@ def postprocess(self):
245247
if self.cache_config.enable_chunked_prefill:
246248
self.max_num_batched_tokens = 2048
247249
else:
248-
if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')):
250+
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
249251
self.max_num_batched_tokens = self.max_model_len
250252
else:
251-
self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
253+
self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
252254

253255
if self.long_prefill_token_threshold == 0:
254256
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
@@ -296,7 +298,7 @@ def check(self):
296298
)
297299

298300
if not self.cache_config.enable_chunked_prefill:
299-
if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')):
301+
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
300302
assert self.max_num_batched_tokens >= self.max_model_len, (
301303
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
302304
f"should be larger than or equal to max_model_len: {self.max_model_len}"

fastdeploy/engine/engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def __init__(self, cfg):
106106
cfg.limit_mm_per_prompt,
107107
cfg.mm_processor_kwargs,
108108
cfg.enable_mm,
109+
cfg.tool_parser,
109110
)
110111

111112
self.start_queue_service()

fastdeploy/engine/request.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import numpy as np
2525

2626
from fastdeploy.engine.sampling_params import SamplingParams
27+
from fastdeploy.entrypoints.openai.protocol import ToolCall
2728
from fastdeploy.utils import data_processor_logger
2829
from fastdeploy.worker.output import LogprobsLists, SampleLogprobs
2930

@@ -249,6 +250,7 @@ class CompletionOutput:
249250
draft_token_ids: list[int] = None
250251
text: Optional[str] = None
251252
reasoning_content: Optional[str] = None
253+
tool_calls: Optional[ToolCall] = None
252254

253255
def to_dict(self):
254256
"""

fastdeploy/entrypoints/chat_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
"""
1616

17+
import uuid
1718
from copy import deepcopy
1819
from typing import List, Literal, Union
1920
from urllib.parse import urlparse
@@ -156,3 +157,7 @@ def parse_chat_messages(messages):
156157

157158
conversation.append({"role": role, "content": parsed_content})
158159
return conversation
160+
161+
162+
def random_tool_call_id() -> str:
163+
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"

fastdeploy/entrypoints/engine_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
data_parallel_size=1,
5050
enable_logprob=False,
5151
workers=1,
52+
tool_parser=None,
5253
):
5354
import fastdeploy.model_executor.models # noqa: F401
5455

@@ -64,6 +65,7 @@ def __init__(
6465
limit_mm_per_prompt,
6566
mm_processor_kwargs,
6667
self.enable_mm,
68+
tool_parser,
6769
)
6870
self.enable_logprob = enable_logprob
6971
self.reasoning_parser = reasoning_parser

fastdeploy/entrypoints/llm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from fastdeploy.engine.args_utils import EngineArgs
2929
from fastdeploy.engine.engine import LLMEngine
3030
from fastdeploy.engine.sampling_params import SamplingParams
31+
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
3132
from fastdeploy.plugins.model_register import load_model_register_plugins
3233
from fastdeploy.utils import (
3334
deprecated_kwargs_warning,
@@ -79,6 +80,9 @@ def __init__(
7980

8081
load_model_register_plugins()
8182
model = retrive_model_from_server(model, revision)
83+
tool_parser_plugin = kwargs.get("tool_parser_plugin")
84+
if tool_parser_plugin:
85+
ToolParserManager.import_tool_parser(tool_parser_plugin)
8286
engine_args = EngineArgs(
8387
model=model,
8488
tokenizer=tokenizer,

fastdeploy/entrypoints/openai/api_server.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
)
4242
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
4343
from fastdeploy.entrypoints.openai.serving_completion import OpenAIServingCompletion
44+
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
4445
from fastdeploy.metrics.metrics import (
4546
EXCLUDE_LABELS,
4647
cleanup_prometheus_files,
@@ -74,7 +75,8 @@
7475
parser = EngineArgs.add_cli_args(parser)
7576
args = parser.parse_args()
7677
args.model = retrive_model_from_server(args.model, args.revision)
77-
78+
if args.tool_parser_plugin:
79+
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
7880
llm_engine = None
7981

8082

@@ -134,6 +136,7 @@ async def lifespan(app: FastAPI):
134136
args.data_parallel_size,
135137
args.enable_logprob,
136138
args.workers,
139+
args.tool_call_parser,
137140
)
138141
app.state.dynamic_load_weight = args.dynamic_load_weight
139142
chat_handler = OpenAIServingChat(engine_client, pid, args.ips, args.max_waiting_time)

fastdeploy/entrypoints/openai/protocol.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ class ToolCall(BaseModel):
7272
id: str = None
7373
type: Literal["function"] = "function"
7474
function: FunctionCall
75-
index: int
7675

7776

7877
class DeltaFunctionCall(BaseModel):
@@ -96,6 +95,18 @@ class DeltaToolCall(BaseModel):
9695
function: Optional[DeltaFunctionCall] = None
9796

9897

98+
class ExtractedToolCallInformation(BaseModel):
99+
# indicate if tools were called
100+
tools_called: bool
101+
102+
# extracted tool calls
103+
tool_calls: Optional[list[ToolCall]] = None
104+
105+
# content - per OpenAI spec, content AND tool calls can be returned rarely
106+
# But some models will do this intentionally
107+
content: Optional[str] = None
108+
109+
99110
class FunctionDefinition(BaseModel):
100111
"""
101112
Function definition.

fastdeploy/entrypoints/openai/serving_chat.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ async def chat_completion_stream_generator(
140140
previous_num_tokens = 0
141141
num_prompt_tokens = 0
142142
num_choices = 1
143+
tool_called = False
143144
max_streaming_response_tokens = (
144145
request.max_streaming_response_tokens
145146
if request.max_streaming_response_tokens is not None
@@ -243,20 +244,28 @@ async def chat_completion_stream_generator(
243244
output = res["outputs"]
244245
delta_text = output["text"]
245246
output_top_logprobs = output["top_logprobs"]
247+
previous_num_tokens += len(output["token_ids"])
246248
logprobs_res: Optional[LogProbs] = None
247249
if request.logprobs and output_top_logprobs is not None:
248250
logprobs_res = self._create_chat_logprobs(
249251
output_top_logprobs, request.logprobs, request.top_logprobs
250252
)
251-
252-
previous_num_tokens += len(output["token_ids"])
253-
delta_message = DeltaMessage(
254-
content=delta_text,
255-
reasoning_content=output.get("reasoning_content"),
256-
prompt_token_ids=None,
257-
completion_token_ids=None,
258-
tool_calls=output.get("tool_call_content", []),
259-
)
253+
if self.engine_client.data_processor.tool_parser_obj and not res["finished"]:
254+
tool_delta_message = output["tool_delta_message"]
255+
if tool_delta_message is None:
256+
continue
257+
delta_message = tool_delta_message
258+
delta_message.reasoning_content = output.get("reasoning_content")
259+
if delta_message.tool_calls:
260+
tool_called = True
261+
else:
262+
delta_message = DeltaMessage(
263+
content=delta_text,
264+
reasoning_content=output.get("reasoning_content"),
265+
prompt_token_ids=None,
266+
completion_token_ids=None,
267+
tool_calls=None,
268+
)
260269

261270
choice = ChatCompletionResponseStreamChoice(
262271
index=0,
@@ -273,10 +282,7 @@ async def chat_completion_stream_generator(
273282
max_tokens = request.max_completion_tokens or request.max_tokens
274283
if has_no_token_limit or previous_num_tokens != max_tokens:
275284
choice.finish_reason = "stop"
276-
if (
277-
self.engine_client.reasoning_parser == "ernie_x1"
278-
and output.get("finish_reason", "") == "tool_calls"
279-
):
285+
if tool_called:
280286
choice.finish_reason = "tool_calls"
281287
else:
282288
choice.finish_reason = "length"
@@ -412,7 +418,7 @@ async def chat_completion_full_generator(
412418
role="assistant",
413419
content=output["text"],
414420
reasoning_content=output.get("reasoning_content"),
415-
tool_calls=output.get("tool_call_content"),
421+
tool_calls=output.get("tool_call"),
416422
prompt_token_ids=prompt_token_ids if request.return_token_ids else None,
417423
completion_token_ids=completion_token_ids if request.return_token_ids else None,
418424
text_after_process=text_after_process if request.return_token_ids else None,

0 commit comments

Comments
 (0)