Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"iter"
"slices"
"sync"
"time"

"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
)
Expand Down Expand Up @@ -56,6 +57,10 @@ type ClientOptions struct {
ResourceListChangedHandler func(context.Context, *ClientSession, *ResourceListChangedParams)
LoggingMessageHandler func(context.Context, *ClientSession, *LoggingMessageParams)
ProgressNotificationHandler func(context.Context, *ClientSession, *ProgressNotificationParams)
// If non-zero, defines an interval for regular "ping" requests.
// If the peer fails to respond to pings originating from the keepalive check,
// the session is automatically closed.
KeepAlive time.Duration
}

// bind implements the binder[*ClientSession] interface, so that Clients can
Expand Down Expand Up @@ -118,6 +123,11 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e
_ = cs.Close()
return nil, err
}

if c.opts.KeepAlive > 0 {
cs.startKeepalive(c.opts.KeepAlive)
}

return cs, nil
}

Expand All @@ -131,12 +141,21 @@ type ClientSession struct {
conn *jsonrpc2.Connection
client *Client
initializeResult *InitializeResult
keepaliveCancel context.CancelFunc
}

// Close performs a graceful close of the connection, preventing new requests
// from being handled, and waiting for ongoing requests to return. Close then
// terminates the connection.
func (cs *ClientSession) Close() error {
// Note: keepaliveCancel access is safe without a mutex because:
// 1. keepaliveCancel is only written once during startKeepalive (happens-before all Close calls)
// 2. context.CancelFunc is safe to call multiple times and from multiple goroutines
// 3. The keepalive goroutine calls Close on ping failure, but this is safe since
// Close is idempotent and conn.Close() handles concurrent calls correctly
if cs.keepaliveCancel != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like a user Close and the Close from the startKeepalive goroutine could race.
I don't think they can, but it's worth an explanation, in terms of happens-before relations, why no mutex is needed for keepaliveCancel. (If you don't understand what I'm talking about, I can explain further.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya, I agree, I'm pretty sure that we do not need a mutex here, I've added the following comment explaining the flow and why it's safe:

	// Note: keepaliveCancel access is safe without a mutex because:
	// 1. keepaliveCancel is only written once during startKeepalive (happens-before user Close calls)
	// 2. context.CancelFunc is safe to call multiple times and from multiple goroutines
	// 3. The keepalive goroutine calls Close on ping failure, but this is safe since
	//    Close is idempotent and conn.Close() handles concurrent calls correctly

Please double check my reasoning though!

cs.keepaliveCancel()
}
return cs.conn.Close()
}

Expand All @@ -146,6 +165,11 @@ func (cs *ClientSession) Wait() error {
return cs.conn.Wait()
}

// startKeepalive starts the keepalive mechanism for this client session.
func (cs *ClientSession) startKeepalive(interval time.Duration) {
startKeepalive(cs, interval, &cs.keepaliveCancel)
}

