Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,25 @@ func TestStreamableTransports(t *testing.T) {
// 1. Create a server with a simple "greet" tool.
server := NewServer(testImpl, nil)
AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi)
AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[any]]) (*CallToolResultFor[any], error) {
// Test that we can make sampling requests during tool handling.
//
// Try this on both the request context and a background context, so
// that messages may be delivered on either the POST or GET connection.
for _, ctx := range map[string]context.Context{
"request context": ctx,
"background context": context.Background(),
} {
res, err := req.Session.CreateMessage(ctx, &CreateMessageParams{})
if err != nil {
return nil, err
}
if g, w := res.Model, "aModel"; g != w {
return nil, fmt.Errorf("got %q, want %q", g, w)
}
}
return &CallToolResultFor[any]{}, nil
})

// 2. Start an httptest.Server with the StreamableHTTPHandler, wrapped in a
// cookie-checking middleware.
Expand Down Expand Up @@ -81,7 +100,11 @@ func TestStreamableTransports(t *testing.T) {
Endpoint: httpServer.URL,
HTTPClient: httpClient,
}
client := NewClient(testImpl, nil)
client := NewClient(testImpl, &ClientOptions{
CreateMessageHandler: func(context.Context, *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) {
return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil
},
})
session, err := client.Connect(ctx, transport, nil)
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
Expand Down Expand Up @@ -119,6 +142,19 @@ func TestStreamableTransports(t *testing.T) {
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("CallTool() returned unexpected content (-want +got):\n%s", diff)
}

// 6. Run the "sampling" tool and verify that the streamable server can
// call tools.
result, err := session.CallTool(ctx, &CallToolParams{
Name: "sample",
Arguments: map[string]any{},
})
if err != nil {
t.Fatal(err)
}
if result.IsError {
t.Fatalf("tool failed: %s", result.Content[0].(*TextContent).Text)
}
})
}
}
Expand Down