Skip to content

Commit a1894f6

Browse files
authored
Support object parameters in tools (#66)
* Support object parameters in tools Signed-off-by: Ira <[email protected]> * Fix the schema Signed-off-by: Ira <[email protected]> * Don't allow enums for arrays and objects Signed-off-by: Ira <[email protected]> --------- Signed-off-by: Ira <[email protected]>
1 parent dcbaa20 commit a1894f6

File tree

2 files changed

+267
-13
lines changed

2 files changed

+267
-13
lines changed

pkg/llm-d-inference-sim/tools_test.go

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,85 @@ var toolWith3DArray = []openai.ChatCompletionToolParam{
177177
},
178178
}
179179

180+
var toolWithObjects = []openai.ChatCompletionToolParam{
181+
{
182+
Function: openai.FunctionDefinitionParam{
183+
Name: "process_order",
184+
Description: openai.String("Process a customer order"),
185+
Parameters: openai.FunctionParameters{
186+
"type": "object",
187+
"properties": map[string]interface{}{
188+
"order_info": map[string]interface{}{
189+
"type": "object",
190+
"properties": map[string]interface{}{
191+
"item": map[string]interface{}{
192+
"type": "string",
193+
},
194+
"quantity": map[string]string{
195+
"type": "number",
196+
},
197+
"address": map[string]interface{}{
198+
"type": "object",
199+
"properties": map[string]interface{}{
200+
"street": map[string]interface{}{
201+
"type": "string",
202+
},
203+
"number": map[string]interface{}{
204+
"type": "number",
205+
},
206+
"home": map[string]interface{}{
207+
"type": "boolean",
208+
},
209+
},
210+
"required": []string{"street", "number", "home"},
211+
},
212+
},
213+
"required": []string{"item", "quantity", "address"},
214+
},
215+
"name": map[string]interface{}{
216+
"type": "string",
217+
},
218+
},
219+
"required": []string{"order_info", "name"},
220+
},
221+
},
222+
},
223+
}
224+
225+
var toolWithObjectAndArray = []openai.ChatCompletionToolParam{
226+
{
227+
Function: openai.FunctionDefinitionParam{
228+
Name: "submit_survey",
229+
Description: openai.String("Submit a survey with user information."),
230+
Parameters: openai.FunctionParameters{
231+
"type": "object",
232+
"properties": map[string]interface{}{
233+
"user_info": map[string]interface{}{
234+
"type": "object",
235+
"properties": map[string]interface{}{
236+
"name": map[string]interface{}{
237+
"type": "string",
238+
"description": "The user's name",
239+
},
240+
"age": map[string]string{
241+
"type": "number",
242+
"description": "The user's age",
243+
},
244+
"hobbies": map[string]interface{}{
245+
"type": "array",
246+
"items": map[string]string{"type": "string"},
247+
"description": "A list of the user's hobbies",
248+
},
249+
},
250+
"required": []string{"name", "age", "hobbies"},
251+
},
252+
},
253+
"required": []string{"user_info"},
254+
},
255+
},
256+
},
257+
}
258+
180259
var _ = Describe("Simulator for request with tools", func() {
181260

182261
DescribeTable("streaming",
@@ -456,4 +535,121 @@ var _ = Describe("Simulator for request with tools", func() {
456535
Entry(nil, modeRandom),
457536
Entry(nil, modeRandom),
458537
)
538+
539+
DescribeTable("objects, no streaming",
540+
func(mode string) {
541+
ctx := context.TODO()
542+
client, err := startServer(ctx, mode)
543+
Expect(err).NotTo(HaveOccurred())
544+
545+
openaiclient := openai.NewClient(
546+
option.WithBaseURL(baseURL),
547+
option.WithHTTPClient(client))
548+
549+
params := openai.ChatCompletionNewParams{
550+
Messages: []openai.ChatCompletionMessageParamUnion{openai.UserMessage(userMessage)},
551+
Model: model,
552+
ToolChoice: openai.ChatCompletionToolChoiceOptionUnionParam{OfAuto: param.NewOpt("required")},
553+
Tools: toolWithObjects,
554+
}
555+
556+
resp, err := openaiclient.Chat.Completions.New(ctx, params)
557+
Expect(err).NotTo(HaveOccurred())
558+
Expect(resp.Choices).ShouldNot(BeEmpty())
559+
Expect(string(resp.Object)).To(Equal(chatCompletionObject))
560+
561+
Expect(resp.Usage.PromptTokens).To(Equal(int64(4)))
562+
Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0))
563+
Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens))
564+
565+
content := resp.Choices[0].Message.Content
566+
Expect(content).Should(BeEmpty())
567+
568+
toolCalls := resp.Choices[0].Message.ToolCalls
569+
Expect(toolCalls).To(HaveLen(1))
570+
tc := toolCalls[0]
571+
Expect(tc.Function.Name).To(Equal("process_order"))
572+
Expect(tc.ID).NotTo(BeEmpty())
573+
Expect(string(tc.Type)).To(Equal("function"))
574+
575+
args := make(map[string]any)
576+
err = json.Unmarshal([]byte(tc.Function.Arguments), &args)
577+
Expect(err).NotTo(HaveOccurred())
578+
Expect(args["name"]).ToNot(BeEmpty())
579+
Expect(args["order_info"]).ToNot(BeEmpty())
580+
orderInfo, ok := args["order_info"].(map[string]any)
581+
Expect(ok).To(BeTrue())
582+
Expect(orderInfo["item"]).ToNot(BeEmpty())
583+
Expect(orderInfo).To(HaveKey("quantity"))
584+
Expect(orderInfo["address"]).ToNot(BeEmpty())
585+
address, ok := orderInfo["address"].(map[string]any)
586+
Expect(ok).To(BeTrue())
587+
Expect(address["street"]).ToNot(BeEmpty())
588+
_, ok = address["street"].(string)
589+
Expect(ok).To(BeTrue())
590+
_, ok = address["number"].(float64)
591+
Expect(ok).To(BeTrue())
592+
_, ok = address["home"].(bool)
593+
Expect(ok).To(BeTrue())
594+
},
595+
func(mode string) string {
596+
return "mode: " + mode
597+
},
598+
Entry(nil, modeRandom),
599+
)
600+
601+
DescribeTable("objects with array field, no streaming",
602+
func(mode string) {
603+
ctx := context.TODO()
604+
client, err := startServer(ctx, mode)
605+
Expect(err).NotTo(HaveOccurred())
606+
607+
openaiclient := openai.NewClient(
608+
option.WithBaseURL(baseURL),
609+
option.WithHTTPClient(client))
610+
611+
params := openai.ChatCompletionNewParams{
612+
Messages: []openai.ChatCompletionMessageParamUnion{openai.UserMessage(userMessage)},
613+
Model: model,
614+
ToolChoice: openai.ChatCompletionToolChoiceOptionUnionParam{OfAuto: param.NewOpt("required")},
615+
Tools: toolWithObjectAndArray,
616+
}
617+
618+
resp, err := openaiclient.Chat.Completions.New(ctx, params)
619+
Expect(err).NotTo(HaveOccurred())
620+
Expect(resp.Choices).ShouldNot(BeEmpty())
621+
Expect(string(resp.Object)).To(Equal(chatCompletionObject))
622+
623+
Expect(resp.Usage.PromptTokens).To(Equal(int64(4)))
624+
Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0))
625+
Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens))
626+
627+
content := resp.Choices[0].Message.Content
628+
Expect(content).Should(BeEmpty())
629+
630+
toolCalls := resp.Choices[0].Message.ToolCalls
631+
Expect(toolCalls).To(HaveLen(1))
632+
tc := toolCalls[0]
633+
Expect(tc.Function.Name).To(Equal("submit_survey"))
634+
Expect(tc.ID).NotTo(BeEmpty())
635+
Expect(string(tc.Type)).To(Equal("function"))
636+
637+
args := make(map[string]any)
638+
err = json.Unmarshal([]byte(tc.Function.Arguments), &args)
639+
Expect(err).NotTo(HaveOccurred())
640+
Expect(args["user_info"]).ToNot(BeEmpty())
641+
642+
userInfo, ok := args["user_info"].(map[string]any)
643+
Expect(ok).To(BeTrue())
644+
Expect(userInfo).To(HaveKey("age"))
645+
Expect(userInfo["name"]).ToNot(BeEmpty())
646+
Expect(userInfo["hobbies"]).ToNot(BeEmpty())
647+
_, ok = userInfo["hobbies"].([]any)
648+
Expect(ok).To(BeTrue())
649+
},
650+
func(mode string) string {
651+
return "mode: " + mode
652+
},
653+
Entry(nil, modeRandom),
654+
)
459655
})

