Skip to content

Commit 1696b59

Browse files
authored
mcp: fix several goroutine leaks in tests (#494)
Fix some (but not all) goroutine leaks in tests. Others were nontrivial, because they relate to streamable http. For #489.
1 parent 49d45a8 commit 1696b59

File tree

5 files changed

+43
-22
lines changed

5 files changed

+43
-22
lines changed

mcp/client_example_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,12 @@ func Example_roots() {
4545
if _, err := s.Connect(ctx, t1, nil); err != nil {
4646
log.Fatal(err)
4747
}
48-
if _, err := c.Connect(ctx, t2, nil); err != nil {
48+
49+
clientSession, err := c.Connect(ctx, t2, nil)
50+
if err != nil {
4951
log.Fatal(err)
5052
}
53+
defer clientSession.Close()
5154

5255
// ...and add a root. The server is notified about the change.
5356
c.AddRoots(&mcp.Root{URI: "file://b"})

mcp/mcp_test.go

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,10 @@ func errorCode(err error) int64 {
592592
//
593593
// The caller should cancel either the client connection or server connection
594594
// when the connections are no longer needed.
595-
func basicConnection(t *testing.T, config func(*Server)) (*ClientSession, *ServerSession) {
595+
//
596+
// The returned func cleans up by closing the client and waiting for the server
597+
// to shut down.
598+
func basicConnection(t *testing.T, config func(*Server)) (*ClientSession, *ServerSession, func()) {
596599
return basicClientServerConnection(t, nil, nil, config)
597600
}
598601

@@ -604,7 +607,10 @@ func basicConnection(t *testing.T, config func(*Server)) (*ClientSession, *Serve
604607
//
605608
// The caller should cancel either the client connection or server connection
606609
// when the connections are no longer needed.
607-
func basicClientServerConnection(t *testing.T, client *Client, server *Server, config func(*Server)) (*ClientSession, *ServerSession) {
610+
//
611+
// The returned func cleans up by closing the client and waiting for the server
612+
// to shut down.
613+
func basicClientServerConnection(t *testing.T, client *Client, server *Server, config func(*Server)) (*ClientSession, *ServerSession, func()) {
608614
t.Helper()
609615

610616
ctx := context.Background()
@@ -628,14 +634,17 @@ func basicClientServerConnection(t *testing.T, client *Client, server *Server, c
628634
if err != nil {
629635
t.Fatal(err)
630636
}
631-
return cs, ss
637+
return cs, ss, func() {
638+
cs.Close()
639+
ss.Wait()
640+
}
632641
}
633642

634643
func TestServerClosing(t *testing.T) {
635-
cs, ss := basicConnection(t, func(s *Server) {
644+
cs, ss, cleanup := basicConnection(t, func(s *Server) {
636645
AddTool(s, greetTool(), sayHi)
637646
})
638-
defer cs.Close()
647+
defer cleanup()
639648

640649
ctx := context.Background()
641650
var wg sync.WaitGroup
@@ -715,10 +724,10 @@ func TestCancellation(t *testing.T) {
715724
}
716725
return nil, nil, nil
717726
}
718-
cs, _ := basicConnection(t, func(s *Server) {
727+
cs, _, cleanup := basicConnection(t, func(s *Server) {
719728
AddTool(s, &Tool{Name: "slow", InputSchema: &jsonschema.Schema{Type: "object"}}, slowTool)
720729
})
721-
defer cs.Close()
730+
defer cleanup()
722731

723732
ctx, cancel := context.WithCancel(context.Background())
724733
go cs.CallTool(ctx, &CallToolParams{Name: "slow"})
@@ -741,13 +750,10 @@ func TestMiddleware(t *testing.T) {
741750
t.Fatal(err)
742751
}
743752
// Wait for the server to exit after the client closes its connection.
744-
var clientWG sync.WaitGroup
745-
clientWG.Add(1)
746-
go func() {
753+
defer func() {
747754
if err := ss.Wait(); err != nil {
748755
t.Errorf("server failed: %v", err)
749756
}
750-
clientWG.Done()
751757
}()
752758

753759
var sbuf, cbuf bytes.Buffer
@@ -767,6 +773,8 @@ func TestMiddleware(t *testing.T) {
767773
if err != nil {
768774
t.Fatal(err)
769775
}
776+
defer cs.Close()
777+
770778
if _, err := cs.ListTools(ctx, nil); err != nil {
771779
t.Fatal(err)
772780
}
@@ -1511,7 +1519,7 @@ func TestKeepAliveFailure(t *testing.T) {
15111519
func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) {
15121520
// Adding the same tool pointer twice should not panic and should not
15131521
// produce duplicates in the server's tool list.
1514-
cs, _ := basicConnection(t, func(s *Server) {
1522+
cs, _, cleanup := basicConnection(t, func(s *Server) {
15151523
// Use two distinct Tool instances with the same name but different
15161524
// descriptions to ensure the second replaces the first
15171525
// This case was written specifically to reproduce a bug where duplicate tools where causing jsonschema errors
@@ -1520,7 +1528,7 @@ func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) {
15201528
s.AddTool(t1, nopHandler)
15211529
s.AddTool(t2, nopHandler)
15221530
})
1523-
defer cs.Close()
1531+
defer cleanup()
15241532

15251533
ctx := context.Background()
15261534
res, err := cs.ListTools(ctx, nil)
@@ -1568,14 +1576,15 @@ func TestSynchronousNotifications(t *testing.T) {
15681576
},
15691577
}
15701578
server := NewServer(testImpl, serverOpts)
1571-
cs, ss := basicClientServerConnection(t, client, server, func(s *Server) {
1579+
cs, ss, cleanup := basicClientServerConnection(t, client, server, func(s *Server) {
15721580
AddTool(s, &Tool{Name: "tool"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) {
15731581
if !rootsChanged.Load() {
15741582
return nil, nil, fmt.Errorf("didn't get root change notification")
15751583
}
15761584
return new(CallToolResult), nil, nil
15771585
})
15781586
})
1587+
defer cleanup()
15791588

15801589
t.Run("from client", func(t *testing.T) {
15811590
client.AddRoots(&Root{Name: "myroot", URI: "file://foo"})
@@ -1617,7 +1626,7 @@ func TestNoDistributedDeadlock(t *testing.T) {
16171626
},
16181627
}
16191628
client := NewClient(testImpl, clientOpts)
1620-
cs, _ := basicClientServerConnection(t, client, nil, func(s *Server) {
1629+
cs, _, cleanup := basicClientServerConnection(t, client, nil, func(s *Server) {
16211630
AddTool(s, &Tool{Name: "tool1"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) {
16221631
req.Session.CreateMessage(ctx, new(CreateMessageParams))
16231632
return new(CallToolResult), nil, nil
@@ -1627,7 +1636,7 @@ func TestNoDistributedDeadlock(t *testing.T) {
16271636
return new(CallToolResult), nil, nil
16281637
})
16291638
})
1630-
defer cs.Close()
1639+
defer cleanup()
16311640

16321641
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
16331642
defer cancel()
@@ -1651,7 +1660,7 @@ func TestPointerArgEquivalence(t *testing.T) {
16511660
type output struct {
16521661
Out string
16531662
}
1654-
cs, _ := basicConnection(t, func(s *Server) {
1663+
cs, _, cleanup := basicConnection(t, func(s *Server) {
16551664
// Add two equivalent tools, one of which operates in the 'pointer' realm,
16561665
// the other of which does not.
16571666
//
@@ -1686,7 +1695,7 @@ func TestPointerArgEquivalence(t *testing.T) {
16861695
}
16871696
})
16881697
})
1689-
defer cs.Close()
1698+
defer cleanup()
16901699

16911700
ctx := context.Background()
16921701
tools, err := cs.ListTools(ctx, nil)
@@ -1758,7 +1767,9 @@ func TestComplete(t *testing.T) {
17581767
},
17591768
}
17601769
server := NewServer(testImpl, serverOpts)
1761-
cs, _ := basicClientServerConnection(t, nil, server, func(s *Server) {})
1770+
cs, _, cleanup := basicClientServerConnection(t, nil, server, func(s *Server) {})
1771+
defer cleanup()
1772+
17621773
result, err := cs.Complete(context.Background(), &CompleteParams{
17631774
Argument: CompleteParamsArgument{
17641775
Name: "language",

mcp/server_example_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ func Example_prompts() {
5555
if err != nil {
5656
log.Fatal(err)
5757
}
58+
defer cs.Close()
5859

5960
// List the prompts.
6061
for p, err := range cs.Prompts(ctx, nil) {

mcp/transport_example_test.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,21 @@ func ExampleLoggingTransport() {
2424
ctx := context.Background()
2525
t1, t2 := mcp.NewInMemoryTransports()
2626
server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil)
27-
if _, err := server.Connect(ctx, t1, nil); err != nil {
27+
serverSession, err := server.Connect(ctx, t1, nil)
28+
if err != nil {
2829
log.Fatal(err)
2930
}
31+
defer serverSession.Wait()
3032

3133
client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil)
3234
var b bytes.Buffer
3335
logTransport := &mcp.LoggingTransport{Transport: t2, Writer: &b}
34-
if _, err := client.Connect(ctx, logTransport, nil); err != nil {
36+
clientSession, err := client.Connect(ctx, logTransport, nil)
37+
if err != nil {
3538
log.Fatal(err)
3639
}
40+
defer clientSession.Close()
41+
3742
// Sort for stability: reads are concurrent to writes.
3843
for _, line := range slices.Sorted(strings.SplitSeq(b.String(), "\n")) {
3944
fmt.Println(line)

mcp/transport_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ func TestBatchFraming(t *testing.T) {
2525
r, w := io.Pipe()
2626
tport := newIOConn(rwc{r, w})
2727
tport.outgoingBatch = make([]jsonrpc.Message, 0, 2)
28+
defer tport.Close()
2829

2930
// Read the two messages into a channel, for easy testing later.
3031
read := make(chan jsonrpc.Message)

0 commit comments

Comments
 (0)