Skip to content

Commit 2bc6b38

Browse files
committed
Configure the tool_choice to use a specific tool
Signed-off-by: Yilong Li <[email protected]>
1 parent 0b25c88 commit 2bc6b38

File tree

3 files changed

+91
-21
lines changed

3 files changed

+91
-21
lines changed

pkg/common/tools_utils.go

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import (
2121
"fmt"
2222

2323
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
24+
"github.com/openai/openai-go/v3"
25+
"github.com/openai/openai-go/v3/packages/param"
2426
"github.com/santhosh-tekuri/jsonschema/v5"
2527
)
2628

@@ -52,24 +54,88 @@ var fakeStringArguments = []string{
5254
`lifetime`,
5355
}
5456

55-
// CreateToolCalls creates and returns response payload based on this request
56-
// (tool calls or nothing in case we randomly choose not to generate calls),
57-
// and the number of generated completion token sand the finish reason
58-
func CreateToolCalls(tools []openaiserverapi.Tool, toolChoice string, config *Configuration) ([]openaiserverapi.ToolCall, int, error) {
59-
// This function is called if tool choice is either 'required' or 'auto'.
60-
// In case of 'required' at least one tool call has to be created, and we randomly choose
61-
// the number of calls starting from one. Otherwise, we start from 0, and in case we randomly
62-
// choose the number of calls to be 0, response text will be generated instead of a tool call.
57+
// IsToolChoiceNone checks if the tool_choice is set to "none".
58+
func IsToolChoiceNone(toolChoice openai.ChatCompletionToolChoiceOptionUnionParam) bool {
59+
if !param.IsOmitted(toolChoice.OfAuto) {
60+
val := toolChoice.OfAuto.Or("")
61+
return val == ToolChoiceNone
62+
}
63+
return false
64+
}
65+
66+
// CreateToolCalls creates and returns tool calls based on the request's tool
67+
// definitions and tool_choice parameter.
68+
// The tool_choice parameter controls how the model responds to function calls:
69+
// - "none": The model does not call any tools. This case should be handled
70+
// before calling this function.
71+
// - "auto": The model can choose to either generate a message or call one or
72+
// more tools. This is the default behavior.
73+
// - "required": The model must call one or more tools.
74+
// - Specific function: A specific tool can be forced by providing an object
75+
// like `{"type": "function", "function": {"name": "my_function"}}`. The
76+
// model will be constrained to call that exact tool.
77+
//
78+
// This function returns the generated tool calls, the number of completion
79+
// tokens used, and an error if one occurs (e.g., if a specified tool is not found).
80+
func CreateToolCalls(
81+
tools []openaiserverapi.Tool,
82+
toolChoice openai.ChatCompletionToolChoiceOptionUnionParam,
83+
config *Configuration,
84+
) ([]openaiserverapi.ToolCall, int, error) {
85+
// If a specific function is required.
86+
if functionChoice := toolChoice.GetFunction(); functionChoice != nil {
87+
requiredFuncName := functionChoice.Name
88+
var targetTool *openaiserverapi.Tool
89+
90+
// Find the specified tool in the list of available tools.
91+
for i, tool := range tools {
92+
if tool.Function.Name == requiredFuncName {
93+
targetTool = &tools[i]
94+
break
95+
}
96+
}
97+
98+
if targetTool == nil {
99+
return nil, 0, fmt.Errorf("tool with name '%s' requested in tool_choice but not found in the tools list", requiredFuncName)
100+
}
101+
102+
// Generate arguments for the specific tool.
103+
args, err := generateToolArguments(*targetTool, config)
104+
if err != nil {
105+
return nil, 0, err
106+
}
107+
argsJson, err := json.Marshal(args)
108+
if err != nil {
109+
return nil, 0, err
110+
}
111+
112+
call := openaiserverapi.ToolCall{
113+
Function: openaiserverapi.FunctionCall{
114+
Arguments: string(argsJson),
115+
TokenizedArguments: Tokenize(string(argsJson)),
116+
Name: &targetTool.Function.Name,
117+
},
118+
ID: "chatcmpl-tool-" + RandomNumericString(10),
119+
Type: "function",
120+
Index: 0,
121+
}
122+
calls := []openaiserverapi.ToolCall{call}
123+
return calls, CountTokensForToolCalls(calls), nil
124+
}
125+
126+
// Tool choice is 'auto' or 'required'.
127+
// In 'required' mode, at least one tool call must be created.
128+
// In 'auto' mode, the number of calls can be zero, leading to a text response instead.
63129
min := 0
64-
if toolChoice == ToolChoiceRequired {
130+
if !param.IsOmitted(toolChoice.OfAuto) && toolChoice.OfAuto.Or("") == ToolChoiceRequired {
65131
min = 1
66132
}
67133
numberOfCalls := RandomInt(min, len(tools))
68134
if numberOfCalls == 0 {
69135
return nil, 0, nil
70136
}
71137

72-
calls := make([]openaiserverapi.ToolCall, 0)
138+
calls := make([]openaiserverapi.ToolCall, 0, numberOfCalls)
73139
for i := range numberOfCalls {
74140
// Randomly choose which tools to call. We may call the same tool more than once.
75141
index := RandomInt(0, len(tools)-1)

pkg/llm-d-inference-sim/worker.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ func (s *VllmSimulator) processRequest(reqCtx *openaiserverapi.CompletionReqCtx)
8888
var toolCalls []openaiserverapi.ToolCall
8989
var completionTokens int
9090
if reqCtx.IsChatCompletion &&
91-
req.GetToolChoice() != common.ToolChoiceNone &&
91+
!common.IsToolChoiceNone(req.GetToolChoice()) &&
9292
req.GetTools() != nil {
9393
toolCalls, completionTokens, err =
9494
common.CreateToolCalls(req.GetTools(), req.GetToolChoice(), s.config)

pkg/openai-server-api/request.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package openaiserverapi
2020
import (
2121
"sync"
2222

23+
"github.com/openai/openai-go/v3"
2324
"github.com/valyala/fasthttp"
2425
)
2526

@@ -46,10 +47,10 @@ type CompletionRequest interface {
4647
SetNumberOfCachedPromptTokens(cachedPromptTokens int)
4748
// GetPrompt returns the prompt
4849
GetPrompt() string
49-
// GetTools() returns tools to use (in chat completion)
50+
// GetTools returns tools to use (in chat completion)
5051
GetTools() []Tool
51-
// GetToolChoice() returns tool choice (in chat completion)
52-
GetToolChoice() string
52+
// GetToolChoice returns tool choice (in chat completion)
53+
GetToolChoice() openai.ChatCompletionToolChoiceOptionUnionParam
5354
// GetMaxCompletionTokens returns the maximum completion tokens requested
5455
GetMaxCompletionTokens() *int64
5556
// GetIgnoreEOS returns true if the end-of-sequence tokens will be ignored
@@ -184,12 +185,13 @@ type ChatCompletionRequest struct {
184185
// Tools is a list of tools the model may call.
185186
Tools []Tool `json:"tools,omitempty"`
186187

187-
// ToolChoice controls which (if any) tool is called by the model,
188-
// possible values: none, auto, required.
189-
// Sending an object with a specific tool, is currently not supported.
190-
ToolChoice string `json:"tool_choice,omitempty"`
188+
// ToolChoice controls which (if any) tool is called by the model.
189+
// It can be a string ("none", "auto", "required") or an object specifying the function.
190+
ToolChoice openai.ChatCompletionToolChoiceOptionUnionParam `json:"tool_choice,omitzero"`
191191
}
192192

193+
var _ CompletionRequest = (*ChatCompletionRequest)(nil)
194+
193195
// function defines a tool
194196
type function struct {
195197
// Name is the function's name
@@ -221,7 +223,7 @@ func (c *ChatCompletionRequest) GetTools() []Tool {
221223
return c.Tools
222224
}
223225

224-
func (c *ChatCompletionRequest) GetToolChoice() string {
226+
func (c *ChatCompletionRequest) GetToolChoice() openai.ChatCompletionToolChoiceOptionUnionParam {
225227
return c.ToolChoice
226228
}
227229

@@ -286,6 +288,8 @@ type TextCompletionRequest struct {
286288
MaxTokens *int64 `json:"max_tokens"`
287289
}
288290

291+
var _ CompletionRequest = (*TextCompletionRequest)(nil)
292+
289293
func (t *TextCompletionRequest) GetPrompt() string {
290294
return t.Prompt
291295
}
@@ -294,8 +298,8 @@ func (c *TextCompletionRequest) GetTools() []Tool {
294298
return nil
295299
}
296300

297-
func (c *TextCompletionRequest) GetToolChoice() string {
298-
return ""
301+
func (c *TextCompletionRequest) GetToolChoice() openai.ChatCompletionToolChoiceOptionUnionParam {
302+
return openai.ChatCompletionToolChoiceOptionUnionParam{}
299303
}
300304

301305
func (c *TextCompletionRequest) GetMaxCompletionTokens() *int64 {

0 commit comments

Comments
 (0)