Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Changelog

## Unreleased

- functiontool: support non-struct inputs from LLM function calls
- Automatically wraps function tools whose input type is a non-struct (e.g. `string`, `int`) so they accept LLM function-call arguments of the form `{ "input": <value> }`.
- Adds `nonStructInputWrapper` implementation, unit + integration tests, and documentation.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ Agent Development Kit (ADK) is a flexible and modular framework that applies sof

This Go version of ADK is ideal for developers building cloud-native agent applications, leveraging Go's strengths in concurrency and performance.

Note: Function tools that accept primitive (non-struct) inputs are now automatically supported — such handlers will be exposed to models as a function taking an object with a single `"input"` property. See `CHANGELOG.md` for details.

---

## ✨ Key Features
Expand Down
46 changes: 46 additions & 0 deletions agent/llmagent/llmagent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/genai"

"google.golang.org/adk/agent"
Expand Down Expand Up @@ -738,6 +739,51 @@ func TestFunctionTool(t *testing.T) {
}
}

func TestLLMAgent_StringFunctionToolIntegration(t *testing.T) {
// Simulate model that asks to call the "super_tool" with a primitive input
responses := []*genai.Content{
genai.NewContentFromFunctionCall("super_tool", map[string]any{"input": "callval"}, "model"),
genai.NewContentFromText("final", "model"),
}
mockModel := &testutil.MockModel{Responses: responses}

superTool, err := functiontool.New(functiontool.Config{Name: "super_tool", Description: "echo"}, functiontool.Func[string, string](
func(ctx tool.Context, input string) (string, error) {
return "Hello", nil
},
))
if err != nil {
t.Fatalf("failed to create tool: %v", err)
}

a, err := llmagent.New(llmagent.Config{
Name: "super_tool_caller",
Model: mockModel,
Description: "Agent to invoke the super_tool",
Instruction: "Call the super_tool",
Tools: []tool.Tool{superTool},
})
if err != nil {
t.Fatalf("failed to create llm agent: %v", err)
}

runner := testutil.NewTestAgentRunner(t, a)
parts, err := testutil.CollectParts(runner.Run(t, "session1", "prompt"))
if err != nil {
t.Fatalf("failed to collect parts: %v", err)
}

wantParts := []*genai.Part{
genai.NewPartFromFunctionCall("super_tool", map[string]any{"input": "callval"}),
genai.NewPartFromFunctionResponse("super_tool", map[string]any{"result": "Hello"}),
genai.NewPartFromText("final"),
}

if diff := cmp.Diff(wantParts, parts, cmpopts.IgnoreFields(genai.FunctionCall{}, "ID"), cmpopts.IgnoreFields(genai.FunctionResponse{}, "ID")); diff != "" {
t.Fatalf("event parts mismatch (-want +got):\n%s", diff)
}
}

func TestAgentTransfer(t *testing.T) {
// Helpers to create genai.Content conveniently.
transferCall := func(agentName string) *genai.Content {
Expand Down
35 changes: 35 additions & 0 deletions tool/functiontool/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package functiontool provides a tool that wraps a Go function.
//
// Special behavior for non-struct inputs:
//
// The LLM `genai.FunctionCall.Args` is always represented as `map[string]any`.
// For Go handlers whose input type (`TArgs`) is a non-struct (for example
// `string`, `int`, or `bool`), ADK's `functiontool.New` wraps the handler with
// a small decorator that:
// - exposes a function declaration whose `parameters` JSON schema is an
// `object` with a single required property `"input"` whose schema is the
// schema inferred for `TArgs`;
// - at runtime accepts `map[string]any{"input": <value>}` from the model,
// extracts the `input` value, converts it to `TArgs`, calls the original
// handler, and converts the result back to `map[string]any` (wrapping
// primitive outputs as `{ "result": <value> }` when necessary).
//
// This automatic wrapping makes it seamless to author function tools that
// accept primitive arguments while keeping the function-call declaration and
// runtime conversion consistent with the model's map-based function call
// format.
package functiontool
11 changes: 5 additions & 6 deletions tool/functiontool/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ type Func[TArgs, TResults any] func(tool.Context, TArgs) (TResults, error)

// New creates a new tool with a name, description, and the provided handler.
// Input schema is automatically inferred from the input and output types.
//
// If TArgs is a non-struct type (e.g., string, int), the tool is automatically wrapped
// to handle the LLM's map[string]any format with an "input" key. This allows non-struct
// function tools to work seamlessly with llmagent.
func New[TArgs, TResults any](cfg Config, handler Func[TArgs, TResults]) (tool.Tool, error) {
// TODO: How can we improve UX for functions that does not require an argument, returns a simple type value, or returns a no result?
// https://github.com/modelcontextprotocol/go-sdk/discussions/37
Expand All @@ -63,12 +67,7 @@ func New[TArgs, TResults any](cfg Config, handler Func[TArgs, TResults]) (tool.T
return nil, fmt.Errorf("failed to infer output schema: %w", err)
}

return &functionTool[TArgs, TResults]{
cfg: cfg,
inputSchema: ischema,
outputSchema: oschema,
handler: handler,
}, nil
return wrapNonStructInput(cfg, ischema, oschema, handler)
}

// functionTool wraps a Go function.
Expand Down
49 changes: 49 additions & 0 deletions tool/functiontool/function_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,55 @@ func TestFunctionTool_CustomSchema(t *testing.T) {
})
}

