@@ -11,6 +11,7 @@ import (
1111 "errors"
1212 "fmt"
1313 "io"
14+ "log"
1415 "maps"
1516 "net"
1617 "net/http"
@@ -192,6 +193,58 @@ func TestStreamableTransports(t *testing.T) {
192193 }
193194}
194195
196+ func TestStreamableConcurrentHandling (t * testing.T ) {
197+ // This test checks that the streamable server and client transports can
198+ // communicate.
199+ type count struct {
200+ Count int
201+ }
202+
203+ var mu sync.Mutex
204+ counts := make (map [string ]int )
205+
206+ server := NewServer (testImpl , nil )
207+ AddTool (server , & Tool {Name : "inc" }, func (ctx context.Context , req * CallToolRequest , _ any ) (* CallToolResult , count , error ) {
208+ log .Println ("incing" )
209+ id := req .Session .ID ()
210+ mu .Lock ()
211+ defer mu .Unlock ()
212+ c := counts [id ]
213+ counts [id ] = c + 1
214+ return nil , count {c }, nil
215+ })
216+ handler := NewStreamableHTTPHandler (func (req * http.Request ) * Server { return server }, nil )
217+ httpServer := httptest .NewServer (mustNotPanic (t , handler ))
218+ defer httpServer .Close ()
219+
220+ ctx := context .Background ()
221+ client := NewClient (testImpl , nil )
222+ var wg sync.WaitGroup
223+ for range 100 {
224+ wg .Add (1 )
225+ go func () {
226+ defer wg .Done ()
227+ clientSession , err := client .Connect (ctx , & StreamableClientTransport {Endpoint : httpServer .URL }, nil )
228+ if err != nil {
229+ t .Errorf ("Connect failed: %v" , err )
230+ return
231+ }
232+ defer clientSession .Close ()
233+ for i := range 10 {
234+ res , err := clientSession .CallTool (ctx , & CallToolParams {Name : "inc" })
235+ if err != nil {
236+ t .Errorf ("CallTool failed: %v" , err )
237+ return
238+ }
239+ if got := int (res .StructuredContent .(map [string ]any )["Count" ].(float64 )); got != i {
240+ t .Errorf ("got count %d, want %d" , got , i )
241+ }
242+ }
243+ }()
244+ }
245+ wg .Wait ()
246+ }
247+
195248func TestStreamableServerShutdown (t * testing.T ) {
196249 ctx := context .Background ()
197250
0 commit comments