Skip to content

Commit 149a0f4

Browse files
committed
Configure the tool_choice to use a specific tool
Signed-off-by: MondayCha <[email protected]>
1 parent 0b25c88 commit 149a0f4

File tree

5 files changed

+236
-21
lines changed

5 files changed

+236
-21
lines changed

pkg/common/tools_utils.go

Lines changed: 75 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"fmt"
2222

2323
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
24+
"github.com/openai/openai-go/v3/packages/param"
2425
"github.com/santhosh-tekuri/jsonschema/v5"
2526
)
2627

@@ -52,24 +53,88 @@ var fakeStringArguments = []string{
5253
`lifetime`,
5354
}
5455

55-
// CreateToolCalls creates and returns response payload based on this request
56-
// (tool calls or nothing in case we randomly choose not to generate calls),
57-
// and the number of generated completion token sand the finish reason
58-
func CreateToolCalls(tools []openaiserverapi.Tool, toolChoice string, config *Configuration) ([]openaiserverapi.ToolCall, int, error) {
59-
// This function is called if tool choice is either 'required' or 'auto'.
60-
// In case of 'required' at least one tool call has to be created, and we randomly choose
61-
// the number of calls starting from one. Otherwise, we start from 0, and in case we randomly
62-
// choose the number of calls to be 0, response text will be generated instead of a tool call.
56+
// IsToolChoiceNone checks if the tool_choice is set to "none".
57+
func IsToolChoiceNone(toolChoice openaiserverapi.ToolChoice) bool {
58+
if !param.IsOmitted(toolChoice.OfAuto) {
59+
val := toolChoice.OfAuto.Or("")
60+
return val == ToolChoiceNone
61+
}
62+
return false
63+
}
64+
65+
// CreateToolCalls creates and returns tool calls based on the request's tool
66+
// definitions and tool_choice parameter.
67+
// The tool_choice parameter controls how the model responds to function calls:
68+
// - "none": The model does not call any tools. This case should be handled
69+
// before calling this function.
70+
// - "auto": The model can choose to either generate a message or call one or
71+
// more tools. This is the default behavior.
72+
// - "required": The model must call one or more tools.
73+
// - Specific function: A specific tool can be forced by providing an object
74+
// like `{"type": "function", "function": {"name": "my_function"}}`. The
75+
// model will be constrained to call that exact tool.
76+
//
77+
// This function returns the generated tool calls, the number of completion
78+
// tokens used, and an error if one occurs (e.g., if a specified tool is not found).
79+
func CreateToolCalls(
80+
tools []openaiserverapi.Tool,
81+
toolChoice openaiserverapi.ToolChoice,
82+
config *Configuration,
83+
) ([]openaiserverapi.ToolCall, int, error) {
84+
// If a specific function is required.
85+
if functionChoice := toolChoice.GetFunction(); functionChoice != nil {
86+
requiredFuncName := functionChoice.Name
87+
var targetTool *openaiserverapi.Tool
88+
89+
// Find the specified tool in the list of available tools.
90+
for i, tool := range tools {
91+
if tool.Function.Name == requiredFuncName {
92+
targetTool = &tools[i]
93+
break
94+
}
95+
}
96+
97+
if targetTool == nil {
98+
return nil, 0, fmt.Errorf("tool with name '%s' requested in tool_choice but not found in the tools list", requiredFuncName)
99+
}
100+
101+
// Generate arguments for the specific tool.
102+
args, err := generateToolArguments(*targetTool, config)
103+
if err != nil {
104+
return nil, 0, err
105+
}
106+
argsJson, err := json.Marshal(args)
107+
if err != nil {
108+
return nil, 0, err
109+
}
110+
111+
call := openaiserverapi.ToolCall{
112+
Function: openaiserverapi.FunctionCall{
113+
Arguments: string(argsJson),
114+
TokenizedArguments: Tokenize(string(argsJson)),
115+
Name: &targetTool.Function.Name,
116+
},
117+
ID: "chatcmpl-tool-" + RandomNumericString(10),
118+
Type: "function",
119+
Index: 0,
120+
}
121+
calls := []openaiserverapi.ToolCall{call}
122+
return calls, CountTokensForToolCalls(calls), nil
123+
}
124+
125+
// Tool choice is 'auto' or 'required'.
126+
// In 'required' mode, at least one tool call must be created.
127+
// In 'auto' mode, the number of calls can be zero, leading to a text response instead.
63128
min := 0
64-
if toolChoice == ToolChoiceRequired {
129+
if !param.IsOmitted(toolChoice.OfAuto) && toolChoice.OfAuto.Or("") == ToolChoiceRequired {
65130
min = 1
66131
}
67132
numberOfCalls := RandomInt(min, len(tools))
68133
if numberOfCalls == 0 {
69134
return nil, 0, nil
70135
}
71136

72-
calls := make([]openaiserverapi.ToolCall, 0)
137+
calls := make([]openaiserverapi.ToolCall, 0, numberOfCalls)
73138
for i := range numberOfCalls {
74139
// Randomly choose which tools to call. We may call the same tool more than once.
75140
index := RandomInt(0, len(tools)-1)

pkg/llm-d-inference-sim/tools_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,59 @@ var _ = Describe("Simulator for request with tools", func() {
499499
Entry(nil, common.ModeRandom),
500500
)
501501

502+
DescribeTable("no streaming, a specific tool",
503+
func(mode string, specificTool string) {
504+
ctx := context.TODO()
505+
client, err := startServer(ctx, mode)
506+
Expect(err).NotTo(HaveOccurred())
507+
508+
openaiclient, params := getOpenAIClientAndChatParams(client, model, userMessage, false)
509+
params.ToolChoice = openai.ToolChoiceOptionFunctionToolChoice(openai.ChatCompletionNamedToolChoiceFunctionParam{
510+
Name: specificTool,
511+
})
512+
params.Tools = tools
513+
514+
resp, err := openaiclient.Chat.Completions.New(ctx, params)
515+
Expect(err).NotTo(HaveOccurred())
516+
Expect(resp.Choices).ShouldNot(BeEmpty())
517+
Expect(string(resp.Object)).To(Equal(chatCompletionObject))
518+
519+
Expect(resp.Usage.PromptTokens).To(Equal(userMsgTokens))
520+
Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0))
521+
Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens))
522+
523+
content := resp.Choices[0].Message.Content
524+
Expect(content).Should(BeEmpty())
525+
526+
toolCalls := resp.Choices[0].Message.ToolCalls
527+
Expect(toolCalls).ToNot(BeEmpty())
528+
for _, tc := range toolCalls {
529+
Expect(tc.Function.Name).To(Equal(specificTool))
530+
Expect(tc.ID).NotTo(BeEmpty())
531+
Expect(tc.Type).To(Equal("function"))
532+
args := make(map[string]string)
533+
err := json.Unmarshal([]byte(tc.Function.Arguments), &args)
534+
Expect(err).NotTo(HaveOccurred())
535+
536+
if tc.Function.Name == "get_weather" {
537+
Expect(tc.Function.Arguments).To(ContainSubstring("location"))
538+
} else {
539+
Expect(tc.Function.Arguments).To(ContainSubstring("city"))
540+
Expect(tc.Function.Arguments).To(ContainSubstring("unit"))
541+
Expect(args["unit"]).To(Or(Equal("C"), Equal("F")))
542+
}
543+
}
544+
},
545+
func(mode string, specificTool string) string {
546+
return "mode: " + mode + ", specificTool: " + specificTool
547+
},
548+
// Call several times because the tools and arguments are chosen randomly
549+
Entry(nil, common.ModeRandom, "get_weather"),
550+
Entry(nil, common.ModeRandom, "get_temperature"),
551+
Entry(nil, common.ModeRandom, "get_weather"),
552+
Entry(nil, common.ModeRandom, "get_temperature"),
553+
)
554+
502555
DescribeTable("check validator",
503556
func(mode string) {
504557
ctx := context.TODO()

pkg/llm-d-inference-sim/worker.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ func (s *VllmSimulator) processRequest(reqCtx *openaiserverapi.CompletionReqCtx)
8888
var toolCalls []openaiserverapi.ToolCall
8989
var completionTokens int
9090
if reqCtx.IsChatCompletion &&
91-
req.GetToolChoice() != common.ToolChoiceNone &&
91+
!common.IsToolChoiceNone(req.GetToolChoice()) &&
9292
req.GetTools() != nil {
9393
toolCalls, completionTokens, err =
9494
common.CreateToolCalls(req.GetTools(), req.GetToolChoice(), s.config)

pkg/openai-server-api/request.go

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ type CompletionRequest interface {
4646
SetNumberOfCachedPromptTokens(cachedPromptTokens int)
4747
// GetPrompt returns the prompt
4848
GetPrompt() string
49-
// GetTools() returns tools to use (in chat completion)
49+
// GetTools returns tools to use (in chat completion)
5050
GetTools() []Tool
51-
// GetToolChoice() returns tool choice (in chat completion)
52-
GetToolChoice() string
51+
// GetToolChoice returns tool choice (in chat completion)
52+
GetToolChoice() ToolChoice
5353
// GetMaxCompletionTokens returns the maximum completion tokens requested
5454
GetMaxCompletionTokens() *int64
5555
// GetIgnoreEOS returns true if the end-of-sequence tokens will be ignored
@@ -184,12 +184,13 @@ type ChatCompletionRequest struct {
184184
// Tools is a list of tools the model may call.
185185
Tools []Tool `json:"tools,omitempty"`
186186

187-
// ToolChoice controls which (if any) tool is called by the model,
188-
// possible values: none, auto, required.
189-
// Sending an object with a specific tool, is currently not supported.
190-
ToolChoice string `json:"tool_choice,omitempty"`
187+
// ToolChoice controls which (if any) tool is called by the model.
188+
// It can be a string ("none", "auto", "required") or an object specifying the function.
189+
ToolChoice ToolChoice `json:"tool_choice,omitzero"`
191190
}
192191

192+
var _ CompletionRequest = (*ChatCompletionRequest)(nil)
193+
193194
// function defines a tool
194195
type function struct {
195196
// Name is the function's name
@@ -221,7 +222,7 @@ func (c *ChatCompletionRequest) GetTools() []Tool {
221222
return c.Tools
222223
}
223224

224-
func (c *ChatCompletionRequest) GetToolChoice() string {
225+
func (c *ChatCompletionRequest) GetToolChoice() ToolChoice {
225226
return c.ToolChoice
226227
}
227228

@@ -286,6 +287,8 @@ type TextCompletionRequest struct {
286287
MaxTokens *int64 `json:"max_tokens"`
287288
}
288289

290+
var _ CompletionRequest = (*TextCompletionRequest)(nil)
291+
289292
func (t *TextCompletionRequest) GetPrompt() string {
290293
return t.Prompt
291294
}
@@ -294,8 +297,8 @@ func (c *TextCompletionRequest) GetTools() []Tool {
294297
return nil
295298
}
296299

297-
func (c *TextCompletionRequest) GetToolChoice() string {
298-
return ""
300+
func (c *TextCompletionRequest) GetToolChoice() ToolChoice {
301+
return ToolChoice{}
299302
}
300303

301304
func (c *TextCompletionRequest) GetMaxCompletionTokens() *int64 {
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
Copyright 2025 The llm-d-inference-sim Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
// Contains structures and functions related to requests for all supported APIs
18+
package openaiserverapi
19+
20+
import (
21+
"encoding/json"
22+
"fmt"
23+
24+
"github.com/openai/openai-go/v3"
25+
"github.com/openai/openai-go/v3/packages/param"
26+
)
27+
28+
// ToolChoice is a wrapper around ChatCompletionToolChoiceOptionUnionParam that
29+
// provides custom JSON unmarshalling logic to correctly handle
30+
// the union type.
31+
type ToolChoice struct {
32+
openai.ChatCompletionToolChoiceOptionUnionParam
33+
}
34+
35+
// MarshalJSON forwards the marshalling process to the embedded
36+
// ChatCompletionToolChoiceOptionUnionParam's MarshalJSON method,
37+
// which is known to work correctly.
38+
func (t ToolChoice) MarshalJSON() ([]byte, error) {
39+
return t.ChatCompletionToolChoiceOptionUnionParam.MarshalJSON()
40+
}
41+
42+
// UnmarshalJSON provides custom logic to correctly deserialize the JSON data
43+
// into the appropriate field of the embedded union type. It inspects the JSON
44+
// structure to determine if it's a simple string or a complex object with a
45+
// 'type' discriminator field.
46+
func (t *ToolChoice) UnmarshalJSON(data []byte) error {
47+
// If the input is a simple string (e.g., "auto", "none", "required"),
48+
// unmarshal it into the OfAuto field.
49+
if data[0] == '"' {
50+
var strValue string
51+
if err := json.Unmarshal(data, &strValue); err != nil {
52+
return err
53+
}
54+
t.OfAuto = param.NewOpt(strValue)
55+
return nil
56+
}
57+
58+
// If the input is a JSON object, we need to determine its type.
59+
// We use a temporary struct to detect the 'type' field.
60+
var typeDetector struct {
61+
Type string `json:"type"`
62+
}
63+
64+
// We only care about the type field, ignore other fields
65+
if err := json.Unmarshal(data, &typeDetector); err != nil {
66+
return fmt.Errorf("failed to detect type for ToolChoice: %w", err)
67+
}
68+
69+
// Based on the detected type, unmarshal the data into the correct struct.
70+
switch typeDetector.Type {
71+
case "function":
72+
var functionChoice openai.ChatCompletionNamedToolChoiceParam
73+
if err := functionChoice.UnmarshalJSON(data); err != nil {
74+
return err
75+
}
76+
t.OfFunctionToolChoice = &functionChoice
77+
case "custom":
78+
var customChoice openai.ChatCompletionNamedToolChoiceCustomParam
79+
if err := customChoice.UnmarshalJSON(data); err != nil {
80+
return err
81+
}
82+
t.OfCustomToolChoice = &customChoice
83+
case "allowed_tools":
84+
var allowedToolsChoice openai.ChatCompletionAllowedToolChoiceParam
85+
if err := allowedToolsChoice.UnmarshalJSON(data); err != nil {
86+
return err
87+
}
88+
t.OfAllowedTools = &allowedToolsChoice
89+
default:
90+
return fmt.Errorf("unknown ToolChoice type: %s", typeDetector.Type)
91+
}
92+
93+
return nil
94+
}

0 commit comments

Comments
 (0)