Skip to content

Commit 377ab7d

Browse files
committed
Configure the tool_choice to use a specific tool
Signed-off-by: MondayCha <[email protected]>
1 parent 0b25c88 commit 377ab7d

File tree

5 files changed

+290
-59
lines changed

5 files changed

+290
-59
lines changed

pkg/common/tools_utils.go

Lines changed: 113 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@ package common
1818

1919
import (
2020
"encoding/json"
21+
"errors"
2122
"fmt"
2223

2324
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
25+
"github.com/openai/openai-go/v3/packages/param"
2426
"github.com/santhosh-tekuri/jsonschema/v5"
2527
)
2628

@@ -52,50 +54,125 @@ var fakeStringArguments = []string{
5254
`lifetime`,
5355
}
5456

55-
// CreateToolCalls creates and returns response payload based on this request
56-
// (tool calls or nothing in case we randomly choose not to generate calls),
57-
// and the number of generated completion token sand the finish reason
58-
func CreateToolCalls(tools []openaiserverapi.Tool, toolChoice string, config *Configuration) ([]openaiserverapi.ToolCall, int, error) {
59-
// This function is called if tool choice is either 'required' or 'auto'.
60-
// In case of 'required' at least one tool call has to be created, and we randomly choose
61-
// the number of calls starting from one. Otherwise, we start from 0, and in case we randomly
62-
// choose the number of calls to be 0, response text will be generated instead of a tool call.
63-
min := 0
64-
if toolChoice == ToolChoiceRequired {
65-
min = 1
66-
}
67-
numberOfCalls := RandomInt(min, len(tools))
68-
if numberOfCalls == 0 {
69-
return nil, 0, nil
57+
// IsToolChoiceNone checks if the tool_choice is set to "none".
58+
func IsToolChoiceNone(toolChoice openaiserverapi.ToolChoice) bool {
59+
if !param.IsOmitted(toolChoice.OfAuto) {
60+
val := toolChoice.OfAuto.Or("")
61+
return val == ToolChoiceNone
7062
}
63+
return false
64+
}
7165

72-
calls := make([]openaiserverapi.ToolCall, 0)
73-
for i := range numberOfCalls {
74-
// Randomly choose which tools to call. We may call the same tool more than once.
75-
index := RandomInt(0, len(tools)-1)
76-
args, err := generateToolArguments(tools[index], config)
77-
if err != nil {
78-
return nil, 0, err
66+
// CreateToolCalls creates and returns tool calls based on the request's tool
67+
// definitions and the tool_choice parameter.
68+
//
69+
// The [tool_choice](https://platform.openai.com/docs/guides/function-calling#tool-choice)
70+
// parameter controls how the model responds to function calls.
71+
//
72+
// This function handles the following cases for tool_choice:
73+
// - "none": The model will not call any tools. In this scenario, this function
74+
// should ideally be bypassed, as no tool calls will be generated.
75+
// - "auto": This is the default behavior where the model autonomously decides
76+
// whether to generate a message or call one or more tools from the provided list.
77+
// - "required": The model is constrained to call one or more of the available tools.
78+
// - Forced Function: A specific tool can be forced by providing an object with the
79+
// structure `{"type": "function", "function": {"name": "my_function"}}`.
80+
// The model will be restricted to calling only that specified tool.
81+
//
82+
// This function currently does not handle the following `tool_choice` scenarios:
83+
// - Forced Custom Tool: If `tool_choice` is set to `{"type": "custom", "name": "my_custom"}`,
84+
// this function will not be able to enforce the calling of a custom tool, as custom
85+
// tool types are not yet supported.
86+
// - Allowed Tools Subset: The functionality to restrict the model's tool-calling
87+
// capabilities to a specific subset of the available tools has not been implemented.
88+
//
89+
// This function returns the generated tool calls, the number of completion
90+
// tokens used, and an error if one occurs (e.g., if a specified tool is not found).
91+
func CreateToolCalls(
92+
tools []openaiserverapi.Tool,
93+
toolChoice openaiserverapi.ToolChoice,
94+
config *Configuration,
95+
) ([]openaiserverapi.ToolCall, int, error) {
96+
generateCalls := func(availableTools []openaiserverapi.Tool, minCalls int) ([]openaiserverapi.ToolCall, int, error) {
97+
if len(availableTools) == 0 {
98+
// If no tools are available to choose from, no calls can be made.
99+
return nil, 0, errors.New("no tools available to create tool calls")
79100
}
80-
argsJson, err := json.Marshal(args)
81-
if err != nil {
82-
return nil, 0, err
101+
102+
numberOfCalls := minCalls
103+
if len(availableTools) > minCalls {
104+
// Randomly decide how many tools to call, between minCalls and the total available.
105+
numberOfCalls = RandomInt(minCalls, len(availableTools))
106+
}
107+
108+
if numberOfCalls == 0 {
109+
return nil, 0, nil
110+
}
111+
112+
calls := make([]openaiserverapi.ToolCall, 0, numberOfCalls)
113+
for i := range numberOfCalls {
114+
// Randomly choose which tool to call. We may call the same tool more than once.
115+
index := 0
116+
if len(availableTools) > 1 {
117+
index = RandomInt(0, len(availableTools)-1)
118+
}
119+
chosenTool := availableTools[index]
120+
121+
args, err := generateToolArguments(chosenTool, config)
122+
if err != nil {
123+
return nil, 0, err
124+
}
125+
argsJson, err := json.Marshal(args)
126+
if err != nil {
127+
return nil, 0, err
128+
}
129+
130+
call := openaiserverapi.ToolCall{
131+
Function: openaiserverapi.FunctionCall{
132+
Arguments: string(argsJson),
133+
TokenizedArguments: Tokenize(string(argsJson)),
134+
Name: &chosenTool.Function.Name,
135+
},
136+
ID: "chatcmpl-tool-" + RandomNumericString(10),
137+
Type: "function",
138+
Index: i,
139+
}
140+
calls = append(calls, call)
83141
}
142+
return calls, CountTokensForToolCalls(calls), nil
143+
}
144+
145+
// A specific function is forced.
146+
if functionChoice := toolChoice.GetFunction(); functionChoice != nil {
147+
requiredFuncName := functionChoice.Name
148+
var targetTool *openaiserverapi.Tool
84149

85-
call := openaiserverapi.ToolCall{
86-
Function: openaiserverapi.FunctionCall{
87-
Arguments: string(argsJson),
88-
TokenizedArguments: Tokenize(string(argsJson)),
89-
Name: &tools[index].Function.Name,
90-
},
91-
ID: "chatcmpl-tool-" + RandomNumericString(10),
92-
Type: "function",
93-
Index: i,
150+
// Find the specified tool in the list of available tools.
151+
for i, tool := range tools {
152+
if tool.Function.Name == requiredFuncName {
153+
targetTool = &tools[i]
154+
break
155+
}
94156
}
95-
calls = append(calls, call)
157+
158+
if targetTool == nil {
159+
return nil, 0, fmt.Errorf("tool with name '%s' requested in tool_choice but not found in the tools list", requiredFuncName)
160+
}
161+
162+
specificTools := []openaiserverapi.Tool{*targetTool}
163+
164+
// Generate arguments for the specific tool.
165+
return generateCalls(specificTools, len(specificTools))
166+
}
167+
168+
// Default behavior for "auto" or "required".
169+
// The model can choose from any of the provided tools.
170+
min := 0
171+
if !param.IsOmitted(toolChoice.OfAuto) && toolChoice.OfAuto.Or("") == ToolChoiceRequired {
172+
min = 1
96173
}
97174

98-
return calls, CountTokensForToolCalls(calls), nil
175+
return generateCalls(tools, min)
99176
}
100177

101178
func getRequiredAsMap(property map[string]any) map[string]struct{} {

pkg/llm-d-inference-sim/tools_test.go

Lines changed: 70 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,16 @@ import (
3232
"github.com/openai/openai-go/v3/packages/param"
3333
)
3434

35+
const (
36+
functionNameGetWeather = "get_weather"
37+
functionNameGetTemperature = "get_temperature"
38+
)
39+
3540
var tools = []openai.ChatCompletionToolUnionParam{
3641
{
3742
OfFunction: &openai.ChatCompletionFunctionToolParam{
3843
Function: openai.FunctionDefinitionParam{
39-
Name: "get_weather",
44+
Name: functionNameGetWeather,
4045
Description: openai.String("Get weather at the given location"),
4146
Parameters: openai.FunctionParameters{
4247
"type": "object",
@@ -53,7 +58,7 @@ var tools = []openai.ChatCompletionToolUnionParam{
5358
{
5459
OfFunction: &openai.ChatCompletionFunctionToolParam{
5560
Function: openai.FunctionDefinitionParam{
56-
Name: "get_temperature",
61+
Name: functionNameGetTemperature,
5762
Description: openai.String("Get temperature at the given location"),
5863
Parameters: openai.FunctionParameters{
5964
"type": "object",
@@ -78,7 +83,7 @@ var invalidTools = [][]openai.ChatCompletionToolUnionParam{
7883
{
7984
OfFunction: &openai.ChatCompletionFunctionToolParam{
8085
Function: openai.FunctionDefinitionParam{
81-
Name: "get_weather",
86+
Name: functionNameGetWeather,
8287
Description: openai.String("Get weather at the given location"),
8388
Parameters: openai.FunctionParameters{
8489
"type": "object",
@@ -95,7 +100,7 @@ var invalidTools = [][]openai.ChatCompletionToolUnionParam{
95100
{
96101
OfFunction: &openai.ChatCompletionFunctionToolParam{
97102
Function: openai.FunctionDefinitionParam{
98-
Name: "get_temperature",
103+
Name: functionNameGetTemperature,
99104
Description: openai.String("Get temperature at the given location"),
100105
Parameters: openai.FunctionParameters{
101106
"type": "object",
@@ -119,7 +124,7 @@ var invalidTools = [][]openai.ChatCompletionToolUnionParam{
119124
{
120125
OfFunction: &openai.ChatCompletionFunctionToolParam{
121126
Function: openai.FunctionDefinitionParam{
122-
Name: "get_weather",
127+
Name: functionNameGetWeather,
123128
Description: openai.String("Get weather at the given location"),
124129
Parameters: openai.FunctionParameters{
125130
"type": "object",
@@ -139,7 +144,7 @@ var invalidTools = [][]openai.ChatCompletionToolUnionParam{
139144
{
140145
OfFunction: &openai.ChatCompletionFunctionToolParam{
141146
Function: openai.FunctionDefinitionParam{
142-
Name: "get_weather",
147+
Name: functionNameGetWeather,
143148
Description: openai.String("Get weather at the given location"),
144149
},
145150
},
@@ -312,7 +317,7 @@ var toolWithoutRequiredParams = []openai.ChatCompletionToolUnionParam{
312317
{
313318
OfFunction: &openai.ChatCompletionFunctionToolParam{
314319
Function: openai.FunctionDefinitionParam{
315-
Name: "get_temperature",
320+
Name: functionNameGetTemperature,
316321
Description: openai.String("Get temperature at the given location"),
317322
Parameters: openai.FunctionParameters{
318323
"type": "object",
@@ -398,7 +403,7 @@ var _ = Describe("Simulator for request with tools", func() {
398403
tc := toolCalls[0]
399404
Expect(tc.Index).To(Or(BeNumerically("==", lastIndex), BeNumerically("==", lastIndex+1)))
400405
if tc.Index > int64(lastIndex) {
401-
Expect(tc.Function.Name).To(Or(Equal("get_weather"), Equal("get_temperature")))
406+
Expect(tc.Function.Name).To(Or(Equal(functionNameGetWeather), Equal(functionNameGetTemperature)))
402407
lastIndex++
403408
args[tc.Function.Name] = []string{tc.Function.Arguments}
404409
functionName = tc.Function.Name
@@ -429,7 +434,7 @@ var _ = Describe("Simulator for request with tools", func() {
429434
err := json.Unmarshal([]byte(joinedArgs), &argsMap)
430435
Expect(err).NotTo(HaveOccurred())
431436

432-
if functionName == "get_weather" {
437+
if functionName == functionNameGetWeather {
433438
Expect(joinedArgs).To(ContainSubstring("location"))
434439
} else {
435440
Expect(joinedArgs).To(ContainSubstring("city"))
@@ -473,14 +478,14 @@ var _ = Describe("Simulator for request with tools", func() {
473478
toolCalls := resp.Choices[0].Message.ToolCalls
474479
Expect(toolCalls).ToNot(BeEmpty())
475480
for _, tc := range toolCalls {
476-
Expect(tc.Function.Name).To(Or(Equal("get_weather"), Equal("get_temperature")))
481+
Expect(tc.Function.Name).To(Or(Equal(functionNameGetWeather), Equal(functionNameGetTemperature)))
477482
Expect(tc.ID).NotTo(BeEmpty())
478483
Expect(tc.Type).To(Equal("function"))
479484
args := make(map[string]string)
480485
err := json.Unmarshal([]byte(tc.Function.Arguments), &args)
481486
Expect(err).NotTo(HaveOccurred())
482487

483-
if tc.Function.Name == "get_weather" {
488+
if tc.Function.Name == functionNameGetWeather {
484489
Expect(tc.Function.Arguments).To(ContainSubstring("location"))
485490
} else {
486491
Expect(tc.Function.Arguments).To(ContainSubstring("city"))
@@ -499,6 +504,59 @@ var _ = Describe("Simulator for request with tools", func() {
499504
Entry(nil, common.ModeRandom),
500505
)
501506

507+
DescribeTable("no streaming, a specific tool",
508+
func(mode string, specificTool string) {
509+
ctx := context.TODO()
510+
client, err := startServer(ctx, mode)
511+
Expect(err).NotTo(HaveOccurred())
512+
513+
openaiclient, params := getOpenAIClientAndChatParams(client, model, userMessage, false)
514+
params.ToolChoice = openai.ToolChoiceOptionFunctionToolChoice(openai.ChatCompletionNamedToolChoiceFunctionParam{
515+
Name: specificTool,
516+
})
517+
params.Tools = tools
518+
519+
resp, err := openaiclient.Chat.Completions.New(ctx, params)
520+
Expect(err).NotTo(HaveOccurred())
521+
Expect(resp.Choices).ShouldNot(BeEmpty())
522+
Expect(string(resp.Object)).To(Equal(chatCompletionObject))
523+
524+
Expect(resp.Usage.PromptTokens).To(Equal(userMsgTokens))
525+
Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0))
526+
Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens))
527+
528+
content := resp.Choices[0].Message.Content
529+
Expect(content).Should(BeEmpty())
530+
531+
toolCalls := resp.Choices[0].Message.ToolCalls
532+
Expect(toolCalls).ToNot(BeEmpty())
533+
for _, tc := range toolCalls {
534+
Expect(tc.Function.Name).To(Equal(specificTool))
535+
Expect(tc.ID).NotTo(BeEmpty())
536+
Expect(tc.Type).To(Equal("function"))
537+
args := make(map[string]string)
538+
err := json.Unmarshal([]byte(tc.Function.Arguments), &args)
539+
Expect(err).NotTo(HaveOccurred())
540+
541+
if tc.Function.Name == functionNameGetWeather {
542+
Expect(tc.Function.Arguments).To(ContainSubstring("location"))
543+
} else {
544+
Expect(tc.Function.Arguments).To(ContainSubstring("city"))
545+
Expect(tc.Function.Arguments).To(ContainSubstring("unit"))
546+
Expect(args["unit"]).To(Or(Equal("C"), Equal("F")))
547+
}
548+
}
549+
},
550+
func(mode string, specificTool string) string {
551+
return "mode: " + mode + ", specificTool: " + specificTool
552+
},
553+
// Call several times because the tools and arguments are chosen randomly
554+
Entry(nil, common.ModeRandom, functionNameGetWeather),
555+
Entry(nil, common.ModeRandom, functionNameGetTemperature),
556+
Entry(nil, common.ModeRandom, functionNameGetWeather),
557+
Entry(nil, common.ModeRandom, functionNameGetTemperature),
558+
)
559+
502560
DescribeTable("check validator",
503561
func(mode string) {
504562
ctx := context.TODO()
@@ -778,7 +836,7 @@ var _ = Describe("Simulator for request with tools", func() {
778836
toolCalls := resp.Choices[0].Message.ToolCalls
779837
Expect(toolCalls).To(HaveLen(1))
780838
tc := toolCalls[0]
781-
Expect(tc.Function.Name).To(Equal("get_temperature"))
839+
Expect(tc.Function.Name).To(Equal(functionNameGetTemperature))
782840
Expect(tc.ID).NotTo(BeEmpty())
783841
Expect(tc.Type).To(Equal("function"))
784842
args := make(map[string]string)

pkg/llm-d-inference-sim/worker.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ func (s *VllmSimulator) processRequest(reqCtx *openaiserverapi.CompletionReqCtx)
8888
var toolCalls []openaiserverapi.ToolCall
8989
var completionTokens int
9090
if reqCtx.IsChatCompletion &&
91-
req.GetToolChoice() != common.ToolChoiceNone &&
91+
!common.IsToolChoiceNone(req.GetToolChoice()) &&
9292
req.GetTools() != nil {
9393
toolCalls, completionTokens, err =
9494
common.CreateToolCalls(req.GetTools(), req.GetToolChoice(), s.config)

0 commit comments

Comments
 (0)