Skip to content

Commit d77a3d2

Browse files
committed
mcp: add a explicit test for concurrent streamable sessions
1 parent 31e97ad commit d77a3d2

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

mcp/streamable_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"errors"
1212
"fmt"
1313
"io"
14+
"log"
1415
"maps"
1516
"net"
1617
"net/http"
@@ -192,6 +193,58 @@ func TestStreamableTransports(t *testing.T) {
192193
}
193194
}
194195

196+
func TestStreamableConcurrentHandling(t *testing.T) {
197+
// This test checks that the streamable server and client transports can
198+
// communicate.
199+
type count struct {
200+
Count int
201+
}
202+
203+
var mu sync.Mutex
204+
counts := make(map[string]int)
205+
206+
server := NewServer(testImpl, nil)
207+
AddTool(server, &Tool{Name: "inc"}, func(ctx context.Context, req *CallToolRequest, _ any) (*CallToolResult, count, error) {
208+
log.Println("incing")
209+
id := req.Session.ID()
210+
mu.Lock()
211+
defer mu.Unlock()
212+
c := counts[id]
213+
counts[id] = c + 1
214+
return nil, count{c}, nil
215+
})
216+
handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil)
217+
httpServer := httptest.NewServer(mustNotPanic(t, handler))
218+
defer httpServer.Close()
219+
220+
ctx := context.Background()
221+
client := NewClient(testImpl, nil)
222+
var wg sync.WaitGroup
223+
for range 100 {
224+
wg.Add(1)
225+
go func() {
226+
defer wg.Done()
227+
clientSession, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, nil)
228+
if err != nil {
229+
t.Errorf("Connect failed: %v", err)
230+
return
231+
}
232+
defer clientSession.Close()
233+
for i := range 10 {
234+
res, err := clientSession.CallTool(ctx, &CallToolParams{Name: "inc"})
235+
if err != nil {
236+
t.Errorf("CallTool failed: %v", err)
237+
return
238+
}
239+
if got := int(res.StructuredContent.(map[string]any)["Count"].(float64)); got != i {
240+
t.Errorf("got count %d, want %d", got, i)
241+
}
242+
}
243+
}()
244+
}
245+
wg.Wait()
246+
}
247+
195248
func TestStreamableServerShutdown(t *testing.T) {
196249
ctx := context.Background()
197250

0 commit comments

Comments
 (0)