// AddRoots adds the given roots to the client,
// replacing any with the same URIs,
// and notifies any connected servers.
Expand Down
98 changes: 98 additions & 0 deletions mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -838,3 +838,101 @@ func falseSchema() *jsonschema.Schema { return &jsonschema.Schema{Not: &jsonsche
func nopHandler(context.Context, *ServerSession, *CallToolParamsFor[map[string]any]) (*CallToolResult, error) {
return nil, nil
}

func TestKeepAlive(t *testing.T) {
// TODO: try to use the new synctest package for this test once we upgrade to Go 1.24+.
// synctest would allow us to control time and avoid the time.Sleep calls, making the test
// faster and more deterministic.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

ct, st := NewInMemoryTransports()

serverOpts := &ServerOptions{
KeepAlive: 100 * time.Millisecond,
}
s := NewServer("testServer", "v1.0.0", serverOpts)
s.AddTools(NewServerTool("greet", "say hi", sayHi))

ss, err := s.Connect(ctx, st)
if err != nil {
t.Fatal(err)
}
defer ss.Close()

clientOpts := &ClientOptions{
KeepAlive: 100 * time.Millisecond,
}
c := NewClient("testClient", "v1.0.0", clientOpts)
cs, err := c.Connect(ctx, ct)
if err != nil {
t.Fatal(err)
}
defer cs.Close()

// Wait for a few keepalive cycles to ensure pings are working
time.Sleep(300 * time.Millisecond)

// Test that the connection is still alive by making a call
result, err := cs.CallTool(ctx, &CallToolParams{
Name: "greet",
Arguments: map[string]any{"Name": "user"},
})
if err != nil {
t.Fatalf("call failed after keepalive: %v", err)
}
if len(result.Content) == 0 {
t.Fatal("expected content in result")
}
if textContent, ok := result.Content[0].(*TextContent); !ok || textContent.Text != "hi user" {
t.Fatalf("unexpected result: %v", result.Content[0])
}
}

func TestKeepAliveFailure(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

ct, st := NewInMemoryTransports()

// Server without keepalive (to test one-sided keepalive)
s := NewServer("testServer", "v1.0.0", nil)
s.AddTools(NewServerTool("greet", "say hi", sayHi))
ss, err := s.Connect(ctx, st)
if err != nil {
t.Fatal(err)
}

// Client with short keepalive
clientOpts := &ClientOptions{
KeepAlive: 50 * time.Millisecond,
}
c := NewClient("testClient", "v1.0.0", clientOpts)
cs, err := c.Connect(ctx, ct)
if err != nil {
t.Fatal(err)
}
defer cs.Close()

// Let the connection establish properly first
time.Sleep(30 * time.Millisecond)

// simulate ping failure
ss.Close()

// Wait for keepalive to detect the failure and close the client
// check periodically instead of just waiting
deadline := time.Now().Add(1 * time.Second)
for time.Now().Before(deadline) {
_, err = cs.CallTool(ctx, &CallToolParams{
Name: "greet",
Arguments: map[string]any{"Name": "user"},
})
if errors.Is(err, ErrConnectionClosed) {
return // Test passed
}
time.Sleep(25 * time.Millisecond)
}

t.Errorf("expected connection to be closed by keepalive, but it wasn't. Last error: %v", err)
}
22 changes: 22 additions & 0 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"path/filepath"
"slices"
"sync"
"time"

"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
"github.com/modelcontextprotocol/go-sdk/internal/util"
Expand Down Expand Up @@ -57,6 +58,10 @@ type ServerOptions struct {
RootsListChangedHandler func(context.Context, *ServerSession, *RootsListChangedParams)
// If non-nil, called when "notifications/progress" is received.
ProgressNotificationHandler func(context.Context, *ServerSession, *ProgressNotificationParams)
// If non-zero, defines an interval for regular "ping" requests.
// If the peer fails to respond to pings originating from the keepalive check,
// the session is automatically closed.
KeepAlive time.Duration
}

// NewServer creates a new MCP server. The resulting server has no features:
Expand Down Expand Up @@ -460,6 +465,9 @@ func (s *Server) Connect(ctx context.Context, t Transport) (*ServerSession, erro
}

func (s *Server) callInitializedHandler(ctx context.Context, ss *ServerSession, params *InitializedParams) (Result, error) {
if s.opts.KeepAlive > 0 {
ss.startKeepalive(s.opts.KeepAlive)
}
return callNotificationHandler(ctx, s.opts.InitializedHandler, ss, params)
}

Expand Down Expand Up @@ -492,6 +500,7 @@ type ServerSession struct {
logLevel LoggingLevel
initializeParams *InitializeParams
initialized bool
keepaliveCancel context.CancelFunc
}

// Ping pings the client.
Expand Down Expand Up @@ -680,6 +689,14 @@ func (ss *ServerSession) setLevel(_ context.Context, params *SetLevelParams) (*e
// requests from being handled, and waiting for ongoing requests to return.
// Close then terminates the connection.
func (ss *ServerSession) Close() error {
if ss.keepaliveCancel != nil {
// Note: keepaliveCancel access is safe without a mutex because:
// 1. keepaliveCancel is only written once during startKeepalive (happens-before all Close calls)
// 2. context.CancelFunc is safe to call multiple times and from multiple goroutines
// 3. The keepalive goroutine calls Close on ping failure, but this is safe since
// Close is idempotent and conn.Close() handles concurrent calls correctly
ss.keepaliveCancel()
}
return ss.conn.Close()
}

Expand All @@ -688,6 +705,11 @@ func (ss *ServerSession) Wait() error {
return ss.conn.Wait()
}

// startKeepalive starts the keepalive mechanism for this server session.
func (ss *ServerSession) startKeepalive(interval time.Duration) {
startKeepalive(ss, interval, &ss.keepaliveCancel)
}

// pageToken is the internal structure for the opaque pagination cursor.
// It will be Gob-encoded and then Base64-encoded for use as a string token.
type pageToken struct {
Expand Down
38 changes: 38 additions & 0 deletions mcp/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,41 @@ type listResult[T any] interface {
// Returns a pointer to the param's NextCursor field.
nextCursorPtr() *string
}

// keepaliveSession represents a session that supports keepalive functionality.
type keepaliveSession interface {
Ping(ctx context.Context, params *PingParams) error
Close() error
}

// startKeepalive starts the keepalive mechanism for a session.
// It assigns the cancel function to the provided cancelPtr and starts a goroutine
// that sends ping messages at the specified interval.
func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr *context.CancelFunc) {
ctx, cancel := context.WithCancel(context.Background())
// Assign cancel function before starting goroutine to avoid race condition.
// We cannot return it because the caller may need to cancel during the
// window between goroutine scheduling and function return.
*cancelPtr = cancel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment saying that we must assign this before starting the goroutine, so we cannot return it.

(I first thought "why not just return the cancel func instead of passing in a pointer?" But it would be racy.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion, thank you! Updated with this comment:

	// Assign cancel function before starting goroutine to avoid race condition.
	// We cannot return it because the caller may need to cancel during the
	// window between goroutine scheduling and function return.


go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()

for {
select {
case <-ctx.Done():
return
case <-ticker.C:
pingCtx, pingCancel := context.WithTimeout(context.Background(), interval/2)
err := session.Ping(pingCtx, nil)
pingCancel()
if err != nil {
// Ping failed, close the session
_ = session.Close()
return
}
}
}
}()
}
Loading