pkg/llm-d-inference-sim/tools_utils.go

Lines changed: 71 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -87,19 +87,24 @@ func createToolCalls(tools []tool, toolChoice string) ([]toolCall, string, int,
8787
return calls, toolsFinishReason, countTokensForToolCalls(calls), nil
8888
}
8989

90-
func generateToolArguments(tool tool) (map[string]any, error) {
91-
arguments := make(map[string]any)
92-
properties, _ := tool.Function.Parameters["properties"].(map[string]any)
93-
90+
func getRequiredAsMap(property map[string]any) map[string]struct{} {
9491
required := make(map[string]struct{})
95-
requiredParams, ok := tool.Function.Parameters["required"]
92+
requiredParams, ok := property["required"]
9693
if ok {
9794
requiredArray, _ := requiredParams.([]any)
9895
for _, requiredParam := range requiredArray {
9996
param, _ := requiredParam.(string)
10097
required[param] = struct{}{}
10198
}
10299
}
100+
return required
101+
}
102+
103+
func generateToolArguments(tool tool) (map[string]any, error) {
104+
arguments := make(map[string]any)
105+
properties, _ := tool.Function.Parameters["properties"].(map[string]any)
106+
107+
required := getRequiredAsMap(tool.Function.Parameters)
103108

104109
for param, property := range properties {
105110
_, paramIsRequired := required[param]
@@ -150,8 +155,24 @@ func createArgument(property any) (any, error) {
150155
array[i] = elem
151156
}
152157
return array, nil
158+
case "object":
159+
required := getRequiredAsMap(propertyMap)
160+
objectProperties := propertyMap["properties"].(map[string]any)
161+
object := make(map[string]interface{})
162+
for fieldName, fieldProperties := range objectProperties {
163+
_, fieldIsRequired := required[fieldName]
164+
if !fieldIsRequired && !flipCoin() {
165+
continue
166+
}
167+
fieldValue, err := createArgument(fieldProperties)
168+
if err != nil {
169+
return nil, err
170+
}
171+
object[fieldName] = fieldValue
172+
}
173+
return object, nil
153174
default:
154-
return nil, fmt.Errorf("tool parameters of type %s are currently not supported", paramType)
175+
return nil, fmt.Errorf("tool parameters of type %s are not supported", paramType)
155176
}
156177
}
157178

@@ -274,6 +295,7 @@ const schema = `{
274295
"number",
275296
"boolean",
276297
"array",
298+
"object",
277299
"null"
278300
]
279301
},
@@ -286,9 +308,7 @@ const schema = `{
286308
"type": [
287309
"string",
288310
"number",
289-
"boolean",
290-
"array",
291-
"null"
311+
"boolean"
292312
]
293313
}
294314
},
@@ -310,6 +330,12 @@ const schema = `{
310330
}
311331
}
312332
]
333+
},
334+
"required": {
335+
"type": "array",
336+
"items": {
337+
"type": "string"
338+
}
313339
}
314340
},
315341
"required": [
@@ -376,11 +402,29 @@ const schema = `{
376402
},
377403
{
378404
"if": {
379-
"properties": {
380-
"type": {
381-
"const": "null"
405+
"anyOf": [
406+
{
407+
"properties": {
408+
"type": {
409+
"const": "null"
410+
}
411+
}
412+
},
413+
{
414+
"properties": {
415+
"type": {
416+
"const": "object"
417+
}
418+
}
419+
},
420+
{
421+
"properties": {
422+
"type": {
423+
"const": "array"
424+
}
425+
}
382426
}
383-
}
427+
]
384428
},
385429
"then": {
386430
"not": {
@@ -403,6 +447,20 @@ const schema = `{
403447
"items"
404448
]
405449
}
450+
},
451+
{
452+
"if": {
453+
"properties": {
454+
"type": {
455+
"const": "object"
456+
}
457+
}
458+
},
459+
"then": {
460+
"required": [
461+
"properties"
462+
]
463+
}
406464
}
407465
]
408466
}

0 commit comments

Comments
 (0)