Skip to content

Commit d71d5ac

Browse files
committed
Configure the tool_choice to use a specific tool
Signed-off-by: MondayCha <[email protected]>
1 parent 0b25c88 commit d71d5ac

File tree

5 files changed

+289
-59
lines changed

5 files changed

+289
-59
lines changed

pkg/common/tools_utils.go

Lines changed: 112 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"fmt"
2222

2323
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
24+
"github.com/openai/openai-go/v3/packages/param"
2425
"github.com/santhosh-tekuri/jsonschema/v5"
2526
)
2627

@@ -52,50 +53,125 @@ var fakeStringArguments = []string{
5253
`lifetime`,
5354
}
5455

55-
// CreateToolCalls creates and returns response payload based on this request
56-
// (tool calls or nothing in case we randomly choose not to generate calls),
57-
// and the number of generated completion token sand the finish reason
58-
func CreateToolCalls(tools []openaiserverapi.Tool, toolChoice string, config *Configuration) ([]openaiserverapi.ToolCall, int, error) {
59-
// This function is called if tool choice is either 'required' or 'auto'.
60-
// In case of 'required' at least one tool call has to be created, and we randomly choose
61-
// the number of calls starting from one. Otherwise, we start from 0, and in case we randomly
62-
// choose the number of calls to be 0, response text will be generated instead of a tool call.
63-
min := 0
64-
if toolChoice == ToolChoiceRequired {
65-
min = 1
66-
}
67-
numberOfCalls := RandomInt(min, len(tools))
68-
if numberOfCalls == 0 {
69-
return nil, 0, nil
56+
// IsToolChoiceNone checks if the tool_choice is set to "none".
57+
func IsToolChoiceNone(toolChoice openaiserverapi.ToolChoice) bool {
58+
if !param.IsOmitted(toolChoice.OfAuto) {
59+
val := toolChoice.OfAuto.Or("")
60+
return val == ToolChoiceNone
7061
}
62+
return false
63+
}
7164

72-
calls := make([]openaiserverapi.ToolCall, 0)
73-
for i := range numberOfCalls {
74-
// Randomly choose which tools to call. We may call the same tool more than once.
75-
index := RandomInt(0, len(tools)-1)
76-
args, err := generateToolArguments(tools[index], config)
77-
if err != nil {
78-
return nil, 0, err
65+
// CreateToolCalls creates and returns tool calls based on the request's tool
66+
// definitions and the tool_choice parameter.
67+
//
68+
// The [tool_choice](https://platform.openai.com/docs/guides/function-calling#tool-choice)
69+
// parameter controls how the model responds to function calls.
70+
//
71+
// This function handles the following cases for tool_choice:
72+
// - "none": The model will not call any tools. In this scenario, this function
73+
// should ideally be bypassed, as no tool calls will be generated.
74+
// - "auto": This is the default behavior where the model autonomously decides
75+
// whether to generate a message or call one or more tools from the provided list.
76+
// - "required": The model is constrained to call one or more of the available tools.
77+
// - Forced Function: A specific tool can be forced by providing an object with the
78+
// structure `{"type": "function", "function": {"name": "my_function"}}`.
79+
// The model will be restricted to calling only that specified tool.
80+
//
81+
// This function currently does not handle the following `tool_choice` scenarios:
82+
// - Forced Custom Tool: If `tool_choice` is set to `{"type": "custom", "name": "my_custom"}`,
83+
// this function will not be able to enforce the calling of a custom tool, as custom
84+
// tool types are not yet supported.
85+
// - Allowed Tools Subset: The functionality to restrict the model's tool-calling
86+
// capabilities to a specific subset of the available tools has not been implemented.
87+
//
88+
// This function returns the generated tool calls, the number of completion
89+
// tokens used, and an error if one occurs (e.g., if a specified tool is not found).
90+
func CreateToolCalls(
91+
tools []openaiserverapi.Tool,
92+
toolChoice openaiserverapi.ToolChoice,
93+
config *Configuration,
94+
) ([]openaiserverapi.ToolCall, int, error) {
95+
generateCalls := func(availableTools []openaiserverapi.Tool, minCalls int) ([]openaiserverapi.ToolCall, int, error) {
96+
if len(availableTools) == 0 {
97+
// If no tools are available to choose from, no calls can be made.
98+
return nil, 0, fmt.Errorf("no tools available to create tool calls")
7999
}
80-
argsJson, err := json.Marshal(args)
81-
if err != nil {
82-
return nil, 0, err
100+
101+
numberOfCalls := minCalls
102+
if len(availableTools) > minCalls {
103+
// Randomly decide how many tools to call, between minCalls and the total available.
104+
numberOfCalls = RandomInt(minCalls, len(availableTools))
105+
}
106+
107+
if numberOfCalls == 0 {
108+
return nil, 0, nil
109+
}
110+
111+
calls := make([]openaiserverapi.ToolCall, 0, numberOfCalls)
112+
for i := range numberOfCalls {
113+
// Randomly choose which tool to call. We may call the same tool more than once.
114+
index := 0
115+
if len(availableTools) > 1 {
116+
index = RandomInt(0, len(availableTools)-1)
117+
}
118+
chosenTool := availableTools[index]
119+
120+
args, err := generateToolArguments(chosenTool, config)
121+
if err != nil {
122+
return nil, 0, err
123+
}
124+
argsJson, err := json.Marshal(args)
125+
if err != nil {
126+
return nil, 0, err
127+
}
128+
129+
call := openaiserverapi.ToolCall{
130+
Function: openaiserverapi.FunctionCall{
131+
Arguments: string(argsJson),
132+
TokenizedArguments: Tokenize(string(argsJson)),
133+
Name: &chosenTool.Function.Name,
134+
},
135+
ID: "chatcmpl-tool-" + RandomNumericString(10),
136+
Type: "function",
137+
Index: i,
138+
}
139+
calls = append(calls, call)
83140
}
141+
return calls, CountTokensForToolCalls(calls), nil
142+
}
143+
144+
// A specific function is forced.
145+
if functionChoice := toolChoice.GetFunction(); functionChoice != nil {
146+
requiredFuncName := functionChoice.Name
147+
var targetTool *openaiserverapi.Tool
84148

85-
call := openaiserverapi.ToolCall{
86-
Function: openaiserverapi.FunctionCall{
87-
Arguments: string(argsJson),
88-
TokenizedArguments: Tokenize(string(argsJson)),
89-
Name: &tools[index].Function.Name,
90-
},
91-
ID: "chatcmpl-tool-" + RandomNumericString(10),
92-
Type: "function",
93-
Index: i,
149+
// Find the specified tool in the list of available tools.
150+
for i, tool := range tools {
151+
if tool.Function.Name == requiredFuncName {
152+
targetTool = &tools[i]
153+
break
154+
}
94155
}
95-
calls = append(calls, call)
156+
157+
if targetTool == nil {
158+
return nil, 0, fmt.Errorf("tool with name '%s' requested in tool_choice but not found in the tools list", requiredFuncName)
159+
}
160+
161+
specificTools := []openaiserverapi.Tool{*targetTool}
162+
163+
// Generate arguments for the specific tool.
164+
return generateCalls(specificTools, len(specificTools))
165+
}
166+
167+
// Default behavior for "auto" or "required".
168+
// The model can choose from any of the provided tools.
169+
min := 0
170+
if !param.IsOmitted(toolChoice.OfAuto) && toolChoice.OfAuto.Or("") == ToolChoiceRequired {
171+
min = 1
96172
}
97173

98-
return calls, CountTokensForToolCalls(calls), nil
174+
return generateCalls(tools, min)
99175
}
100176

101177
func getRequiredAsMap(property map[string]any) map[string]struct{} {

pkg/llm-d-inference-sim/tools_test.go

Lines changed: 70 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,16 @@ import (
3232
"github.com/openai/openai-go/v3/packages/param"
3333
)
3434

35+
const (
36+
functionNameGetWeather = "get_weather"
37+
functionNameGetTemperature = "get_temperature"
38+
)
39+
3540
var tools = []openai.ChatCompletionToolUnionParam{
3641
{
3742
OfFunction: &openai.ChatCompletionFunctionToolParam{
3843
Function: openai.FunctionDefinitionParam{
39-
Name: "get_weather",
44+
Name: functionNameGetWeather,
4045
Description: openai.String("Get weather at the given location"),
4146
Parameters: openai.FunctionParameters{
4247
"type": "object",
@@ -53,7 +58,7 @@ var tools = []openai.ChatCompletionToolUnionParam{
5358
{
5459
OfFunction: &openai.ChatCompletionFunctionToolParam{
5560
Function: openai.FunctionDefinitionParam{
56-
Name: "get_temperature",
61+
Name: functionNameGetTemperature,
5762
Description: openai.String("Get temperature at the given location"),
5863
Parameters: openai.FunctionParameters{
5964
"type": "object",
@@ -78,7 +83,7 @@ var invalidTools = [][]openai.ChatCompletionToolUnionParam{
7883
{
7984
OfFunction: &openai.ChatCompletionFunctionToolParam{
8085
Function: openai.FunctionDefinitionParam{
81-
Name: "get_weather",
86+
Name: functionNameGetWeather,
8287
Description: openai.String("Get weather at the given location"),
8388
Parameters: openai.FunctionParameters{
8489
"type": "object",
@@ -95,7 +100,7 @@ var invalidTools = [][]openai.ChatCompletionToolUnionParam{
95100
{
96101
OfFunction: &openai.ChatCompletionFunctionToolParam{
97102
Function: openai.FunctionDefinitionParam{
98-
Name: "get_temperature",
103+
Name: functionNameGetTemperature,
99104
Description: openai.String("Get temperature at the given location"),
100105
Parameters: openai.FunctionParameters{
101106
"type": "object",
@@ -119,7 +124,7 @@ var invalidTools = [][]openai.ChatCompletionToolUnionParam{
119124
{
120125
OfFunction: &openai.ChatCompletionFunctionToolParam{
121126
Function: openai.FunctionDefinitionParam{
122-
Name: "get_weather",
127+
Name: functionNameGetWeather,
123128
Description: openai.String("Get weather at the given location"),
124129
Parameters: openai.FunctionParameters{
125130
"type": "object",
@@ -139,7 +144,7 @@ var invalidTools = [][]openai.ChatCompletionToolUnionParam{
139144
{
140145
OfFunction: &openai.ChatCompletionFunctionToolParam{
141146
Function: openai.FunctionDefinitionParam{
142-
Name: "get_weather",
147+
Name: functionNameGetWeather,
143148
Description: openai.String("Get weather at the given location"),
144149
},
145150
},
@@ -312,7 +317,7 @@ var toolWithoutRequiredParams = []openai.ChatCompletionToolUnionParam{
312317
{
313318
OfFunction: &openai.ChatCompletionFunctionToolParam{
314319
Function: openai.FunctionDefinitionParam{
315-
Name: "get_temperature",
320+
Name: functionNameGetTemperature,
316321
Description: openai.String("Get temperature at the given location"),
317322
Parameters: openai.FunctionParameters{
318323
"type": "object",
@@ -398,7 +403,7 @@ var _ = Describe("Simulator for request with tools", func() {
398403
tc := toolCalls[0]
399404
Expect(tc.Index).To(Or(BeNumerically("==", lastIndex), BeNumerically("==", lastIndex+1)))
400405
if tc.Index > int64(lastIndex) {
401-
Expect(tc.Function.Name).To(Or(Equal("get_weather"), Equal("get_temperature")))
406+
Expect(tc.Function.Name).To(Or(Equal(functionNameGetWeather), Equal(functionNameGetTemperature)))
402407
lastIndex++
403408
args[tc.Function.Name] = []string{tc.Function.Arguments}
404409
functionName = tc.Function.Name
@@ -429,7 +434,7 @@ var _ = Describe("Simulator for request with tools", func() {
429434
err := json.Unmarshal([]byte(joinedArgs), &argsMap)
430435
Expect(err).NotTo(HaveOccurred())
431436

432-
if functionName == "get_weather" {
437+
if functionName == functionNameGetWeather {
433438
Expect(joinedArgs).To(ContainSubstring("location"))
434439
} else {
435440
Expect(joinedArgs).To(ContainSubstring("city"))
@@ -473,14 +478,14 @@ var _ = Describe("Simulator for request with tools", func() {
473478
toolCalls := resp.Choices[0].Message.ToolCalls
474479
Expect(toolCalls).ToNot(BeEmpty())
475480
for _, tc := range toolCalls {
476-
Expect(tc.Function.Name).To(Or(Equal("get_weather"), Equal("get_temperature")))
481+
Expect(tc.Function.Name).To(Or(Equal(functionNameGetWeather), Equal(functionNameGetTemperature)))
477482
Expect(tc.ID).NotTo(BeEmpty())
478483
Expect(tc.Type).To(Equal("function"))
479484
args := make(map[string]string)
480485
err := json.Unmarshal([]byte(tc.Function.Arguments), &args)
481486
Expect(err).NotTo(HaveOccurred())
482487

483-
if tc.Function.Name == "get_weather" {
488+
if tc.Function.Name == functionNameGetWeather {
484489
Expect(tc.Function.Arguments).To(ContainSubstring("location"))
485490
} else {
486491
Expect(tc.Function.Arguments).To(ContainSubstring("city"))
@@ -499,6 +504,59 @@ var _ = Describe("Simulator for request with tools", func() {
499504
Entry(nil, common.ModeRandom),
500505
)
501506

507+
DescribeTable("no streaming, a specific tool",
508+
func(mode string, specificTool string) {
509+
ctx := context.TODO()
510+
client, err := startServer(ctx, mode)
511+
Expect(err).NotTo(HaveOccurred())
512+
513+
openaiclient, params := getOpenAIClientAndChatParams(client, model, userMessage, false)
514+
params.ToolChoice = openai.ToolChoiceOptionFunctionToolChoice(openai.ChatCompletionNamedToolChoiceFunctionParam{
515+
Name: specificTool,
516+
})
517+
params.Tools = tools
518+
519+
resp, err := openaiclient.Chat.Completions.New(ctx, params)
520+
Expect(err).NotTo(HaveOccurred())
521+
Expect(resp.Choices).ShouldNot(BeEmpty())
522+
Expect(string(resp.Object)).To(Equal(chatCompletionObject))
523+
524+
Expect(resp.Usage.PromptTokens).To(Equal(userMsgTokens))
525+
Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0))
526+
Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens))
527+
528+
content := resp.Choices[0].Message.Content
529+
Expect(content).Should(BeEmpty())
530+
531+
toolCalls := resp.Choices[0].Message.ToolCalls
532+
Expect(toolCalls).ToNot(BeEmpty())
533+
for _, tc := range toolCalls {
534+
Expect(tc.Function.Name).To(Equal(specificTool))
535+
Expect(tc.ID).NotTo(BeEmpty())
536+
Expect(tc.Type).To(Equal("function"))
537+
args := make(map[string]string)
538+
err := json.Unmarshal([]byte(tc.Function.Arguments), &args)
539+
Expect(err).NotTo(HaveOccurred())
540+
541+
if tc.Function.Name == functionNameGetWeather {
542+
Expect(tc.Function.Arguments).To(ContainSubstring("location"))
543+
} else {
544+
Expect(tc.Function.Arguments).To(ContainSubstring("city"))
545+
Expect(tc.Function.Arguments).To(ContainSubstring("unit"))
546+
Expect(args["unit"]).To(Or(Equal("C"), Equal("F")))
547+
}
548+
}
549+
},
550+
func(mode string, specificTool string) string {
551+
return "mode: " + mode + ", specificTool: " + specificTool
552+
},
553+
// Call several times because the tools and arguments are chosen randomly
554+
Entry(nil, common.ModeRandom, functionNameGetWeather),
555+
Entry(nil, common.ModeRandom, functionNameGetTemperature),
556+
Entry(nil, common.ModeRandom, functionNameGetWeather),
557+
Entry(nil, common.ModeRandom, functionNameGetTemperature),
558+
)
559+
502560
DescribeTable("check validator",
503561
func(mode string) {
504562
ctx := context.TODO()
@@ -778,7 +836,7 @@ var _ = Describe("Simulator for request with tools", func() {
778836
toolCalls := resp.Choices[0].Message.ToolCalls
779837
Expect(toolCalls).To(HaveLen(1))
780838
tc := toolCalls[0]
781-
Expect(tc.Function.Name).To(Equal("get_temperature"))
839+
Expect(tc.Function.Name).To(Equal(functionNameGetTemperature))
782840
Expect(tc.ID).NotTo(BeEmpty())
783841
Expect(tc.Type).To(Equal("function"))
784842
args := make(map[string]string)

pkg/llm-d-inference-sim/worker.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ func (s *VllmSimulator) processRequest(reqCtx *openaiserverapi.CompletionReqCtx)
8888
var toolCalls []openaiserverapi.ToolCall
8989
var completionTokens int
9090
if reqCtx.IsChatCompletion &&
91-
req.GetToolChoice() != common.ToolChoiceNone &&
91+
!common.IsToolChoiceNone(req.GetToolChoice()) &&
9292
req.GetTools() != nil {
9393
toolCalls, completionTokens, err =
9494
common.CreateToolCalls(req.GetTools(), req.GetToolChoice(), s.config)

0 commit comments

Comments
 (0)