Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 113 additions & 36 deletions pkg/common/tools_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ package common

import (
"encoding/json"
"errors"
"fmt"

openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
"github.com/openai/openai-go/v3/packages/param"
"github.com/santhosh-tekuri/jsonschema/v5"
)

Expand Down Expand Up @@ -52,50 +54,125 @@ var fakeStringArguments = []string{
`lifetime`,
}

// CreateToolCalls creates and returns response payload based on this request
// (tool calls or nothing in case we randomly choose not to generate calls),
// and the number of generated completion token sand the finish reason
func CreateToolCalls(tools []openaiserverapi.Tool, toolChoice string, config *Configuration) ([]openaiserverapi.ToolCall, int, error) {
// This function is called if tool choice is either 'required' or 'auto'.
// In case of 'required' at least one tool call has to be created, and we randomly choose
// the number of calls starting from one. Otherwise, we start from 0, and in case we randomly
// choose the number of calls to be 0, response text will be generated instead of a tool call.
min := 0
if toolChoice == ToolChoiceRequired {
min = 1
}
numberOfCalls := RandomInt(min, len(tools))
if numberOfCalls == 0 {
return nil, 0, nil
// IsToolChoiceNone checks if the tool_choice is set to "none".
func IsToolChoiceNone(toolChoice openaiserverapi.ToolChoice) bool {
if !param.IsOmitted(toolChoice.OfAuto) {
val := toolChoice.OfAuto.Or("")
return val == ToolChoiceNone
}
return false
}

calls := make([]openaiserverapi.ToolCall, 0)
for i := range numberOfCalls {
// Randomly choose which tools to call. We may call the same tool more than once.
index := RandomInt(0, len(tools)-1)
args, err := generateToolArguments(tools[index], config)
if err != nil {
return nil, 0, err
// CreateToolCalls creates and returns tool calls based on the request's tool
// definitions and the tool_choice parameter.
//
// The [tool_choice](https://platform.openai.com/docs/guides/function-calling#tool-choice)
// parameter controls how the model responds to function calls.
//
// This function handles the following cases for tool_choice:
// - "none": The model will not call any tools. In this scenario, this function
// should ideally be bypassed, as no tool calls will be generated.
// - "auto": This is the default behavior where the model autonomously decides
// whether to generate a message or call one or more tools from the provided list.
// - "required": The model is constrained to call one or more of the available tools.
// - Forced Function: A specific tool can be forced by providing an object with the
// structure `{"type": "function", "function": {"name": "my_function"}}`.
// The model will be restricted to calling only that specified tool.
//
// This function currently does not handle the following `tool_choice` scenarios:
// - Forced Custom Tool: If `tool_choice` is set to `{"type": "custom", "name": "my_custom"}`,
// this function will not be able to enforce the calling of a custom tool, as custom
// tool types are not yet supported.
// - Allowed Tools Subset: The functionality to restrict the model's tool-calling
// capabilities to a specific subset of the available tools has not been implemented.
//
// This function returns the generated tool calls, the number of completion
// tokens used, and an error if one occurs (e.g., if a specified tool is not found).
func CreateToolCalls(
tools []openaiserverapi.Tool,
toolChoice openaiserverapi.ToolChoice,
config *Configuration,
) ([]openaiserverapi.ToolCall, int, error) {
generateCalls := func(availableTools []openaiserverapi.Tool, minCalls int) ([]openaiserverapi.ToolCall, int, error) {
if len(availableTools) == 0 {
// If no tools are available to choose from, no calls can be made.
return nil, 0, errors.New("no tools available to create tool calls")
}
argsJson, err := json.Marshal(args)
if err != nil {
return nil, 0, err

numberOfCalls := minCalls
if len(availableTools) > minCalls {
// Randomly decide how many tools to call, between minCalls and the total available.
numberOfCalls = RandomInt(minCalls, len(availableTools))
}

if numberOfCalls == 0 {
return nil, 0, nil
}

calls := make([]openaiserverapi.ToolCall, 0, numberOfCalls)
for i := range numberOfCalls {
// Randomly choose which tool to call. We may call the same tool more than once.
index := 0
if len(availableTools) > 1 {
index = RandomInt(0, len(availableTools)-1)
}
chosenTool := availableTools[index]

args, err := generateToolArguments(chosenTool, config)
if err != nil {
return nil, 0, err
}
argsJson, err := json.Marshal(args)
if err != nil {
return nil, 0, err
}

call := openaiserverapi.ToolCall{
Function: openaiserverapi.FunctionCall{
Arguments: string(argsJson),
TokenizedArguments: Tokenize(string(argsJson)),
Name: &chosenTool.Function.Name,
},
ID: "chatcmpl-tool-" + RandomNumericString(10),
Type: "function",
Index: i,
}
calls = append(calls, call)
}
return calls, CountTokensForToolCalls(calls), nil
}

// A specific function is forced.
if functionChoice := toolChoice.GetFunction(); functionChoice != nil {
requiredFuncName := functionChoice.Name
var targetTool *openaiserverapi.Tool

call := openaiserverapi.ToolCall{
Function: openaiserverapi.FunctionCall{
Arguments: string(argsJson),
TokenizedArguments: Tokenize(string(argsJson)),
Name: &tools[index].Function.Name,
},
ID: "chatcmpl-tool-" + RandomNumericString(10),
Type: "function",
Index: i,
// Find the specified tool in the list of available tools.
for i, tool := range tools {
if tool.Function.Name == requiredFuncName {
targetTool = &tools[i]
break
}
}
calls = append(calls, call)

if targetTool == nil {
return nil, 0, fmt.Errorf("tool with name '%s' requested in tool_choice but not found in the tools list", requiredFuncName)
}

specificTools := []openaiserverapi.Tool{*targetTool}

// Generate arguments for the specific tool.
return generateCalls(specificTools, len(specificTools))
}

// Default behavior for "auto" or "required".
// The model can choose from any of the provided tools.
min := 0
if !param.IsOmitted(toolChoice.OfAuto) && toolChoice.OfAuto.Or("") == ToolChoiceRequired {
min = 1
}

return calls, CountTokensForToolCalls(calls), nil
return generateCalls(tools, min)
}

func getRequiredAsMap(property map[string]any) map[string]struct{} {
Expand Down
82 changes: 70 additions & 12 deletions pkg/llm-d-inference-sim/tools_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,16 @@ import (
"github.com/openai/openai-go/v3/packages/param"
)

const (
functionNameGetWeather = "get_weather"
functionNameGetTemperature = "get_temperature"
)

var tools = []openai.ChatCompletionToolUnionParam{
{
OfFunction: &openai.ChatCompletionFunctionToolParam{
Function: openai.FunctionDefinitionParam{
Name: "get_weather",
Name: functionNameGetWeather,
Description: openai.String("Get weather at the given location"),
Parameters: openai.FunctionParameters{
"type": "object",
Expand All @@ -53,7 +58,7 @@ var tools = []openai.ChatCompletionToolUnionParam{
{
OfFunction: &openai.ChatCompletionFunctionToolParam{
Function: openai.FunctionDefinitionParam{
Name: "get_temperature",
Name: functionNameGetTemperature,
Description: openai.String("Get temperature at the given location"),
Parameters: openai.FunctionParameters{
"type": "object",
Expand All @@ -78,7 +83,7 @@ var invalidTools = [][]openai.ChatCompletionToolUnionParam{
{
OfFunction: &openai.ChatCompletionFunctionToolParam{
Function: openai.FunctionDefinitionParam{
Name: "get_weather",
Name: functionNameGetWeather,
Description: openai.String("Get weather at the given location"),
Parameters: openai.FunctionParameters{
"type": "object",
Expand All @@ -95,7 +100,7 @@ var invalidTools = [][]openai.ChatCompletionToolUnionParam{
{
OfFunction: &openai.ChatCompletionFunctionToolParam{
Function: openai.FunctionDefinitionParam{
Name: "get_temperature",
Name: functionNameGetTemperature,
Description: openai.String("Get temperature at the given location"),
Parameters: openai.FunctionParameters{
"type": "object",
Expand All @@ -119,7 +124,7 @@ var invalidTools = [][]openai.ChatCompletionToolUnionParam{
{
OfFunction: &openai.ChatCompletionFunctionToolParam{
Function: openai.FunctionDefinitionParam{
Name: "get_weather",
Name: functionNameGetWeather,
Description: openai.String("Get weather at the given location"),
Parameters: openai.FunctionParameters{
"type": "object",
Expand All @@ -139,7 +144,7 @@ var invalidTools = [][]openai.ChatCompletionToolUnionParam{
{
OfFunction: &openai.ChatCompletionFunctionToolParam{
Function: openai.FunctionDefinitionParam{
Name: "get_weather",
Name: functionNameGetWeather,
Description: openai.String("Get weather at the given location"),
},
},
Expand Down Expand Up @@ -312,7 +317,7 @@ var toolWithoutRequiredParams = []openai.ChatCompletionToolUnionParam{
{
OfFunction: &openai.ChatCompletionFunctionToolParam{
Function: openai.FunctionDefinitionParam{
Name: "get_temperature",
Name: functionNameGetTemperature,
Description: openai.String("Get temperature at the given location"),
Parameters: openai.FunctionParameters{
"type": "object",
Expand Down Expand Up @@ -398,7 +403,7 @@ var _ = Describe("Simulator for request with tools", func() {
tc := toolCalls[0]
Expect(tc.Index).To(Or(BeNumerically("==", lastIndex), BeNumerically("==", lastIndex+1)))
if tc.Index > int64(lastIndex) {
Expect(tc.Function.Name).To(Or(Equal("get_weather"), Equal("get_temperature")))
Expect(tc.Function.Name).To(Or(Equal(functionNameGetWeather), Equal(functionNameGetTemperature)))
lastIndex++
args[tc.Function.Name] = []string{tc.Function.Arguments}
functionName = tc.Function.Name
Expand Down Expand Up @@ -429,7 +434,7 @@ var _ = Describe("Simulator for request with tools", func() {
err := json.Unmarshal([]byte(joinedArgs), &argsMap)
Expect(err).NotTo(HaveOccurred())

if functionName == "get_weather" {
if functionName == functionNameGetWeather {
Expect(joinedArgs).To(ContainSubstring("location"))
} else {
Expect(joinedArgs).To(ContainSubstring("city"))
Expand Down Expand Up @@ -473,14 +478,14 @@ var _ = Describe("Simulator for request with tools", func() {
toolCalls := resp.Choices[0].Message.ToolCalls
Expect(toolCalls).ToNot(BeEmpty())
for _, tc := range toolCalls {
Expect(tc.Function.Name).To(Or(Equal("get_weather"), Equal("get_temperature")))
Expect(tc.Function.Name).To(Or(Equal(functionNameGetWeather), Equal(functionNameGetTemperature)))
Expect(tc.ID).NotTo(BeEmpty())
Expect(tc.Type).To(Equal("function"))
args := make(map[string]string)
err := json.Unmarshal([]byte(tc.Function.Arguments), &args)
Expect(err).NotTo(HaveOccurred())

if tc.Function.Name == "get_weather" {
if tc.Function.Name == functionNameGetWeather {
Expect(tc.Function.Arguments).To(ContainSubstring("location"))
} else {
Expect(tc.Function.Arguments).To(ContainSubstring("city"))
Expand All @@ -499,6 +504,59 @@ var _ = Describe("Simulator for request with tools", func() {
Entry(nil, common.ModeRandom),
)

DescribeTable("no streaming, a specific tool",
func(mode string, specificTool string) {
ctx := context.TODO()
client, err := startServer(ctx, mode)
Expect(err).NotTo(HaveOccurred())

openaiclient, params := getOpenAIClientAndChatParams(client, model, userMessage, false)
params.ToolChoice = openai.ToolChoiceOptionFunctionToolChoice(openai.ChatCompletionNamedToolChoiceFunctionParam{
Name: specificTool,
})
params.Tools = tools

resp, err := openaiclient.Chat.Completions.New(ctx, params)
Expect(err).NotTo(HaveOccurred())
Expect(resp.Choices).ShouldNot(BeEmpty())
Expect(string(resp.Object)).To(Equal(chatCompletionObject))

Expect(resp.Usage.PromptTokens).To(Equal(userMsgTokens))
Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0))
Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens))

content := resp.Choices[0].Message.Content
Expect(content).Should(BeEmpty())

toolCalls := resp.Choices[0].Message.ToolCalls
Expect(toolCalls).ToNot(BeEmpty())
for _, tc := range toolCalls {
Expect(tc.Function.Name).To(Equal(specificTool))
Expect(tc.ID).NotTo(BeEmpty())
Expect(tc.Type).To(Equal("function"))
args := make(map[string]string)
err := json.Unmarshal([]byte(tc.Function.Arguments), &args)
Expect(err).NotTo(HaveOccurred())

if tc.Function.Name == functionNameGetWeather {
Expect(tc.Function.Arguments).To(ContainSubstring("location"))
} else {
Expect(tc.Function.Arguments).To(ContainSubstring("city"))
Expect(tc.Function.Arguments).To(ContainSubstring("unit"))
Expect(args["unit"]).To(Or(Equal("C"), Equal("F")))
}
}
},
func(mode string, specificTool string) string {
return "mode: " + mode + ", specificTool: " + specificTool
},
// Call several times because the tools and arguments are chosen randomly
Entry(nil, common.ModeRandom, functionNameGetWeather),
Entry(nil, common.ModeRandom, functionNameGetTemperature),
Entry(nil, common.ModeRandom, functionNameGetWeather),
Entry(nil, common.ModeRandom, functionNameGetTemperature),
)

DescribeTable("check validator",
func(mode string) {
ctx := context.TODO()
Expand Down Expand Up @@ -778,7 +836,7 @@ var _ = Describe("Simulator for request with tools", func() {
toolCalls := resp.Choices[0].Message.ToolCalls
Expect(toolCalls).To(HaveLen(1))
tc := toolCalls[0]
Expect(tc.Function.Name).To(Equal("get_temperature"))
Expect(tc.Function.Name).To(Equal(functionNameGetTemperature))
Expect(tc.ID).NotTo(BeEmpty())
Expect(tc.Type).To(Equal("function"))
args := make(map[string]string)
Expand Down
2 changes: 1 addition & 1 deletion pkg/llm-d-inference-sim/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (s *VllmSimulator) processRequest(reqCtx *openaiserverapi.CompletionReqCtx)
var toolCalls []openaiserverapi.ToolCall
var completionTokens int
if reqCtx.IsChatCompletion &&
req.GetToolChoice() != common.ToolChoiceNone &&
!common.IsToolChoiceNone(req.GetToolChoice()) &&
req.GetTools() != nil {
toolCalls, completionTokens, err =
common.CreateToolCalls(req.GetTools(), req.GetToolChoice(), s.config)
Expand Down
Loading
Loading