diff --git a/mcp/client.go b/mcp/client.go index 17875c38..82fbb01c 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -9,6 +9,7 @@ import ( "iter" "slices" "sync" + "time" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" ) @@ -56,6 +57,10 @@ type ClientOptions struct { ResourceListChangedHandler func(context.Context, *ClientSession, *ResourceListChangedParams) LoggingMessageHandler func(context.Context, *ClientSession, *LoggingMessageParams) ProgressNotificationHandler func(context.Context, *ClientSession, *ProgressNotificationParams) + // If non-zero, defines an interval for regular "ping" requests. + // If the peer fails to respond to pings originating from the keepalive check, + // the session is automatically closed. + KeepAlive time.Duration } // bind implements the binder[*ClientSession] interface, so that Clients can @@ -118,6 +123,11 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e _ = cs.Close() return nil, err } + + if c.opts.KeepAlive > 0 { + cs.startKeepalive(c.opts.KeepAlive) + } + return cs, nil } @@ -131,12 +141,21 @@ type ClientSession struct { conn *jsonrpc2.Connection client *Client initializeResult *InitializeResult + keepaliveCancel context.CancelFunc } // Close performs a graceful close of the connection, preventing new requests // from being handled, and waiting for ongoing requests to return. Close then // terminates the connection. func (cs *ClientSession) Close() error { + // Note: keepaliveCancel access is safe without a mutex because: + // 1. keepaliveCancel is only written once during startKeepalive (happens-before all Close calls) + // 2. context.CancelFunc is safe to call multiple times and from multiple goroutines + // 3. The keepalive goroutine calls Close on ping failure, but this is safe since + // Close is idempotent and conn.Close() handles concurrent calls correctly + if cs.keepaliveCancel != nil { + cs.keepaliveCancel() + } return cs.conn.Close() } @@ -146,6 +165,11 @@ func (cs *ClientSession) Wait() error { return cs.conn.Wait() } +// startKeepalive starts the keepalive mechanism for this client session. +func (cs *ClientSession) startKeepalive(interval time.Duration) { + startKeepalive(cs, interval, &cs.keepaliveCancel) +} + // AddRoots adds the given roots to the client, // replacing any with the same URIs, // and notifies any connected servers. diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 71b99a4a..dc7f80bf 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -838,3 +838,101 @@ func falseSchema() *jsonschema.Schema { return &jsonschema.Schema{Not: &jsonsche func nopHandler(context.Context, *ServerSession, *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { return nil, nil } + +func TestKeepAlive(t *testing.T) { + // TODO: try to use the new synctest package for this test once we upgrade to Go 1.24+. + // synctest would allow us to control time and avoid the time.Sleep calls, making the test + // faster and more deterministic. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + ct, st := NewInMemoryTransports() + + serverOpts := &ServerOptions{ + KeepAlive: 100 * time.Millisecond, + } + s := NewServer("testServer", "v1.0.0", serverOpts) + s.AddTools(NewServerTool("greet", "say hi", sayHi)) + + ss, err := s.Connect(ctx, st) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + clientOpts := &ClientOptions{ + KeepAlive: 100 * time.Millisecond, + } + c := NewClient("testClient", "v1.0.0", clientOpts) + cs, err := c.Connect(ctx, ct) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Wait for a few keepalive cycles to ensure pings are working + time.Sleep(300 * time.Millisecond) + + // Test that the connection is still alive by making a call + result, err := cs.CallTool(ctx, &CallToolParams{ + Name: "greet", + Arguments: map[string]any{"Name": "user"}, + }) + if err != nil { + t.Fatalf("call failed after keepalive: %v", err) + } + if len(result.Content) == 0 { + t.Fatal("expected content in result") + } + if textContent, ok := result.Content[0].(*TextContent); !ok || textContent.Text != "hi user" { + t.Fatalf("unexpected result: %v", result.Content[0]) + } +} + +func TestKeepAliveFailure(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + ct, st := NewInMemoryTransports() + + // Server without keepalive (to test one-sided keepalive) + s := NewServer("testServer", "v1.0.0", nil) + s.AddTools(NewServerTool("greet", "say hi", sayHi)) + ss, err := s.Connect(ctx, st) + if err != nil { + t.Fatal(err) + } + + // Client with short keepalive + clientOpts := &ClientOptions{ + KeepAlive: 50 * time.Millisecond, + } + c := NewClient("testClient", "v1.0.0", clientOpts) + cs, err := c.Connect(ctx, ct) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Let the connection establish properly first + time.Sleep(30 * time.Millisecond) + + // simulate ping failure + ss.Close() + + // Wait for keepalive to detect the failure and close the client + // check periodically instead of just waiting + deadline := time.Now().Add(1 * time.Second) + for time.Now().Before(deadline) { + _, err = cs.CallTool(ctx, &CallToolParams{ + Name: "greet", + Arguments: map[string]any{"Name": "user"}, + }) + if errors.Is(err, ErrConnectionClosed) { + return // Test passed + } + time.Sleep(25 * time.Millisecond) + } + + t.Errorf("expected connection to be closed by keepalive, but it wasn't. Last error: %v", err) +} diff --git a/mcp/server.go b/mcp/server.go index c4aeeed9..843fa5ce 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -16,6 +16,7 @@ import ( "path/filepath" "slices" "sync" + "time" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/internal/util" @@ -57,6 +58,10 @@ type ServerOptions struct { RootsListChangedHandler func(context.Context, *ServerSession, *RootsListChangedParams) // If non-nil, called when "notifications/progress" is received. ProgressNotificationHandler func(context.Context, *ServerSession, *ProgressNotificationParams) + // If non-zero, defines an interval for regular "ping" requests. + // If the peer fails to respond to pings originating from the keepalive check, + // the session is automatically closed. + KeepAlive time.Duration } // NewServer creates a new MCP server. The resulting server has no features: @@ -460,6 +465,9 @@ func (s *Server) Connect(ctx context.Context, t Transport) (*ServerSession, erro } func (s *Server) callInitializedHandler(ctx context.Context, ss *ServerSession, params *InitializedParams) (Result, error) { + if s.opts.KeepAlive > 0 { + ss.startKeepalive(s.opts.KeepAlive) + } return callNotificationHandler(ctx, s.opts.InitializedHandler, ss, params) } @@ -492,6 +500,7 @@ type ServerSession struct { logLevel LoggingLevel initializeParams *InitializeParams initialized bool + keepaliveCancel context.CancelFunc } // Ping pings the client. @@ -680,6 +689,14 @@ func (ss *ServerSession) setLevel(_ context.Context, params *SetLevelParams) (*e // requests from being handled, and waiting for ongoing requests to return. // Close then terminates the connection. func (ss *ServerSession) Close() error { + if ss.keepaliveCancel != nil { + // Note: keepaliveCancel access is safe without a mutex because: + // 1. keepaliveCancel is only written once during startKeepalive (happens-before all Close calls) + // 2. context.CancelFunc is safe to call multiple times and from multiple goroutines + // 3. The keepalive goroutine calls Close on ping failure, but this is safe since + // Close is idempotent and conn.Close() handles concurrent calls correctly + ss.keepaliveCancel() + } return ss.conn.Close() } @@ -688,6 +705,11 @@ func (ss *ServerSession) Wait() error { return ss.conn.Wait() } +// startKeepalive starts the keepalive mechanism for this server session. +func (ss *ServerSession) startKeepalive(interval time.Duration) { + startKeepalive(ss, interval, &ss.keepaliveCancel) +} + // pageToken is the internal structure for the opaque pagination cursor. // It will be Gob-encoded and then Base64-encoded for use as a string token. type pageToken struct { diff --git a/mcp/shared.go b/mcp/shared.go index 07ea5dff..c281965a 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -308,3 +308,41 @@ type listResult[T any] interface { // Returns a pointer to the param's NextCursor field. nextCursorPtr() *string } + +// keepaliveSession represents a session that supports keepalive functionality. +type keepaliveSession interface { + Ping(ctx context.Context, params *PingParams) error + Close() error +} + +// startKeepalive starts the keepalive mechanism for a session. +// It assigns the cancel function to the provided cancelPtr and starts a goroutine +// that sends ping messages at the specified interval. +func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr *context.CancelFunc) { + ctx, cancel := context.WithCancel(context.Background()) + // Assign cancel function before starting goroutine to avoid race condition. + // We cannot return it because the caller may need to cancel during the + // window between goroutine scheduling and function return. + *cancelPtr = cancel + + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + pingCtx, pingCancel := context.WithTimeout(context.Background(), interval/2) + err := session.Ping(pingCtx, nil) + pingCancel() + if err != nil { + // Ping failed, close the session + _ = session.Close() + return + } + } + } + }() +}