Skip to content

Commit 1cf2c6e

Browse files
committed
mcp: add a test for streamable sampling during a tool call
Add a test that attempts (and fails) to reproduce the bug reported in issue #285. For #285
1 parent e097918 commit 1cf2c6e

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

mcp/streamable_test.go

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,25 @@ func TestStreamableTransports(t *testing.T) {
4040
// 1. Create a server with a simple "greet" tool.
4141
server := NewServer(testImpl, nil)
4242
AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi)
43+
AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, ss *ServerSession, _ *CallToolParamsFor[any]) (*CallToolResultFor[any], error) {
44+
// Test that we can make sampling requests during tool handling.
45+
//
46+
// Try this on both the request context and a background context, so
47+
// that messages may be delivered on either the POST or GET connection.
48+
for _, ctx := range map[string]context.Context{
49+
"request context": ctx,
50+
"background context": context.Background(),
51+
} {
52+
res, err := ss.CreateMessage(ctx, &CreateMessageParams{})
53+
if err != nil {
54+
return nil, err
55+
}
56+
if g, w := res.Model, "aModel"; g != w {
57+
return nil, fmt.Errorf("got %q, want %q", g, w)
58+
}
59+
}
60+
return &CallToolResultFor[any]{}, nil
61+
})
4362

4463
// 2. Start an httptest.Server with the StreamableHTTPHandler, wrapped in a
4564
// cookie-checking middleware.
@@ -81,7 +100,11 @@ func TestStreamableTransports(t *testing.T) {
81100
Endpoint: httpServer.URL,
82101
HTTPClient: httpClient,
83102
}
84-
client := NewClient(testImpl, nil)
103+
client := NewClient(testImpl, &ClientOptions{
104+
CreateMessageHandler: func(context.Context, *ClientSession, *CreateMessageParams) (*CreateMessageResult, error) {
105+
return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil
106+
},
107+
})
85108
session, err := client.Connect(ctx, transport, nil)
86109
if err != nil {
87110
t.Fatalf("client.Connect() failed: %v", err)
@@ -119,6 +142,19 @@ func TestStreamableTransports(t *testing.T) {
119142
if diff := cmp.Diff(want, got); diff != "" {
120143
t.Errorf("CallTool() returned unexpected content (-want +got):\n%s", diff)
121144
}
145+
146+
// 6. Run the "sampling" tool and verify that the streamable server can
147+
// call tools.
148+
result, err := session.CallTool(ctx, &CallToolParams{
149+
Name: "sample",
150+
Arguments: map[string]any{},
151+
})
152+
if err != nil {
153+
t.Fatal(err)
154+
}
155+
if result.IsError {
156+
t.Fatalf("tool failed: %s", result.Content[0].(*TextContent).Text)
157+
}
122158
})
123159
}
124160
}

0 commit comments

Comments
 (0)