diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index fd1dc3e4..25dd224e 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -40,6 +40,25 @@ func TestStreamableTransports(t *testing.T) { // 1. Create a server with a simple "greet" tool. server := NewServer(testImpl, nil) AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) + AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + // Test that we can make sampling requests during tool handling. + // + // Try this on both the request context and a background context, so + // that messages may be delivered on either the POST or GET connection. + for _, ctx := range map[string]context.Context{ + "request context": ctx, + "background context": context.Background(), + } { + res, err := req.Session.CreateMessage(ctx, &CreateMessageParams{}) + if err != nil { + return nil, err + } + if g, w := res.Model, "aModel"; g != w { + return nil, fmt.Errorf("got %q, want %q", g, w) + } + } + return &CallToolResultFor[any]{}, nil + }) // 2. Start an httptest.Server with the StreamableHTTPHandler, wrapped in a // cookie-checking middleware. @@ -81,7 +100,11 @@ func TestStreamableTransports(t *testing.T) { Endpoint: httpServer.URL, HTTPClient: httpClient, } - client := NewClient(testImpl, nil) + client := NewClient(testImpl, &ClientOptions{ + CreateMessageHandler: func(context.Context, *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) { + return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil + }, + }) session, err := client.Connect(ctx, transport, nil) if err != nil { t.Fatalf("client.Connect() failed: %v", err) @@ -119,6 +142,19 @@ func TestStreamableTransports(t *testing.T) { if diff := cmp.Diff(want, got); diff != "" { t.Errorf("CallTool() returned unexpected content (-want +got):\n%s", diff) } + + // 6. Run the "sampling" tool and verify that the streamable server can + // call tools. + result, err := session.CallTool(ctx, &CallToolParams{ + Name: "sample", + Arguments: map[string]any{}, + }) + if err != nil { + t.Fatal(err) + } + if result.IsError { + t.Fatalf("tool failed: %s", result.Content[0].(*TextContent).Text) + } }) } }