Skip to content

Commit bfd4ec1

Browse files
committed
Configure the tool_choice to use a specific tool
Signed-off-by: MondayCha <[email protected]>
1 parent 0b25c88 commit bfd4ec1

File tree

5 files changed

+280
-59
lines changed

5 files changed

+280
-59
lines changed

pkg/common/tools_utils.go

Lines changed: 102 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"fmt"
2222

2323
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
24+
"github.com/openai/openai-go/v3/packages/param"
2425
"github.com/santhosh-tekuri/jsonschema/v5"
2526
)
2627

@@ -52,50 +53,115 @@ var fakeStringArguments = []string{
5253
`lifetime`,
5354
}
5455

55-
// CreateToolCalls creates and returns response payload based on this request
56-
// (tool calls or nothing in case we randomly choose not to generate calls),
57-
// and the number of generated completion token sand the finish reason
58-
func CreateToolCalls(tools []openaiserverapi.Tool, toolChoice string, config *Configuration) ([]openaiserverapi.ToolCall, int, error) {
59-
// This function is called if tool choice is either 'required' or 'auto'.
60-
// In case of 'required' at least one tool call has to be created, and we randomly choose
61-
// the number of calls starting from one. Otherwise, we start from 0, and in case we randomly
62-
// choose the number of calls to be 0, response text will be generated instead of a tool call.
63-
min := 0
64-
if toolChoice == ToolChoiceRequired {
65-
min = 1
66-
}
67-
numberOfCalls := RandomInt(min, len(tools))
68-
if numberOfCalls == 0 {
69-
return nil, 0, nil
56+
// IsToolChoiceNone checks if the tool_choice is set to "none".
57+
func IsToolChoiceNone(toolChoice openaiserverapi.ToolChoice) bool {
58+
if !param.IsOmitted(toolChoice.OfAuto) {
59+
val := toolChoice.OfAuto.Or("")
60+
return val == ToolChoiceNone
7061
}
62+
return false
63+
}
7164

72-
calls := make([]openaiserverapi.ToolCall, 0)
73-
for i := range numberOfCalls {
74-
// Randomly choose which tools to call. We may call the same tool more than once.
75-
index := RandomInt(0, len(tools)-1)
76-
args, err := generateToolArguments(tools[index], config)
77-
if err != nil {
78-
return nil, 0, err
65+
// CreateToolCalls creates and returns tool calls based on the request's tool
66+
// definitions and tool_choice parameter.
67+
// The tool_choice parameter controls how the model responds to function calls:
68+
// - "none": The model does not call any tools. This case should be handled
69+
// before calling this function.
70+
// - "auto": The model can choose to either generate a message or call one or
71+
// more tools. This is the default behavior.
72+
// - "required": The model must call one or more tools.
73+
// - Specific function: A specific tool can be forced by providing an object
74+
// like `{"type": "function", "function": {"name": "my_function"}}`. The
75+
// model will be constrained to call that exact tool.
76+
//
77+
// This function returns the generated tool calls, the number of completion
78+
// tokens used, and an error if one occurs (e.g., if a specified tool is not found).
79+
func CreateToolCalls(
80+
tools []openaiserverapi.Tool,
81+
toolChoice openaiserverapi.ToolChoice,
82+
config *Configuration,
83+
) ([]openaiserverapi.ToolCall, int, error) {
84+
generateCalls := func(availableTools []openaiserverapi.Tool, minCalls int) ([]openaiserverapi.ToolCall, int, error) {
85+
if len(availableTools) == 0 {
86+
// If no tools are available to choose from, no calls can be made.
87+
// If minCalls > 0, this indicates a configuration error, but we return no calls.
88+
return nil, 0, nil
7989
}
80-
argsJson, err := json.Marshal(args)
81-
if err != nil {
82-
return nil, 0, err
90+
91+
numberOfCalls := minCalls
92+
if len(availableTools) > minCalls {
93+
// Randomly decide how many tools to call, between minCalls and the total available.
94+
numberOfCalls = RandomInt(minCalls, len(availableTools))
95+
}
96+
97+
if numberOfCalls == 0 {
98+
return nil, 0, nil
99+
}
100+
101+
calls := make([]openaiserverapi.ToolCall, 0, numberOfCalls)
102+
for i := range numberOfCalls {
103+
// Randomly choose which tool to call. We may call the same tool more than once.
104+
index := 0
105+
if len(availableTools) > 1 {
106+
index = RandomInt(0, len(availableTools)-1)
107+
}
108+
chosenTool := availableTools[index]
109+
110+
args, err := generateToolArguments(chosenTool, config)
111+
if err != nil {
112+
return nil, 0, err
113+
}
114+
argsJson, err := json.Marshal(args)
115+
if err != nil {
116+
return nil, 0, err
117+
}
118+
119+
call := openaiserverapi.ToolCall{
120+
Function: openaiserverapi.FunctionCall{
121+
Arguments: string(argsJson),
122+
TokenizedArguments: Tokenize(string(argsJson)),
123+
Name: &chosenTool.Function.Name,
124+
},
125+
ID: "chatcmpl-tool-" + RandomNumericString(10),
126+
Type: "function",
127+
Index: i,
128+
}
129+
calls = append(calls, call)
83130
}
131+
return calls, CountTokensForToolCalls(calls), nil
132+
}
133+
134+
// A specific function is forced.
135+
if functionChoice := toolChoice.GetFunction(); functionChoice != nil {
136+
requiredFuncName := functionChoice.Name
137+
var targetTool *openaiserverapi.Tool
84138

85-
call := openaiserverapi.ToolCall{
86-
Function: openaiserverapi.FunctionCall{
87-
Arguments: string(argsJson),
88-
TokenizedArguments: Tokenize(string(argsJson)),
89-
Name: &tools[index].Function.Name,
90-
},
91-
ID: "chatcmpl-tool-" + RandomNumericString(10),
92-
Type: "function",
93-
Index: i,
139+
// Find the specified tool in the list of available tools.
140+
for i, tool := range tools {
141+
if tool.Function.Name == requiredFuncName {
142+
targetTool = &tools[i]
143+
break
144+
}
94145
}
95-
calls = append(calls, call)
146+
147+
if targetTool == nil {
148+
return nil, 0, fmt.Errorf("tool with name '%s' requested in tool_choice but not found in the tools list", requiredFuncName)
149+
}
150+
151+
specificTools := []openaiserverapi.Tool{*targetTool}
152+
153+
// Generate arguments for the specific tool.
154+
return generateCalls(specificTools, len(specificTools))
155+
}
156+
157+
// Default behavior for "auto" or "required".
158+
// The model can choose from any of the provided tools.
159+
min := 0
160+
if !param.IsOmitted(toolChoice.OfAuto) && toolChoice.OfAuto.Or("") == ToolChoiceRequired {
161+
min = 1
96162
}
97163

98-
return calls, CountTokensForToolCalls(calls), nil
164+
return generateCalls(tools, min)
99165
}
100166

101167
func getRequiredAsMap(property map[string]any) map[string]struct{} {

pkg/llm-d-inference-sim/tools_test.go

Lines changed: 70 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,16 @@ import (
3232
"github.com/openai/openai-go/v3/packages/param"
3333
)
3434

35+
const (
36+
functionNameGetWeather = "get_weather"
37+
functionNameGetTemperature = "get_temperature"
38+
)
39+
3540
var tools = []openai.ChatCompletionToolUnionParam{
3641
{
3742
OfFunction: &openai.ChatCompletionFunctionToolParam{
3843
Function: openai.FunctionDefinitionParam{
39-
Name: "get_weather",
44+
Name: functionNameGetWeather,
4045
Description: openai.String("Get weather at the given location"),
4146
Parameters: openai.FunctionParameters{
4247
"type": "object",
@@ -53,7 +58,7 @@ var tools = []openai.ChatCompletionToolUnionParam{
5358
{
5459
OfFunction: &openai.ChatCompletionFunctionToolParam{
5560
Function: openai.FunctionDefinitionParam{
56-
Name: "get_temperature",
61+
Name: functionNameGetTemperature,
5762
Description: openai.String("Get temperature at the given location"),
5863
Parameters: openai.FunctionParameters{
5964
"type": "object",
@@ -78,7 +83,7 @@ var invalidTools = [][]openai.ChatCompletionToolUnionParam{
7883
{
7984
OfFunction: &openai.ChatCompletionFunctionToolParam{
8085
Function: openai.FunctionDefinitionParam{
81-
Name: "get_weather",
86+
Name: functionNameGetWeather,
8287
Description: openai.String("Get weather at the given location"),
8388
Parameters: openai.FunctionParameters{
8489
"type": "object",
@@ -95,7 +100,7 @@ var invalidTools = [][]openai.ChatCompletionToolUnionParam{
95100
{
96101
OfFunction: &openai.ChatCompletionFunctionToolParam{
97102
Function: openai.FunctionDefinitionParam{
98-
Name: "get_temperature",
103+
Name: functionNameGetTemperature,
99104
Description: openai.String("Get temperature at the given location"),
100105
Parameters: openai.FunctionParameters{
101106
"type": "object",
@@ -119,7 +124,7 @@ var invalidTools = [][]openai.ChatCompletionToolUnionParam{
119124
{
120125
OfFunction: &openai.ChatCompletionFunctionToolParam{
121126
Function: openai.FunctionDefinitionParam{
122-
Name: "get_weather",
127+
Name: functionNameGetWeather,
123128
Description: openai.String("Get weather at the given location"),
124129
Parameters: openai.FunctionParameters{
125130
"type": "object",
@@ -139,7 +144,7 @@ var invalidTools = [][]openai.ChatCompletionToolUnionParam{
139144
{
140145
OfFunction: &openai.ChatCompletionFunctionToolParam{
141146
Function: openai.FunctionDefinitionParam{
142-
Name: "get_weather",
147+
Name: functionNameGetWeather,
143148
Description: openai.String("Get weather at the given location"),
144149
},
145150
},
@@ -312,7 +317,7 @@ var toolWithoutRequiredParams = []openai.ChatCompletionToolUnionParam{
312317
{
313318
OfFunction: &openai.ChatCompletionFunctionToolParam{
314319
Function: openai.FunctionDefinitionParam{
315-
Name: "get_temperature",
320+
Name: functionNameGetTemperature,
316321
Description: openai.String("Get temperature at the given location"),
317322
Parameters: openai.FunctionParameters{
318323
"type": "object",
@@ -398,7 +403,7 @@ var _ = Describe("Simulator for request with tools", func() {
398403
tc := toolCalls[0]
399404
Expect(tc.Index).To(Or(BeNumerically("==", lastIndex), BeNumerically("==", lastIndex+1)))
400405
if tc.Index > int64(lastIndex) {
401-
Expect(tc.Function.Name).To(Or(Equal("get_weather"), Equal("get_temperature")))
406+
Expect(tc.Function.Name).To(Or(Equal(functionNameGetWeather), Equal(functionNameGetTemperature)))
402407
lastIndex++
403408
args[tc.Function.Name] = []string{tc.Function.Arguments}
404409
functionName = tc.Function.Name
@@ -429,7 +434,7 @@ var _ = Describe("Simulator for request with tools", func() {
429434
err := json.Unmarshal([]byte(joinedArgs), &argsMap)
430435
Expect(err).NotTo(HaveOccurred())
431436

432-
if functionName == "get_weather" {
437+
if functionName == functionNameGetWeather {
433438
Expect(joinedArgs).To(ContainSubstring("location"))
434439
} else {
435440
Expect(joinedArgs).To(ContainSubstring("city"))
@@ -473,14 +478,14 @@ var _ = Describe("Simulator for request with tools", func() {
473478
toolCalls := resp.Choices[0].Message.ToolCalls
474479
Expect(toolCalls).ToNot(BeEmpty())
475480
for _, tc := range toolCalls {
476-
Expect(tc.Function.Name).To(Or(Equal("get_weather"), Equal("get_temperature")))
481+
Expect(tc.Function.Name).To(Or(Equal(functionNameGetWeather), Equal(functionNameGetTemperature)))
477482
Expect(tc.ID).NotTo(BeEmpty())
478483
Expect(tc.Type).To(Equal("function"))
479484
args := make(map[string]string)
480485
err := json.Unmarshal([]byte(tc.Function.Arguments), &args)
481486
Expect(err).NotTo(HaveOccurred())
482487

483-
if tc.Function.Name == "get_weather" {
488+
if tc.Function.Name == functionNameGetWeather {
484489
Expect(tc.Function.Arguments).To(ContainSubstring("location"))
485490
} else {
486491
Expect(tc.Function.Arguments).To(ContainSubstring("city"))
@@ -499,6 +504,59 @@ var _ = Describe("Simulator for request with tools", func() {
499504
Entry(nil, common.ModeRandom),
500505
)
501506

507+
DescribeTable("no streaming, a specific tool",
508+
func(mode string, specificTool string) {
509+
ctx := context.TODO()
510+
client, err := startServer(ctx, mode)
511+
Expect(err).NotTo(HaveOccurred())
512+
513+
openaiclient, params := getOpenAIClientAndChatParams(client, model, userMessage, false)
514+
params.ToolChoice = openai.ToolChoiceOptionFunctionToolChoice(openai.ChatCompletionNamedToolChoiceFunctionParam{
515+
Name: specificTool,
516+
})
517+
params.Tools = tools
518+
519+
resp, err := openaiclient.Chat.Completions.New(ctx, params)
520+
Expect(err).NotTo(HaveOccurred())
521+
Expect(resp.Choices).ShouldNot(BeEmpty())
522+
Expect(string(resp.Object)).To(Equal(chatCompletionObject))
523+
524+
Expect(resp.Usage.PromptTokens).To(Equal(userMsgTokens))
525+
Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0))
526+
Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens))
527+
528+
content := resp.Choices[0].Message.Content
529+
Expect(content).Should(BeEmpty())
530+
531+
toolCalls := resp.Choices[0].Message.ToolCalls
532+
Expect(toolCalls).ToNot(BeEmpty())
533+
for _, tc := range toolCalls {
534+
Expect(tc.Function.Name).To(Equal(specificTool))
535+
Expect(tc.ID).NotTo(BeEmpty())
536+
Expect(tc.Type).To(Equal("function"))
537+
args := make(map[string]string)
538+
err := json.Unmarshal([]byte(tc.Function.Arguments), &args)
539+
Expect(err).NotTo(HaveOccurred())
540+
541+
if tc.Function.Name == functionNameGetWeather {
542+
Expect(tc.Function.Arguments).To(ContainSubstring("location"))
543+
} else {
544+
Expect(tc.Function.Arguments).To(ContainSubstring("city"))
545+
Expect(tc.Function.Arguments).To(ContainSubstring("unit"))
546+
Expect(args["unit"]).To(Or(Equal("C"), Equal("F")))
547+
}
548+
}
549+
},
550+
func(mode string, specificTool string) string {
551+
return "mode: " + mode + ", specificTool: " + specificTool
552+
},
553+
// Call several times because the tools and arguments are chosen randomly
554+
Entry(nil, common.ModeRandom, functionNameGetWeather),
555+
Entry(nil, common.ModeRandom, functionNameGetTemperature),
556+
Entry(nil, common.ModeRandom, functionNameGetWeather),
557+
Entry(nil, common.ModeRandom, functionNameGetTemperature),
558+
)
559+
502560
DescribeTable("check validator",
503561
func(mode string) {
504562
ctx := context.TODO()
@@ -778,7 +836,7 @@ var _ = Describe("Simulator for request with tools", func() {
778836
toolCalls := resp.Choices[0].Message.ToolCalls
779837
Expect(toolCalls).To(HaveLen(1))
780838
tc := toolCalls[0]
781-
Expect(tc.Function.Name).To(Equal("get_temperature"))
839+
Expect(tc.Function.Name).To(Equal(functionNameGetTemperature))
782840
Expect(tc.ID).NotTo(BeEmpty())
783841
Expect(tc.Type).To(Equal("function"))
784842
args := make(map[string]string)

pkg/llm-d-inference-sim/worker.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ func (s *VllmSimulator) processRequest(reqCtx *openaiserverapi.CompletionReqCtx)
8888
var toolCalls []openaiserverapi.ToolCall
8989
var completionTokens int
9090
if reqCtx.IsChatCompletion &&
91-
req.GetToolChoice() != common.ToolChoiceNone &&
91+
!common.IsToolChoiceNone(req.GetToolChoice()) &&
9292
req.GetTools() != nil {
9393
toolCalls, completionTokens, err =
9494
common.CreateToolCalls(req.GetTools(), req.GetToolChoice(), s.config)

0 commit comments

Comments
 (0)