Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mcp/client_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,12 @@ func Example_roots() {
if _, err := s.Connect(ctx, t1, nil); err != nil {
log.Fatal(err)
}
if _, err := c.Connect(ctx, t2, nil); err != nil {

clientSession, err := c.Connect(ctx, t2, nil)
if err != nil {
log.Fatal(err)
}
defer clientSession.Close()

// ...and add a root. The server is notified about the change.
c.AddRoots(&mcp.Root{URI: "file://b"})
Expand Down
49 changes: 30 additions & 19 deletions mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,10 @@ func errorCode(err error) int64 {
//
// The caller should cancel either the client connection or server connection
// when the connections are no longer needed.
func basicConnection(t *testing.T, config func(*Server)) (*ClientSession, *ServerSession) {
//
// The returned func cleans up by closing the client and waiting for the server
// to shut down.
func basicConnection(t *testing.T, config func(*Server)) (*ClientSession, *ServerSession, func()) {
return basicClientServerConnection(t, nil, nil, config)
}

Expand All @@ -604,7 +607,10 @@ func basicConnection(t *testing.T, config func(*Server)) (*ClientSession, *Serve
//
// The caller should cancel either the client connection or server connection
// when the connections are no longer needed.
func basicClientServerConnection(t *testing.T, client *Client, server *Server, config func(*Server)) (*ClientSession, *ServerSession) {
//
// The returned func cleans up by closing the client and waiting for the server
// to shut down.
func basicClientServerConnection(t *testing.T, client *Client, server *Server, config func(*Server)) (*ClientSession, *ServerSession, func()) {
t.Helper()

ctx := context.Background()
Expand All @@ -628,14 +634,17 @@ func basicClientServerConnection(t *testing.T, client *Client, server *Server, c
if err != nil {
t.Fatal(err)
}
return cs, ss
return cs, ss, func() {
cs.Close()
ss.Wait()
}
}

func TestServerClosing(t *testing.T) {
cs, ss := basicConnection(t, func(s *Server) {
cs, ss, cleanup := basicConnection(t, func(s *Server) {
AddTool(s, greetTool(), sayHi)
})
defer cs.Close()
defer cleanup()

ctx := context.Background()
var wg sync.WaitGroup
Expand Down Expand Up @@ -715,10 +724,10 @@ func TestCancellation(t *testing.T) {
}
return nil, nil, nil
}
cs, _ := basicConnection(t, func(s *Server) {
cs, _, cleanup := basicConnection(t, func(s *Server) {
AddTool(s, &Tool{Name: "slow", InputSchema: &jsonschema.Schema{Type: "object"}}, slowTool)
})
defer cs.Close()
defer cleanup()

ctx, cancel := context.WithCancel(context.Background())
go cs.CallTool(ctx, &CallToolParams{Name: "slow"})
Expand All @@ -741,13 +750,10 @@ func TestMiddleware(t *testing.T) {
t.Fatal(err)
}
// Wait for the server to exit after the client closes its connection.
var clientWG sync.WaitGroup
clientWG.Add(1)
go func() {
defer func() {
if err := ss.Wait(); err != nil {
t.Errorf("server failed: %v", err)
}
clientWG.Done()
}()

var sbuf, cbuf bytes.Buffer
Expand All @@ -767,6 +773,8 @@ func TestMiddleware(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer cs.Close()

if _, err := cs.ListTools(ctx, nil); err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1511,7 +1519,7 @@ func TestKeepAliveFailure(t *testing.T) {
func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) {
// Adding the same tool pointer twice should not panic and should not
// produce duplicates in the server's tool list.
cs, _ := basicConnection(t, func(s *Server) {
cs, _, cleanup := basicConnection(t, func(s *Server) {
// Use two distinct Tool instances with the same name but different
// descriptions to ensure the second replaces the first
// This case was written specifically to reproduce a bug where duplicate tools where causing jsonschema errors
Expand All @@ -1520,7 +1528,7 @@ func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) {
s.AddTool(t1, nopHandler)
s.AddTool(t2, nopHandler)
})
defer cs.Close()
defer cleanup()

ctx := context.Background()
res, err := cs.ListTools(ctx, nil)
Expand Down Expand Up @@ -1568,14 +1576,15 @@ func TestSynchronousNotifications(t *testing.T) {
},
}
server := NewServer(testImpl, serverOpts)
cs, ss := basicClientServerConnection(t, client, server, func(s *Server) {
cs, ss, cleanup := basicClientServerConnection(t, client, server, func(s *Server) {
AddTool(s, &Tool{Name: "tool"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) {
if !rootsChanged.Load() {
return nil, nil, fmt.Errorf("didn't get root change notification")
}
return new(CallToolResult), nil, nil
})
})
defer cleanup()

t.Run("from client", func(t *testing.T) {
client.AddRoots(&Root{Name: "myroot", URI: "file://foo"})
Expand Down Expand Up @@ -1617,7 +1626,7 @@ func TestNoDistributedDeadlock(t *testing.T) {
},
}
client := NewClient(testImpl, clientOpts)
cs, _ := basicClientServerConnection(t, client, nil, func(s *Server) {
cs, _, cleanup := basicClientServerConnection(t, client, nil, func(s *Server) {
AddTool(s, &Tool{Name: "tool1"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) {
req.Session.CreateMessage(ctx, new(CreateMessageParams))
return new(CallToolResult), nil, nil
Expand All @@ -1627,7 +1636,7 @@ func TestNoDistributedDeadlock(t *testing.T) {
return new(CallToolResult), nil, nil
})
})
defer cs.Close()
defer cleanup()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
Expand All @@ -1651,7 +1660,7 @@ func TestPointerArgEquivalence(t *testing.T) {
type output struct {
Out string
}
cs, _ := basicConnection(t, func(s *Server) {
cs, _, cleanup := basicConnection(t, func(s *Server) {
// Add two equivalent tools, one of which operates in the 'pointer' realm,
// the other of which does not.
//
Expand Down Expand Up @@ -1686,7 +1695,7 @@ func TestPointerArgEquivalence(t *testing.T) {
}
})
})
defer cs.Close()
defer cleanup()

ctx := context.Background()
tools, err := cs.ListTools(ctx, nil)
Expand Down Expand Up @@ -1758,7 +1767,9 @@ func TestComplete(t *testing.T) {
},
}
server := NewServer(testImpl, serverOpts)
cs, _ := basicClientServerConnection(t, nil, server, func(s *Server) {})
cs, _, cleanup := basicClientServerConnection(t, nil, server, func(s *Server) {})
defer cleanup()

result, err := cs.Complete(context.Background(), &CompleteParams{
Argument: CompleteParamsArgument{
Name: "language",
Expand Down
1 change: 1 addition & 0 deletions mcp/server_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ func Example_prompts() {
if err != nil {
log.Fatal(err)
}
defer cs.Close()

// List the prompts.
for p, err := range cs.Prompts(ctx, nil) {
Expand Down
9 changes: 7 additions & 2 deletions mcp/transport_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,21 @@ func ExampleLoggingTransport() {
ctx := context.Background()
t1, t2 := mcp.NewInMemoryTransports()
server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil)
if _, err := server.Connect(ctx, t1, nil); err != nil {
serverSession, err := server.Connect(ctx, t1, nil)
if err != nil {
log.Fatal(err)
}
defer serverSession.Wait()

client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil)
var b bytes.Buffer
logTransport := &mcp.LoggingTransport{Transport: t2, Writer: &b}
if _, err := client.Connect(ctx, logTransport, nil); err != nil {
clientSession, err := client.Connect(ctx, logTransport, nil)
if err != nil {
log.Fatal(err)
}
defer clientSession.Close()

// Sort for stability: reads are concurrent to writes.
for _, line := range slices.Sorted(strings.SplitSeq(b.String(), "\n")) {
fmt.Println(line)
Expand Down
1 change: 1 addition & 0 deletions mcp/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ func TestBatchFraming(t *testing.T) {
r, w := io.Pipe()
tport := newIOConn(rwc{r, w})
tport.outgoingBatch = make([]jsonrpc.Message, 0, 2)
defer tport.Close()

// Read the two messages into a channel, for easy testing later.
read := make(chan jsonrpc.Message)
Expand Down