diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 6ccaebf7..6bbe8c4e 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -192,6 +192,57 @@ func TestStreamableTransports(t *testing.T) { } } +func TestStreamableConcurrentHandling(t *testing.T) { + // This test checks that the streamable server and client transports can + // communicate. + type count struct { + Count int + } + + var mu sync.Mutex + counts := make(map[string]int) + + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "inc"}, func(ctx context.Context, req *CallToolRequest, _ any) (*CallToolResult, count, error) { + id := req.Session.ID() + mu.Lock() + defer mu.Unlock() + c := counts[id] + counts[id] = c + 1 + return nil, count{c}, nil + }) + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + defer httpServer.Close() + + ctx := context.Background() + client := NewClient(testImpl, nil) + var wg sync.WaitGroup + for range 100 { + wg.Add(1) + go func() { + defer wg.Done() + clientSession, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, nil) + if err != nil { + t.Errorf("Connect failed: %v", err) + return + } + defer clientSession.Close() + for i := range 10 { + res, err := clientSession.CallTool(ctx, &CallToolParams{Name: "inc"}) + if err != nil { + t.Errorf("CallTool failed: %v", err) + return + } + if got := int(res.StructuredContent.(map[string]any)["Count"].(float64)); got != i { + t.Errorf("got count %d, want %d", got, i) + } + } + }() + } + wg.Wait() +} + func TestStreamableServerShutdown(t *testing.T) { ctx := context.Background()