func TestFunctionTool_StringInputWrapper(t *testing.T) {
// string identity function should be wrapped to accept {"input": string}
stringIdentityFunc := func(ctx tool.Context, input string) (string, error) {
return input, nil
}
stringIdentityTool, err := functiontool.New(
functiontool.Config{
Name: "string_identity",
Description: "returns the input value",
},
stringIdentityFunc)
if err != nil {
t.Fatalf("New(function) failed: %v", err)
}

// Declaration should expect an object with an "input" property
funcTool, ok := stringIdentityTool.(toolinternal.FunctionTool)
if !ok {
t.Fatal("stringIdentityTool does not implement FunctionTool")
}
decl := funcTool.Declaration()
if decl == nil || decl.ParametersJsonSchema == nil {
t.Fatalf("declaration or parameters schema is nil: %v", decl)
}
// Ensure the top-level type is object and contains "input" property
params := decl.ParametersJsonSchema
// marshal for easy substring checks
raw := stringify(params)
if !strings.Contains(raw, `"type": "object"`) {
t.Fatalf("expected object schema for parameters, got: %s", raw)
}
if !strings.Contains(raw, `"input"`) {
t.Fatalf("expected 'input' property in parameters schema, got: %s", raw)
}

// Run should accept map[string]any{"input": "value"}
result, err := funcTool.Run(nil, map[string]any{"input": "hello"})
if err != nil {
t.Fatalf("Run failed: %v", err)
}
got, err := typeutil.ConvertToWithJSONSchema[map[string]any, map[string]string](result, nil)
if err != nil {
t.Fatalf("unexpected result type: %v", err)
}
if gotVal, ok := got["result"]; !ok || gotVal != "hello" {
t.Fatalf("unexpected run result = %v", got)
}
}

