Skip to content

Commit f3f9a24

Browse files
committed
mcp: Implement KeepAlive for client and server
1 parent abb650e commit f3f9a24

File tree

4 files changed

+182
-0
lines changed

4 files changed

+182
-0
lines changed

mcp/client.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"iter"
1010
"slices"
1111
"sync"
12+
"time"
1213

1314
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
1415
)
@@ -56,6 +57,10 @@ type ClientOptions struct {
5657
ResourceListChangedHandler func(context.Context, *ClientSession, *ResourceListChangedParams)
5758
LoggingMessageHandler func(context.Context, *ClientSession, *LoggingMessageParams)
5859
ProgressNotificationHandler func(context.Context, *ClientSession, *ProgressNotificationParams)
60+
// If non-zero, defines an interval for regular "ping" requests.
61+
// If the peer fails to respond to pings originating from the keepalive check,
62+
// the session is automatically closed.
63+
KeepAlive time.Duration
5964
}
6065

6166
// bind implements the binder[*ClientSession] interface, so that Clients can
@@ -118,6 +123,11 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e
118123
_ = cs.Close()
119124
return nil, err
120125
}
126+
127+
if c.opts.KeepAlive > 0 {
128+
cs.startKeepalive(c.opts.KeepAlive)
129+
}
130+
121131
return cs, nil
122132
}
123133

@@ -131,12 +141,21 @@ type ClientSession struct {
131141
conn *jsonrpc2.Connection
132142
client *Client
133143
initializeResult *InitializeResult
144+
keepaliveCancel context.CancelFunc
134145
}
135146

136147
// Close performs a graceful close of the connection, preventing new requests
137148
// from being handled, and waiting for ongoing requests to return. Close then
138149
// terminates the connection.
139150
func (cs *ClientSession) Close() error {
151+
// Note: keepaliveCancel access is safe without a mutex because:
152+
// 1. keepaliveCancel is only written once during startKeepalive (happens-before all Close calls)
153+
// 2. context.CancelFunc is safe to call multiple times and from multiple goroutines
154+
// 3. The keepalive goroutine calls Close on ping failure, but this is safe since
155+
// Close is idempotent and conn.Close() handles concurrent calls correctly
156+
if cs.keepaliveCancel != nil {
157+
cs.keepaliveCancel()
158+
}
140159
return cs.conn.Close()
141160
}
142161

@@ -146,6 +165,11 @@ func (cs *ClientSession) Wait() error {
146165
return cs.conn.Wait()
147166
}
148167

168+
// startKeepalive starts the keepalive mechanism for this client session.
169+
func (cs *ClientSession) startKeepalive(interval time.Duration) {
170+
startKeepalive(cs, interval, &cs.keepaliveCancel)
171+
}
172+
149173
// AddRoots adds the given roots to the client,
150174
// replacing any with the same URIs,
151175
// and notifies any connected servers.

