diff --git a/pkg/common/tools_utils.go b/pkg/common/tools_utils.go index 8df7ebd2..e18bc924 100644 --- a/pkg/common/tools_utils.go +++ b/pkg/common/tools_utils.go @@ -18,9 +18,11 @@ package common import ( "encoding/json" + "errors" "fmt" openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" + "github.com/openai/openai-go/v3/packages/param" "github.com/santhosh-tekuri/jsonschema/v5" ) @@ -52,50 +54,125 @@ var fakeStringArguments = []string{ `lifetime`, } -// CreateToolCalls creates and returns response payload based on this request -// (tool calls or nothing in case we randomly choose not to generate calls), -// and the number of generated completion token sand the finish reason -func CreateToolCalls(tools []openaiserverapi.Tool, toolChoice string, config *Configuration) ([]openaiserverapi.ToolCall, int, error) { - // This function is called if tool choice is either 'required' or 'auto'. - // In case of 'required' at least one tool call has to be created, and we randomly choose - // the number of calls starting from one. Otherwise, we start from 0, and in case we randomly - // choose the number of calls to be 0, response text will be generated instead of a tool call. - min := 0 - if toolChoice == ToolChoiceRequired { - min = 1 - } - numberOfCalls := RandomInt(min, len(tools)) - if numberOfCalls == 0 { - return nil, 0, nil +// IsToolChoiceNone checks if the tool_choice is set to "none". +func IsToolChoiceNone(toolChoice openaiserverapi.ToolChoice) bool { + if !param.IsOmitted(toolChoice.OfAuto) { + val := toolChoice.OfAuto.Or("") + return val == ToolChoiceNone } + return false +} - calls := make([]openaiserverapi.ToolCall, 0) - for i := range numberOfCalls { - // Randomly choose which tools to call. We may call the same tool more than once. - index := RandomInt(0, len(tools)-1) - args, err := generateToolArguments(tools[index], config) - if err != nil { - return nil, 0, err +// CreateToolCalls creates and returns tool calls based on the request's tool +// definitions and the tool_choice parameter. +// +// The [tool_choice](https://platform.openai.com/docs/guides/function-calling#tool-choice) +// parameter controls how the model responds to function calls. +// +// This function handles the following cases for tool_choice: +// - "none": The model will not call any tools. In this scenario, this function +// should ideally be bypassed, as no tool calls will be generated. +// - "auto": This is the default behavior where the model autonomously decides +// whether to generate a message or call one or more tools from the provided list. +// - "required": The model is constrained to call one or more of the available tools. +// - Forced Function: A specific tool can be forced by providing an object with the +// structure `{"type": "function", "function": {"name": "my_function"}}`. +// The model will be restricted to calling only that specified tool. +// +// This function currently does not handle the following `tool_choice` scenarios: +// - Forced Custom Tool: If `tool_choice` is set to `{"type": "custom", "name": "my_custom"}`, +// this function will not be able to enforce the calling of a custom tool, as custom +// tool types are not yet supported. +// - Allowed Tools Subset: The functionality to restrict the model's tool-calling +// capabilities to a specific subset of the available tools has not been implemented. +// +// This function returns the generated tool calls, the number of completion +// tokens used, and an error if one occurs (e.g., if a specified tool is not found). +func CreateToolCalls( + tools []openaiserverapi.Tool, + toolChoice openaiserverapi.ToolChoice, + config *Configuration, +) ([]openaiserverapi.ToolCall, int, error) { + generateCalls := func(availableTools []openaiserverapi.Tool, minCalls int) ([]openaiserverapi.ToolCall, int, error) { + if len(availableTools) == 0 { + // If no tools are available to choose from, no calls can be made. + return nil, 0, errors.New("no tools available to create tool calls") } - argsJson, err := json.Marshal(args) - if err != nil { - return nil, 0, err + + numberOfCalls := minCalls + if len(availableTools) > minCalls { + // Randomly decide how many tools to call, between minCalls and the total available. + numberOfCalls = RandomInt(minCalls, len(availableTools)) + } + + if numberOfCalls == 0 { + return nil, 0, nil + } + + calls := make([]openaiserverapi.ToolCall, 0, numberOfCalls) + for i := range numberOfCalls { + // Randomly choose which tool to call. We may call the same tool more than once. + index := 0 + if len(availableTools) > 1 { + index = RandomInt(0, len(availableTools)-1) + } + chosenTool := availableTools[index] + + args, err := generateToolArguments(chosenTool, config) + if err != nil { + return nil, 0, err + } + argsJson, err := json.Marshal(args) + if err != nil { + return nil, 0, err + } + + call := openaiserverapi.ToolCall{ + Function: openaiserverapi.FunctionCall{ + Arguments: string(argsJson), + TokenizedArguments: Tokenize(string(argsJson)), + Name: &chosenTool.Function.Name, + }, + ID: "chatcmpl-tool-" + RandomNumericString(10), + Type: "function", + Index: i, + } + calls = append(calls, call) } + return calls, CountTokensForToolCalls(calls), nil + } + + // A specific function is forced. + if functionChoice := toolChoice.GetFunction(); functionChoice != nil { + requiredFuncName := functionChoice.Name + var targetTool *openaiserverapi.Tool - call := openaiserverapi.ToolCall{ - Function: openaiserverapi.FunctionCall{ - Arguments: string(argsJson), - TokenizedArguments: Tokenize(string(argsJson)), - Name: &tools[index].Function.Name, - }, - ID: "chatcmpl-tool-" + RandomNumericString(10), - Type: "function", - Index: i, + // Find the specified tool in the list of available tools. + for i, tool := range tools { + if tool.Function.Name == requiredFuncName { + targetTool = &tools[i] + break + } } - calls = append(calls, call) + + if targetTool == nil { + return nil, 0, fmt.Errorf("tool with name '%s' requested in tool_choice but not found in the tools list", requiredFuncName) + } + + specificTools := []openaiserverapi.Tool{*targetTool} + + // Generate arguments for the specific tool. + return generateCalls(specificTools, len(specificTools)) + } + + // Default behavior for "auto" or "required". + // The model can choose from any of the provided tools. + min := 0 + if !param.IsOmitted(toolChoice.OfAuto) && toolChoice.OfAuto.Or("") == ToolChoiceRequired { + min = 1 } - return calls, CountTokensForToolCalls(calls), nil + return generateCalls(tools, min) } func getRequiredAsMap(property map[string]any) map[string]struct{} { diff --git a/pkg/llm-d-inference-sim/tools_test.go b/pkg/llm-d-inference-sim/tools_test.go index e97db392..c504c5d5 100644 --- a/pkg/llm-d-inference-sim/tools_test.go +++ b/pkg/llm-d-inference-sim/tools_test.go @@ -32,11 +32,16 @@ import ( "github.com/openai/openai-go/v3/packages/param" ) +const ( + functionNameGetWeather = "get_weather" + functionNameGetTemperature = "get_temperature" +) + var tools = []openai.ChatCompletionToolUnionParam{ { OfFunction: &openai.ChatCompletionFunctionToolParam{ Function: openai.FunctionDefinitionParam{ - Name: "get_weather", + Name: functionNameGetWeather, Description: openai.String("Get weather at the given location"), Parameters: openai.FunctionParameters{ "type": "object", @@ -53,7 +58,7 @@ var tools = []openai.ChatCompletionToolUnionParam{ { OfFunction: &openai.ChatCompletionFunctionToolParam{ Function: openai.FunctionDefinitionParam{ - Name: "get_temperature", + Name: functionNameGetTemperature, Description: openai.String("Get temperature at the given location"), Parameters: openai.FunctionParameters{ "type": "object", @@ -78,7 +83,7 @@ var invalidTools = [][]openai.ChatCompletionToolUnionParam{ { OfFunction: &openai.ChatCompletionFunctionToolParam{ Function: openai.FunctionDefinitionParam{ - Name: "get_weather", + Name: functionNameGetWeather, Description: openai.String("Get weather at the given location"), Parameters: openai.FunctionParameters{ "type": "object", @@ -95,7 +100,7 @@ var invalidTools = [][]openai.ChatCompletionToolUnionParam{ { OfFunction: &openai.ChatCompletionFunctionToolParam{ Function: openai.FunctionDefinitionParam{ - Name: "get_temperature", + Name: functionNameGetTemperature, Description: openai.String("Get temperature at the given location"), Parameters: openai.FunctionParameters{ "type": "object", @@ -119,7 +124,7 @@ var invalidTools = [][]openai.ChatCompletionToolUnionParam{ { OfFunction: &openai.ChatCompletionFunctionToolParam{ Function: openai.FunctionDefinitionParam{ - Name: "get_weather", + Name: functionNameGetWeather, Description: openai.String("Get weather at the given location"), Parameters: openai.FunctionParameters{ "type": "object", @@ -139,7 +144,7 @@ var invalidTools = [][]openai.ChatCompletionToolUnionParam{ { OfFunction: &openai.ChatCompletionFunctionToolParam{ Function: openai.FunctionDefinitionParam{ - Name: "get_weather", + Name: functionNameGetWeather, Description: openai.String("Get weather at the given location"), }, }, @@ -312,7 +317,7 @@ var toolWithoutRequiredParams = []openai.ChatCompletionToolUnionParam{ { OfFunction: &openai.ChatCompletionFunctionToolParam{ Function: openai.FunctionDefinitionParam{ - Name: "get_temperature", + Name: functionNameGetTemperature, Description: openai.String("Get temperature at the given location"), Parameters: openai.FunctionParameters{ "type": "object", @@ -398,7 +403,7 @@ var _ = Describe("Simulator for request with tools", func() { tc := toolCalls[0] Expect(tc.Index).To(Or(BeNumerically("==", lastIndex), BeNumerically("==", lastIndex+1))) if tc.Index > int64(lastIndex) { - Expect(tc.Function.Name).To(Or(Equal("get_weather"), Equal("get_temperature"))) + Expect(tc.Function.Name).To(Or(Equal(functionNameGetWeather), Equal(functionNameGetTemperature))) lastIndex++ args[tc.Function.Name] = []string{tc.Function.Arguments} functionName = tc.Function.Name @@ -429,7 +434,7 @@ var _ = Describe("Simulator for request with tools", func() { err := json.Unmarshal([]byte(joinedArgs), &argsMap) Expect(err).NotTo(HaveOccurred()) - if functionName == "get_weather" { + if functionName == functionNameGetWeather { Expect(joinedArgs).To(ContainSubstring("location")) } else { Expect(joinedArgs).To(ContainSubstring("city")) @@ -473,14 +478,14 @@ var _ = Describe("Simulator for request with tools", func() { toolCalls := resp.Choices[0].Message.ToolCalls Expect(toolCalls).ToNot(BeEmpty()) for _, tc := range toolCalls { - Expect(tc.Function.Name).To(Or(Equal("get_weather"), Equal("get_temperature"))) + Expect(tc.Function.Name).To(Or(Equal(functionNameGetWeather), Equal(functionNameGetTemperature))) Expect(tc.ID).NotTo(BeEmpty()) Expect(tc.Type).To(Equal("function")) args := make(map[string]string) err := json.Unmarshal([]byte(tc.Function.Arguments), &args) Expect(err).NotTo(HaveOccurred()) - if tc.Function.Name == "get_weather" { + if tc.Function.Name == functionNameGetWeather { Expect(tc.Function.Arguments).To(ContainSubstring("location")) } else { Expect(tc.Function.Arguments).To(ContainSubstring("city")) @@ -499,6 +504,59 @@ var _ = Describe("Simulator for request with tools", func() { Entry(nil, common.ModeRandom), ) + DescribeTable("no streaming, a specific tool", + func(mode string, specificTool string) { + ctx := context.TODO() + client, err := startServer(ctx, mode) + Expect(err).NotTo(HaveOccurred()) + + openaiclient, params := getOpenAIClientAndChatParams(client, model, userMessage, false) + params.ToolChoice = openai.ToolChoiceOptionFunctionToolChoice(openai.ChatCompletionNamedToolChoiceFunctionParam{ + Name: specificTool, + }) + params.Tools = tools + + resp, err := openaiclient.Chat.Completions.New(ctx, params) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.Choices).ShouldNot(BeEmpty()) + Expect(string(resp.Object)).To(Equal(chatCompletionObject)) + + Expect(resp.Usage.PromptTokens).To(Equal(userMsgTokens)) + Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0)) + Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens)) + + content := resp.Choices[0].Message.Content + Expect(content).Should(BeEmpty()) + + toolCalls := resp.Choices[0].Message.ToolCalls + Expect(toolCalls).ToNot(BeEmpty()) + for _, tc := range toolCalls { + Expect(tc.Function.Name).To(Equal(specificTool)) + Expect(tc.ID).NotTo(BeEmpty()) + Expect(tc.Type).To(Equal("function")) + args := make(map[string]string) + err := json.Unmarshal([]byte(tc.Function.Arguments), &args) + Expect(err).NotTo(HaveOccurred()) + + if tc.Function.Name == functionNameGetWeather { + Expect(tc.Function.Arguments).To(ContainSubstring("location")) + } else { + Expect(tc.Function.Arguments).To(ContainSubstring("city")) + Expect(tc.Function.Arguments).To(ContainSubstring("unit")) + Expect(args["unit"]).To(Or(Equal("C"), Equal("F"))) + } + } + }, + func(mode string, specificTool string) string { + return "mode: " + mode + ", specificTool: " + specificTool + }, + // Call several times because the tools and arguments are chosen randomly + Entry(nil, common.ModeRandom, functionNameGetWeather), + Entry(nil, common.ModeRandom, functionNameGetTemperature), + Entry(nil, common.ModeRandom, functionNameGetWeather), + Entry(nil, common.ModeRandom, functionNameGetTemperature), + ) + DescribeTable("check validator", func(mode string) { ctx := context.TODO() @@ -778,7 +836,7 @@ var _ = Describe("Simulator for request with tools", func() { toolCalls := resp.Choices[0].Message.ToolCalls Expect(toolCalls).To(HaveLen(1)) tc := toolCalls[0] - Expect(tc.Function.Name).To(Equal("get_temperature")) + Expect(tc.Function.Name).To(Equal(functionNameGetTemperature)) Expect(tc.ID).NotTo(BeEmpty()) Expect(tc.Type).To(Equal("function")) args := make(map[string]string) diff --git a/pkg/llm-d-inference-sim/worker.go b/pkg/llm-d-inference-sim/worker.go index 3e05faca..b247a72b 100644 --- a/pkg/llm-d-inference-sim/worker.go +++ b/pkg/llm-d-inference-sim/worker.go @@ -88,7 +88,7 @@ func (s *VllmSimulator) processRequest(reqCtx *openaiserverapi.CompletionReqCtx) var toolCalls []openaiserverapi.ToolCall var completionTokens int if reqCtx.IsChatCompletion && - req.GetToolChoice() != common.ToolChoiceNone && + !common.IsToolChoiceNone(req.GetToolChoice()) && req.GetTools() != nil { toolCalls, completionTokens, err = common.CreateToolCalls(req.GetTools(), req.GetToolChoice(), s.config) diff --git a/pkg/openai-server-api/request.go b/pkg/openai-server-api/request.go index 8b17d679..9326b52b 100644 --- a/pkg/openai-server-api/request.go +++ b/pkg/openai-server-api/request.go @@ -46,10 +46,10 @@ type CompletionRequest interface { SetNumberOfCachedPromptTokens(cachedPromptTokens int) // GetPrompt returns the prompt GetPrompt() string - // GetTools() returns tools to use (in chat completion) + // GetTools returns tools to use (in chat completion) GetTools() []Tool - // GetToolChoice() returns tool choice (in chat completion) - GetToolChoice() string + // GetToolChoice returns tool choice (in chat completion) + GetToolChoice() ToolChoice // GetMaxCompletionTokens returns the maximum completion tokens requested GetMaxCompletionTokens() *int64 // GetIgnoreEOS returns true if the end-of-sequence tokens will be ignored @@ -184,12 +184,13 @@ type ChatCompletionRequest struct { // Tools is a list of tools the model may call. Tools []Tool `json:"tools,omitempty"` - // ToolChoice controls which (if any) tool is called by the model, - // possible values: none, auto, required. - // Sending an object with a specific tool, is currently not supported. - ToolChoice string `json:"tool_choice,omitempty"` + // ToolChoice controls which (if any) tool is called by the model. + // It can be a string ("none", "auto", "required") or an object specifying the function. + ToolChoice ToolChoice `json:"tool_choice,omitzero"` } +var _ CompletionRequest = (*ChatCompletionRequest)(nil) + // function defines a tool type function struct { // Name is the function's name @@ -221,7 +222,7 @@ func (c *ChatCompletionRequest) GetTools() []Tool { return c.Tools } -func (c *ChatCompletionRequest) GetToolChoice() string { +func (c *ChatCompletionRequest) GetToolChoice() ToolChoice { return c.ToolChoice } @@ -286,6 +287,8 @@ type TextCompletionRequest struct { MaxTokens *int64 `json:"max_tokens"` } +var _ CompletionRequest = (*TextCompletionRequest)(nil) + func (t *TextCompletionRequest) GetPrompt() string { return t.Prompt } @@ -294,8 +297,8 @@ func (c *TextCompletionRequest) GetTools() []Tool { return nil } -func (c *TextCompletionRequest) GetToolChoice() string { - return "" +func (c *TextCompletionRequest) GetToolChoice() ToolChoice { + return ToolChoice{} } func (c *TextCompletionRequest) GetMaxCompletionTokens() *int64 { diff --git a/pkg/openai-server-api/tool_choice.go b/pkg/openai-server-api/tool_choice.go new file mode 100644 index 00000000..6cdf6e1b --- /dev/null +++ b/pkg/openai-server-api/tool_choice.go @@ -0,0 +1,93 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package openaiserverapi + +import ( + "encoding/json" + "fmt" + + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/packages/param" +) + +// ToolChoice is a wrapper around ChatCompletionToolChoiceOptionUnionParam that +// provides custom JSON unmarshalling logic to correctly handle +// the union type. +type ToolChoice struct { + openai.ChatCompletionToolChoiceOptionUnionParam +} + +// MarshalJSON forwards the marshalling process to the embedded +// ChatCompletionToolChoiceOptionUnionParam's MarshalJSON method, +// which is known to work correctly. +func (t ToolChoice) MarshalJSON() ([]byte, error) { + return t.ChatCompletionToolChoiceOptionUnionParam.MarshalJSON() +} + +// UnmarshalJSON provides custom logic to correctly deserialize the JSON data +// into the appropriate field of the embedded union type. It inspects the JSON +// structure to determine if it's a simple string or a complex object with a +// 'type' discriminator field. +func (t *ToolChoice) UnmarshalJSON(data []byte) error { + // If the input is a simple string (e.g., "auto", "none", "required"), + // unmarshal it into the OfAuto field. + if data[0] == '"' { + var strValue string + if err := json.Unmarshal(data, &strValue); err != nil { + return err + } + t.OfAuto = param.NewOpt(strValue) + return nil + } + + // If the input is a JSON object, we need to determine its type. + // We use a temporary struct to detect the 'type' field. + var typeDetector struct { + Type string `json:"type"` + } + + // We only care about the type field, ignore other fields + if err := json.Unmarshal(data, &typeDetector); err != nil { + return fmt.Errorf("failed to detect type for ToolChoice: %w", err) + } + + // Based on the detected type, unmarshal the data into the correct struct. + switch typeDetector.Type { + case "function": + var functionChoice openai.ChatCompletionNamedToolChoiceParam + if err := functionChoice.UnmarshalJSON(data); err != nil { + return err + } + t.OfFunctionToolChoice = &functionChoice + case "custom": + var customChoice openai.ChatCompletionNamedToolChoiceCustomParam + if err := customChoice.UnmarshalJSON(data); err != nil { + return err + } + t.OfCustomToolChoice = &customChoice + case "allowed_tools": + var allowedToolsChoice openai.ChatCompletionAllowedToolChoiceParam + if err := allowedToolsChoice.UnmarshalJSON(data); err != nil { + return err + } + t.OfAllowedTools = &allowedToolsChoice + default: + return fmt.Errorf("unknown ToolChoice type: %s", typeDetector.Type) + } + + return nil +}