func toolDeclaration(cfg *genai.GenerateContentConfig) *genai.FunctionDeclaration {
if cfg == nil || len(cfg.Tools) == 0 {
return nil
Expand Down
185 changes: 185 additions & 0 deletions tool/functiontool/wrapper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package functiontool

import (
"fmt"
"reflect"

"github.com/google/jsonschema-go/jsonschema"
"google.golang.org/genai"

"google.golang.org/adk/internal/toolinternal/toolutils"
"google.golang.org/adk/internal/typeutil"
"google.golang.org/adk/model"
"google.golang.org/adk/tool"
)

// nonStructInputWrapper wraps a functionTool with a non-struct input type,
// converting the LLM-provided map[string]any with an "input" key to the expected type.
type nonStructInputWrapper[TArgs, TResults any] struct {
cfg Config
inputSchema *jsonschema.Resolved
outputSchema *jsonschema.Resolved
handler Func[TArgs, TResults]
innerTool *functionTool[TArgs, TResults]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The innerTool field is declared here but appears to be unused throughout the nonStructInputWrapper's methods. To improve code clarity and maintainability, it's best to remove this unused field and its corresponding initialization in the wrapNonStructInput function.

}

// wrapNonStructInput wraps a functionTool if its input type is not a struct.
// This is necessary for llmagent compatibility, as the LLM always provides arguments
// as map[string]any with an "input" key for non-struct types.
func wrapNonStructInput[TArgs, TResults any](
cfg Config,
inputSchema *jsonschema.Resolved,
outputSchema *jsonschema.Resolved,
handler Func[TArgs, TResults],
) (tool.Tool, error) {
// Check if TArgs is a struct type
var zeroArgs TArgs
if reflect.TypeOf(zeroArgs).Kind() != reflect.Struct {
// Non-struct type: wrap it
return &nonStructInputWrapper[TArgs, TResults]{
cfg: cfg,
inputSchema: inputSchema,
outputSchema: outputSchema,
handler: handler,
innerTool: &functionTool[TArgs, TResults]{
cfg: cfg,
inputSchema: inputSchema,
outputSchema: outputSchema,
handler: handler,
},
}, nil
}

// Struct type: return the regular functionTool
return &functionTool[TArgs, TResults]{
cfg: cfg,
inputSchema: inputSchema,
outputSchema: outputSchema,
handler: handler,
}, nil
}

// Description implements tool.Tool.
func (w *nonStructInputWrapper[TArgs, TResults]) Description() string {
return w.cfg.Description
}

// Name implements tool.Tool.
func (w *nonStructInputWrapper[TArgs, TResults]) Name() string {
return w.cfg.Name
}

// IsLongRunning implements tool.Tool.
func (w *nonStructInputWrapper[TArgs, TResults]) IsLongRunning() bool {
return w.cfg.IsLongRunning
}

// ProcessRequest implements tool.Tool.
func (w *nonStructInputWrapper[TArgs, TResults]) ProcessRequest(ctx tool.Context, req *model.LLMRequest) error {
return toolutils.PackTool(req, w)
}

// Declaration implements FunctionTool interface.
// It modifies the schema to expect {"input": <value>} format.
func (w *nonStructInputWrapper[TArgs, TResults]) Declaration() *genai.FunctionDeclaration {
decl := &genai.FunctionDeclaration{
Name: w.Name(),
Description: w.Description(),
}

// Create a wrapper schema that expects {"input": <value>} format
wrappedSchema := &jsonschema.Schema{
Type: "object",
Properties: map[string]*jsonschema.Schema{
"input": {},
},
Required: []string{"input"},
}

// Copy the original input schema properties to the "input" property
if w.inputSchema != nil {
wrappedSchema.Properties["input"] = w.inputSchema.Schema()
} else {
// Use string as default if no schema
wrappedSchema.Properties["input"] = &jsonschema.Schema{Type: "string"}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This if-else block can be simplified. The wrapNonStructInput function is called from New, which ensures that inputSchema is non-nil; otherwise, it would have returned an error. Therefore, the else block appears to be unreachable. You can remove the conditional and directly assign w.inputSchema.Schema().

wrappedSchema.Properties["input"] = w.inputSchema.Schema()


if wrappedSchema != nil {
decl.ParametersJsonSchema = wrappedSchema
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The wrappedSchema variable is initialized using a struct literal, so it cannot be nil. This if check is redundant and can be removed to make the code more concise.

decl.ParametersJsonSchema = wrappedSchema

if w.outputSchema != nil {
decl.ResponseJsonSchema = w.outputSchema.Schema()
}

if w.cfg.IsLongRunning {
instruction := "NOTE: This is a long-running operation. Do not call this tool again if it has already returned some intermediate or pending status."
if decl.Description != "" {
decl.Description += "\n\n" + instruction
} else {
decl.Description = instruction
}
}

return decl
}

// Run implements FunctionTool.
// It unwraps the {"input": <value>} format and calls the handler.
func (w *nonStructInputWrapper[TArgs, TResults]) Run(ctx tool.Context, args any) (map[string]any, error) {
// Extract the map
m, ok := args.(map[string]any)
if !ok {
return nil, fmt.Errorf("unexpected args type, got: %T", args)
}

// Extract the "input" key
inputVal, ok := m["input"]
if !ok {
return nil, fmt.Errorf("missing required 'input' argument")
}

// Convert to TArgs - use JSON marshaling to handle the conversion.
// Do NOT pass the original resolved input schema to ConvertToWithJSONSchema here,
// because resolvedSchema validation expects a map[string]any and will fail for
// primitive JSON values (e.g. string). Validation is handled by the declaration
// schema sent to the model; at runtime we only need to convert the value.
input, err := typeutil.ConvertToWithJSONSchema[any, TArgs](inputVal, nil)
if err != nil {
return nil, fmt.Errorf("failed to convert input: %w", err)
}

// Call the handler
output, err := w.handler(ctx, input)
if err != nil {
return nil, err
}

// Convert output to map[string]any
resp, err := typeutil.ConvertToWithJSONSchema[TResults, map[string]any](output, w.outputSchema)
if err == nil {
return resp, nil
}

// If conversion fails and outputSchema is set, validate and return with wrapped error
if w.outputSchema != nil {
if err1 := w.outputSchema.Validate(output); err1 != nil {
return resp, err // if it fails propagate original err.
}
Comment on lines 166 to 168
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's an issue with the error handling in this block. If w.outputSchema.Validate(output) fails, the specific validation error err1 is discarded, and the original, more generic conversion error err is returned. This hides the true cause of the failure. The more specific validation error err1 should be propagated to aid in debugging.

if err1 := w.outputSchema.Validate(output); err1 != nil {
			return nil, fmt.Errorf("output validation failed: %w", err1)
		}

}
wrappedOutput := map[string]any{"result": output}
return wrappedOutput, nil
}