Skip to content

Commit 26dba7f

Browse files
authored
feat(go): improved tool interrupts API to be more ergonomic (#4112)
1 parent 28994d9 commit 26dba7f

File tree

19 files changed

+1100
-155
lines changed

19 files changed

+1100
-155
lines changed

go/ai/generate.go

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030
"github.com/firebase/genkit/go/core/logger"
3131
"github.com/firebase/genkit/go/core/tracing"
3232
"github.com/firebase/genkit/go/internal/base"
33+
"github.com/google/uuid"
3334
)
3435

3536
// Model represents a model that can generate content based on a request.
@@ -361,6 +362,9 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi
361362
return nil, err
362363
}
363364

365+
// Ensure all tool requests have unique refs for matching during resume.
366+
ensureToolRequestRefs(resp.Message)
367+
364368
// If this is a long-running operation response, return it immediately without further processing
365369
if bm != nil && resp.Operation != nil {
366370
return resp, nil
@@ -552,6 +556,9 @@ func GenerateText(ctx context.Context, r api.Registry, opts ...GenerateOption) (
552556
}
553557

554558
// GenerateData runs a generate request and returns strongly-typed output.
559+
// If the response doesn't contain text output (e.g., contains tool requests
560+
// or interrupts instead), the output will be nil and no error is returned.
561+
// Check resp.Interrupts() or resp.ToolRequests() to handle these cases.
555562
func GenerateData[Out any](ctx context.Context, r api.Registry, opts ...GenerateOption) (*Out, *ModelResponse, error) {
556563
var value Out
557564
opts = append(opts, WithOutputType(value))
@@ -561,9 +568,16 @@ func GenerateData[Out any](ctx context.Context, r api.Registry, opts ...Generate
561568
return nil, nil, err
562569
}
563570

571+
// If there's no text content to parse (e.g., the response contains tool
572+
// requests or interrupts), return nil output. The caller should check
573+
// resp.Interrupts() or resp.ToolRequests() to handle these cases.
574+
if resp.Text() == "" {
575+
return nil, resp, nil
576+
}
577+
564578
err = resp.Output(&value)
565579
if err != nil {
566-
return nil, nil, err
580+
return nil, resp, err
567581
}
568582

569583
return &value, resp, nil
@@ -715,6 +729,20 @@ func (m *model) supportsConstrained(hasTools bool) bool {
715729
return true
716730
}
717731

732+
// ensureToolRequestRefs assigns unique refs to tool request parts that don't have one.
733+
// This ensures that when there are multiple calls to the same tool, each can be
734+
// individually matched when resuming with Restart or Respond directives.
735+
func ensureToolRequestRefs(msg *Message) {
736+
if msg == nil {
737+
return
738+
}
739+
for _, part := range msg.Content {
740+
if part.IsToolRequest() && part.ToolRequest.Ref == "" {
741+
part.ToolRequest.Ref = uuid.New().String()
742+
}
743+
}
744+
}
745+
718746
// clone creates a deep copy of the provided object using JSON marshaling and unmarshaling.
719747
func clone[T any](obj *T) *T {
720748
if obj == nil {
@@ -1166,7 +1194,7 @@ func handleResumedToolRequest(ctx context.Context, r api.Registry, genOpts *Gene
11661194
}
11671195
}
11681196
}
1169-
if originalInputVal, ok := restartPart.Metadata["originalInput"]; ok {
1197+
if originalInputVal, ok := restartPart.Metadata["replacedInput"]; ok {
11701198
resumedCtx = origInputCtxKey.NewContext(resumedCtx, originalInputVal)
11711199
}
11721200

go/ai/generate_test.go

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -993,13 +993,19 @@ func validTestMessage(m *Message, output *ModelOutputConfig) (*Message, error) {
993993
return handler.ParseMessage(m)
994994
}
995995

996+
type conditionalToolInput struct {
997+
Value string
998+
Interrupt bool
999+
}
1000+
1001+
type resumableToolInput struct {
1002+
Action string
1003+
Data string
1004+
}
1005+
9961006
func TestToolInterruptsAndResume(t *testing.T) {
9971007
conditionalTool := DefineTool(r, "conditional", "tool that may interrupt based on input",
998-
func(ctx *ToolContext, input struct {
999-
Value string
1000-
Interrupt bool
1001-
},
1002-
) (string, error) {
1008+
func(ctx *ToolContext, input conditionalToolInput) (string, error) {
10031009
if input.Interrupt {
10041010
return "", ctx.Interrupt(&InterruptOptions{
10051011
Metadata: map[string]any{
@@ -1014,11 +1020,7 @@ func TestToolInterruptsAndResume(t *testing.T) {
10141020
)
10151021

10161022
resumableTool := DefineTool(r, "resumable", "tool that can be resumed",
1017-
func(ctx *ToolContext, input struct {
1018-
Action string
1019-
Data string
1020-
},
1021-
) (string, error) {
1023+
func(ctx *ToolContext, input resumableToolInput) (string, error) {
10221024
if ctx.Resumed != nil {
10231025
resumedData, ok := ctx.Resumed["data"].(string)
10241026
if ok {
@@ -1156,11 +1158,12 @@ func TestToolInterruptsAndResume(t *testing.T) {
11561158

11571159
interruptedPart := res.Message.Content[1]
11581160

1161+
newInput := conditionalToolInput{
1162+
Value: "new_test_data",
1163+
Interrupt: false,
1164+
}
11591165
restartPart := conditionalTool.Restart(interruptedPart, &RestartOptions{
1160-
ReplaceInput: map[string]any{
1161-
"Value": "new_test_data",
1162-
"Interrupt": false,
1163-
},
1166+
ReplaceInput: newInput,
11641167
ResumedMetadata: map[string]any{
11651168
"data": "resumed_data",
11661169
"source": "restart",
@@ -1175,17 +1178,17 @@ func TestToolInterruptsAndResume(t *testing.T) {
11751178
t.Errorf("expected tool request name 'conditional', got %q", restartPart.ToolRequest.Name)
11761179
}
11771180

1178-
newInput, ok := restartPart.ToolRequest.Input.(map[string]any)
1181+
replacedInput, ok := restartPart.ToolRequest.Input.(conditionalToolInput)
11791182
if !ok {
1180-
t.Fatal("expected input to be map[string]any")
1183+
t.Fatalf("expected input to be conditionalInput, got %T", restartPart.ToolRequest.Input)
11811184
}
11821185

1183-
if newInput["Value"] != "new_test_data" {
1184-
t.Errorf("expected new input value 'new_test_data', got %v", newInput["Value"])
1186+
if replacedInput.Value != "new_test_data" {
1187+
t.Errorf("expected new input value 'new_test_data', got %v", replacedInput.Value)
11851188
}
11861189

1187-
if newInput["Interrupt"] != false {
1188-
t.Errorf("expected interrupt to be false, got %v", newInput["Interrupt"])
1190+
if replacedInput.Interrupt != false {
1191+
t.Errorf("expected interrupt to be false, got %v", replacedInput.Interrupt)
11891192
}
11901193

11911194
if _, hasInterrupt := restartPart.Metadata["interrupt"]; hasInterrupt {
@@ -1242,11 +1245,13 @@ func TestToolInterruptsAndResume(t *testing.T) {
12421245
}
12431246

12441247
interruptedPart := res.Message.Content[1]
1248+
1249+
newInput := conditionalToolInput{
1250+
Value: "restarted_data",
1251+
Interrupt: false,
1252+
}
12451253
restartPart := conditionalTool.Restart(interruptedPart, &RestartOptions{
1246-
ReplaceInput: map[string]any{
1247-
"Value": "restarted_data",
1248-
"Interrupt": false,
1249-
},
1254+
ReplaceInput: newInput,
12501255
ResumedMetadata: map[string]any{
12511256
"data": "restart_context",
12521257
},

go/ai/option.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -915,12 +915,15 @@ func (o *generateOptions) applyGenerate(genOpts *generateOptions) error {
915915
return nil
916916
}
917917

918-
// WithToolResponses sets the tool responses to return from interrupted tool calls.
918+
// WithToolResponses provides resolved responses for interrupted tool calls.
919+
// Use this when you already have the result and want to skip re-executing the tool.
919920
func WithToolResponses(parts ...*Part) GenerateOption {
920921
return &generateOptions{RespondParts: parts}
921922
}
922923

923-
// WithToolRestarts sets the tool requests to restart interrupted tools with.
924+
// WithToolRestarts re-executes interrupted tool calls with additional metadata.
925+
// Use this when the original call lacked required context (e.g., auth, user confirmation)
926+
// that should now allow the tool to complete successfully.
924927
func WithToolRestarts(parts ...*Part) GenerateOption {
925928
return &generateOptions{RestartParts: parts}
926929
}

go/ai/option_test.go

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"context"
2121
"testing"
2222

23+
"github.com/firebase/genkit/go/core/api"
2324
"github.com/google/go-cmp/cmp"
2425
"github.com/google/go-cmp/cmp/cmpopts"
2526
)
@@ -658,15 +659,7 @@ func (t *mockTool) RunRawMultipart(ctx context.Context, input any) (*MultipartTo
658659
return nil, nil
659660
}
660661

661-
func (t *mockTool) Respond(toolReq *Part, outputData any, opts *RespondOptions) *Part {
662-
return nil
663-
}
664-
665-
func (t *mockTool) Restart(toolReq *Part, opts *RestartOptions) *Part {
666-
return nil
667-
}
668-
669-
func (t *mockTool) Register(r interface{ RegisterValue(string, any) }) {
662+
func (t *mockTool) Register(r api.Registry) {
670663
}
671664

672665
func TestWithInputSchemaName(t *testing.T) {

0 commit comments

Comments
 (0)