Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ require (
github.com/google/go-cmp v0.7.0
github.com/google/jsonschema-go v0.3.0
github.com/yosida95/uritemplate/v3 v3.0.2
go.uber.org/goleak v1.3.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I don't think we should add an additional dependency just for this purpose. It seems like handling these as a one-off, every once in a while, is sufficient for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point and removed this from the PR so we can move ahead.

It would have been beneficial to keep an automated test because it will be very easy to introduce regressions without it, but of course that decision is up to you.

golang.org/x/oauth2 v0.30.0
golang.org/x/tools v0.34.0
)
10 changes: 10 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q=
github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=
golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
8 changes: 5 additions & 3 deletions mcp/client_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,17 @@ func Example_roots() {

// Connect the server and client...
t1, t2 := mcp.NewInMemoryTransports()
if _, err := s.Connect(ctx, t1, nil); err != nil {
sess1, err := s.Connect(ctx, t1, nil)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/sess1/serverSession (or ss)
s/sess2/clientSession (or cs)

sess1 and sess2 obscures the fact that these variables have different types.

if err != nil {
log.Fatal(err)
}
defer sess1.Close()

clientSession, err := c.Connect(ctx, t2, nil)
sess2, err := c.Connect(ctx, t2, nil)
if err != nil {
log.Fatal(err)
}
defer clientSession.Close()
defer sess2.Close()

// ...and add a root. The server is notified about the change.
c.AddRoots(&mcp.Root{URI: "file://b"})
Expand Down
46 changes: 28 additions & 18 deletions mcp/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package mcp_test
import (
"context"
"errors"
"flag"
"log"
"os"
"os/exec"
Expand All @@ -18,10 +19,15 @@ import (

"github.com/google/go-cmp/cmp"
"github.com/modelcontextprotocol/go-sdk/mcp"
"go.uber.org/goleak"
)

const runAsServer = "_MCP_RUN_AS_SERVER"

// TODO: remove this flag and always check for goroutine leaks once
// . https://github.com/modelcontextprotocol/go-sdk/issues/499 is fixed
var leakCheck = flag.Bool("leak", false, "enable goroutine leak checking")

type SayHiParams struct {
Name string `json:"name"`
}
Expand All @@ -46,6 +52,13 @@ func TestMain(m *testing.M) {
run()
return
}

flag.Parse()
if *leakCheck {
goleak.VerifyTestMain(m)
return
}

os.Exit(m.Run())
}

Expand Down Expand Up @@ -97,6 +110,8 @@ func TestServerRunContextCancel(t *testing.T) {
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { session.Close() })

if err := session.Ping(context.Background(), nil); err != nil {
t.Fatal(err)
}
Expand All @@ -122,35 +137,30 @@ func TestServerInterrupt(t *testing.T) {
}
requireExec(t)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

t.Log("Starting server command")
cmd := createServerCommand(t, "default")

client := mcp.NewClient(testImpl, nil)
_, err := client.Connect(ctx, &mcp.CommandTransport{Command: cmd}, nil)
t.Log("Connecting to server")

ctx := context.Background()
session, err := client.Connect(ctx, &mcp.CommandTransport{Command: cmd}, nil)
if err != nil {
t.Fatal(err)
}

// get a signal when the server process exits
onExit := make(chan struct{})
go func() {
cmd.Process.Wait()
close(onExit)
}()

// send a signal to the server process to terminate it
t.Log("Send a signal to the server process to terminate it")
if err := cmd.Process.Signal(os.Interrupt); err != nil {
t.Fatal(err)
}

// wait for the server to exit
// TODO: use synctest when available
select {
case <-time.After(5 * time.Second):
t.Fatal("server did not exit after SIGINT")
case <-onExit:
t.Log("Closing client session so server can exit immediately")
session.Close()

t.Log("Wait for process to terminate after interrupt signal")
_, err = cmd.Process.Wait()
if err == nil {
t.Errorf("unexpected error: %v", err)
}
}

Expand Down
25 changes: 8 additions & 17 deletions mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,6 @@ func TestEndToEnd(t *testing.T) {
t.Errorf("after connection, Clients() has length %d, want 1", len(got))
}

// Wait for the server to exit after the client closes its connection.
var clientWG sync.WaitGroup
clientWG.Add(1)
go func() {
if err := ss.Wait(); err != nil {
t.Errorf("server failed: %v", err)
}
clientWG.Done()
}()

loggingMessages := make(chan *LoggingMessageParams, 100) // big enough for all logging
opts := &ClientOptions{
CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) {
Expand Down Expand Up @@ -518,7 +508,9 @@ func TestEndToEnd(t *testing.T) {

// Disconnect.
cs.Close()
clientWG.Wait()
if err := ss.Wait(); err != nil {
t.Errorf("server failed: %v", err)
}

// After disconnecting, neither client nor server should have any
// connections.
Expand Down Expand Up @@ -626,6 +618,7 @@ func basicClientServerConnection(t *testing.T, client *Client, server *Server, c
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = ss.Close() })

if client == nil {
client = NewClient(testImpl, nil)
Expand All @@ -634,6 +627,8 @@ func basicClientServerConnection(t *testing.T, client *Client, server *Server, c
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = cs.Close() })

return cs, ss, func() {
cs.Close()
ss.Wait()
Expand Down Expand Up @@ -750,11 +745,7 @@ func TestMiddleware(t *testing.T) {
t.Fatal(err)
}
// Wait for the server to exit after the client closes its connection.
defer func() {
if err := ss.Wait(); err != nil {
t.Errorf("server failed: %v", err)
}
}()
t.Cleanup(func() { _ = ss.Close() })

var sbuf, cbuf bytes.Buffer
sbuf.WriteByte('\n')
Expand All @@ -773,7 +764,7 @@ func TestMiddleware(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer cs.Close()
t.Cleanup(func() { _ = cs.Close() })

if _, err := cs.ListTools(ctx, nil); err != nil {
t.Fatal(err)
Expand Down
1 change: 1 addition & 0 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,7 @@ func (s *Server) Run(ctx context.Context, t Transport) error {
select {
case <-ctx.Done():
ss.Close()
<-ssClosed // wait until waiting go routine above actually completes
s.opts.Logger.Error("server run cancelled", "error", ctx.Err())
return ctx.Err()
case err := <-ssClosed:
Expand Down
4 changes: 2 additions & 2 deletions mcp/streamable_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func ExampleStreamableHTTPHandler() {
server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.1.0"}, nil)
handler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server {
return server
}, &mcp.StreamableHTTPOptions{JSONResponse: true})
}, &mcp.StreamableHTTPOptions{JSONResponse: true, Stateless: true})
httpServer := httptest.NewServer(handler)
defer httpServer.Close()

Expand All @@ -45,7 +45,7 @@ func ExampleStreamableHTTPHandler_middleware() {
server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.1.0"}, nil)
handler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server {
return server
}, nil)
}, &mcp.StreamableHTTPOptions{Stateless: true})
loggingHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
// Example debugging; you could also capture the response.
body, err := io.ReadAll(req.Body)
Expand Down
46 changes: 32 additions & 14 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@ func TestStreamableServerShutdown(t *testing.T) {
// network failure and receive replayed messages (if replay is configured). It
// uses a proxy that is killed and restarted to simulate a recoverable network
// outage.
//
// TODO: Until we have a way to clean up abandoned sessions, this test will leak goroutines (see #499)
func TestClientReplay(t *testing.T) {
for _, test := range []clientReplayTest{
{"default", 0, true},
Expand Down Expand Up @@ -316,7 +318,10 @@ func testClientReplay(t *testing.T, test clientReplayTest) {
})

realServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil))
defer realServer.Close()
t.Cleanup(func() {
t.Log("Closing real HTTP server")
realServer.Close()
})
realServerURL, err := url.Parse(realServer.URL)
if err != nil {
t.Fatalf("Failed to parse real server URL: %v", err)
Expand All @@ -342,21 +347,20 @@ func testClientReplay(t *testing.T, test clientReplayTest) {
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}
defer clientSession.Close()
t.Cleanup(func() {
t.Log("Closing clientSession")
clientSession.Close()
})

var (
wg sync.WaitGroup
callErr error
)
wg.Add(1)
toolCallResult := make(chan error, 1)
go func() {
defer wg.Done()
_, callErr = clientSession.CallTool(ctx, &CallToolParams{Name: "multiMessageTool"})
_, callErr := clientSession.CallTool(ctx, &CallToolParams{Name: "multiMessageTool"})
toolCallResult <- callErr
}()

select {
case <-serverReadyToKillProxy:
// Server has sent the first two messages and is paused.
t.Log("Server has sent the first two messages and is paused.")
case <-ctx.Done():
t.Fatalf("Context timed out before server was ready to kill proxy")
}
Expand Down Expand Up @@ -384,9 +388,9 @@ func testClientReplay(t *testing.T, test clientReplayTest) {

restartedProxy := &http.Server{Handler: proxyHandler}
go restartedProxy.Serve(listener)
defer restartedProxy.Close()
t.Cleanup(func() { restartedProxy.Close() })

wg.Wait()
callErr := <-toolCallResult

if test.wantRecovered {
// If we've recovered, we should get all 4 notifications and the tool call
Expand Down Expand Up @@ -460,14 +464,15 @@ func TestServerTransportCleanup(t *testing.T) {
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}
defer clientSession.Close()
t.Cleanup(func() { _ = clientSession.Close() })
}

for _, ch := range chans {
select {
case <-ctx.Done():
t.Errorf("did not capture transport deletion event from all session in 10 seconds")
case <-ch: // Received transport deletion signal of this session
case <-ch:
t.Log("Received session transport deletion signal")
}
}

Expand Down Expand Up @@ -1253,6 +1258,7 @@ func TestStreamableStateless(t *testing.T) {
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { cs.Close() })
res, err := cs.CallTool(ctx, &CallToolParams{Name: "greet", Arguments: hiParams{Name: "bar"}})
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -1431,6 +1437,18 @@ func TestStreamableGET(t *testing.T) {
if got, want := resp.StatusCode, http.StatusOK; got != want {
t.Errorf("GET with session ID: got status %d, want %d", got, want)
}

t.Log("Sending final DELETE request to close session and release resources")
del := newReq("DELETE", nil)
del.Header.Set(sessionIDHeader, sessionID)
resp, err = http.DefaultClient.Do(del)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if got, want := resp.StatusCode, http.StatusNoContent; got != want {
t.Errorf("DELETE with session ID: got status %d, want %d", got, want)
}
}

func TestStreamableClientContextPropagation(t *testing.T) {
Expand Down
9 changes: 8 additions & 1 deletion mcp/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,14 @@ func (r rwc) Write(p []byte) (n int, err error) {
}

func (r rwc) Close() error {
return errors.Join(r.rc.Close(), r.wc.Close())
rcErr := r.rc.Close()

var wcErr error
if r.wc != nil { // we only allow a nil writer in unit tests
wcErr = r.wc.Close()
}

return errors.Join(rcErr, wcErr)
}

// An ioConn is a transport that delimits messages with newlines across
Expand Down
2 changes: 1 addition & 1 deletion mcp/transport_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func ExampleLoggingTransport() {
if err != nil {
log.Fatal(err)
}
defer serverSession.Wait()
defer serverSession.Close()

client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil)
var b bytes.Buffer
Expand Down
3 changes: 2 additions & 1 deletion mcp/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestBatchFraming(t *testing.T) {
r, w := io.Pipe()
tport := newIOConn(rwc{r, w})
tport.outgoingBatch = make([]jsonrpc.Message, 0, 2)
defer tport.Close()
t.Cleanup(func() { tport.Close() })

// Read the two messages into a channel, for easy testing later.
read := make(chan jsonrpc.Message)
Expand Down Expand Up @@ -101,6 +101,7 @@ func TestIOConnRead(t *testing.T) {
tr := newIOConn(rwc{
rc: io.NopCloser(strings.NewReader(tt.input)),
})
t.Cleanup(func() { tr.Close() })
if tt.protocolVersion != "" {
tr.sessionUpdated(ServerSessionState{
InitializeParams: &InitializeParams{
Expand Down