mcp/mcp_test.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,3 +838,101 @@ func falseSchema() *jsonschema.Schema { return &jsonschema.Schema{Not: &jsonsche
838838
func nopHandler(context.Context, *ServerSession, *CallToolParamsFor[map[string]any]) (*CallToolResult, error) {
839839
return nil, nil
840840
}
841+
842+
func TestKeepAlive(t *testing.T) {
843+
// TODO: try to use the new synctest package for this test once we upgrade to Go 1.24+.
844+
// synctest would allow us to control time and avoid the time.Sleep calls, making the test
845+
// faster and more deterministic.
846+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
847+
defer cancel()
848+
849+
ct, st := NewInMemoryTransports()
850+
851+
serverOpts := &ServerOptions{
852+
KeepAlive: 100 * time.Millisecond,
853+
}
854+
s := NewServer("testServer", "v1.0.0", serverOpts)
855+
s.AddTools(NewServerTool("greet", "say hi", sayHi))
856+
857+
ss, err := s.Connect(ctx, st)
858+
if err != nil {
859+
t.Fatal(err)
860+
}
861+
defer ss.Close()
862+
863+
clientOpts := &ClientOptions{
864+
KeepAlive: 100 * time.Millisecond,
865+
}
866+
c := NewClient("testClient", "v1.0.0", clientOpts)
867+
cs, err := c.Connect(ctx, ct)
868+
if err != nil {
869+
t.Fatal(err)
870+
}
871+
defer cs.Close()
872+
873+
// Wait for a few keepalive cycles to ensure pings are working
874+
time.Sleep(300 * time.Millisecond)
875+
876+
// Test that the connection is still alive by making a call
877+
result, err := cs.CallTool(ctx, &CallToolParams{
878+
Name: "greet",
879+
Arguments: map[string]any{"Name": "user"},
880+
})
881+
if err != nil {
882+
t.Fatalf("call failed after keepalive: %v", err)
883+
}
884+
if len(result.Content) == 0 {
885+
t.Fatal("expected content in result")
886+
}
887+
if textContent, ok := result.Content[0].(*TextContent); !ok || textContent.Text != "hi user" {
888+
t.Fatalf("unexpected result: %v", result.Content[0])
889+
}
890+
}
891+
892+
func TestKeepAliveFailure(t *testing.T) {
893+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
894+
defer cancel()
895+
896+
ct, st := NewInMemoryTransports()
897+
898+
// Server without keepalive (to test one-sided keepalive)
899+
s := NewServer("testServer", "v1.0.0", nil)
900+
s.AddTools(NewServerTool("greet", "say hi", sayHi))
901+
ss, err := s.Connect(ctx, st)
902+
if err != nil {
903+
t.Fatal(err)
904+
}
905+
906+
// Client with short keepalive
907+
clientOpts := &ClientOptions{
908+
KeepAlive: 50 * time.Millisecond,
909+
}
910+
c := NewClient("testClient", "v1.0.0", clientOpts)
911+
cs, err := c.Connect(ctx, ct)
912+
if err != nil {
913+
t.Fatal(err)
914+
}
915+
defer cs.Close()
916+
917+
// Let the connection establish properly first
918+
time.Sleep(30 * time.Millisecond)
919+
920+
// simulate ping failure
921+
ss.Close()
922+
923+
// Wait for keepalive to detect the failure and close the client
924+
// check periodically instead of just waiting
925+
deadline := time.Now().Add(1 * time.Second)
926+
for time.Now().Before(deadline) {
927+
_, err = cs.CallTool(ctx, &CallToolParams{
928+
Name: "greet",
929+
Arguments: map[string]any{"Name": "user"},
930+
})
931+
if errors.Is(err, ErrConnectionClosed) {
932+
return // Test passed
933+
}
934+
time.Sleep(25 * time.Millisecond)
935+
}
936+
937+
t.Errorf("expected connection to be closed by keepalive, but it wasn't. Last error: %v", err)
938+
}

mcp/server.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"path/filepath"
1717
"slices"
1818
"sync"
19+
"time"
1920

2021
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
2122
"github.com/modelcontextprotocol/go-sdk/internal/util"
@@ -57,6 +58,10 @@ type ServerOptions struct {
5758
RootsListChangedHandler func(context.Context, *ServerSession, *RootsListChangedParams)
5859
// If non-nil, called when "notifications/progress" is received.
5960
ProgressNotificationHandler func(context.Context, *ServerSession, *ProgressNotificationParams)
61+
// If non-zero, defines an interval for regular "ping" requests.
62+
// If the peer fails to respond to pings originating from the keepalive check,
63+
// the session is automatically closed.
64+
KeepAlive time.Duration
6065
}
6166

6267
// NewServer creates a new MCP server. The resulting server has no features:
@@ -460,6 +465,9 @@ func (s *Server) Connect(ctx context.Context, t Transport) (*ServerSession, erro
460465
}
461466

462467
func (s *Server) callInitializedHandler(ctx context.Context, ss *ServerSession, params *InitializedParams) (Result, error) {
468+
if s.opts.KeepAlive > 0 {
469+
ss.startKeepalive(s.opts.KeepAlive)
470+
}
463471
return callNotificationHandler(ctx, s.opts.InitializedHandler, ss, params)
464472
}
465473

@@ -492,6 +500,7 @@ type ServerSession struct {
492500
logLevel LoggingLevel
493501
initializeParams *InitializeParams
494502
initialized bool
503+
keepaliveCancel context.CancelFunc
495504
}
496505

497506
// Ping pings the client.
@@ -680,6 +689,14 @@ func (ss *ServerSession) setLevel(_ context.Context, params *SetLevelParams) (*e
680689
// requests from being handled, and waiting for ongoing requests to return.
681690
// Close then terminates the connection.
682691
func (ss *ServerSession) Close() error {
692+
if ss.keepaliveCancel != nil {
693+
// Note: keepaliveCancel access is safe without a mutex because:
694+
// 1. keepaliveCancel is only written once during startKeepalive (happens-before all Close calls)
695+
// 2. context.CancelFunc is safe to call multiple times and from multiple goroutines
696+
// 3. The keepalive goroutine calls Close on ping failure, but this is safe since
697+
// Close is idempotent and conn.Close() handles concurrent calls correctly
698+
ss.keepaliveCancel()
699+
}
683700
return ss.conn.Close()
684701
}
685702

@@ -688,6 +705,11 @@ func (ss *ServerSession) Wait() error {
688705
return ss.conn.Wait()
689706
}
690707

708+
// startKeepalive starts the keepalive mechanism for this server session.
709+
func (ss *ServerSession) startKeepalive(interval time.Duration) {
710+
startKeepalive(ss, interval, &ss.keepaliveCancel)
711+
}
712+
691713
// pageToken is the internal structure for the opaque pagination cursor.
692714
// It will be Gob-encoded and then Base64-encoded for use as a string token.
693715
type pageToken struct {

mcp/shared.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,3 +308,41 @@ type listResult[T any] interface {
308308
// Returns a pointer to the param's NextCursor field.
309309
nextCursorPtr() *string
310310
}
311+
312+
// keepaliveSession represents a session that supports keepalive functionality.
313+
type keepaliveSession interface {
314+
Ping(ctx context.Context, params *PingParams) error
315+
Close() error
316+
}
317+
318+
// startKeepalive starts the keepalive mechanism for a session.
319+
// It assigns the cancel function to the provided cancelPtr and starts a goroutine
320+
// that sends ping messages at the specified interval.
321+
func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr *context.CancelFunc) {
322+
ctx, cancel := context.WithCancel(context.Background())
323+
// Assign cancel function before starting goroutine to avoid race condition.
324+
// We cannot return it because the caller may need to cancel during the
325+
// window between goroutine scheduling and function return.
326+
*cancelPtr = cancel
327+
328+
go func() {
329+
ticker := time.NewTicker(interval)
330+
defer ticker.Stop()
331+
332+
for {
333+
select {
334+
case <-ctx.Done():
335+
return
336+
case <-ticker.C:
337+
pingCtx, pingCancel := context.WithTimeout(context.Background(), interval/2)
338+
err := session.Ping(pingCtx, nil)
339+
pingCancel()
340+
if err != nil {
341+
// Ping failed, close the session
342+
_ = session.Close()
343+
return
344+
}
345+
}
346+
}
347+
}()
348+
}

0 commit comments

Comments
 (0)