Skip to content

Commit 5d9c799

Browse files
committed
feat: Implement KeepAlive for client and server
1 parent abb650e commit 5d9c799

File tree

3 files changed

+175
-0
lines changed

3 files changed

+175
-0
lines changed

mcp/client.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"iter"
1010
"slices"
1111
"sync"
12+
"time"
1213

1314
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
1415
)
@@ -56,6 +57,10 @@ type ClientOptions struct {
5657
ResourceListChangedHandler func(context.Context, *ClientSession, *ResourceListChangedParams)
5758
LoggingMessageHandler func(context.Context, *ClientSession, *LoggingMessageParams)
5859
ProgressNotificationHandler func(context.Context, *ClientSession, *ProgressNotificationParams)
60+
// If non-zero, defines an interval for regular "ping" requests.
61+
// If the peer fails to respond to pings originating from the keepalive check,
62+
// the session is automatically closed.
63+
KeepAlive time.Duration
5964
}
6065

6166
// bind implements the binder[*ClientSession] interface, so that Clients can
@@ -118,6 +123,11 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e
118123
_ = cs.Close()
119124
return nil, err
120125
}
126+
127+
if c.opts.KeepAlive > 0 {
128+
cs.startKeepalive(c.opts.KeepAlive)
129+
}
130+
121131
return cs, nil
122132
}
123133

@@ -131,12 +141,16 @@ type ClientSession struct {
131141
conn *jsonrpc2.Connection
132142
client *Client
133143
initializeResult *InitializeResult
144+
keepaliveCancel context.CancelFunc
134145
}
135146

136147
// Close performs a graceful close of the connection, preventing new requests
137148
// from being handled, and waiting for ongoing requests to return. Close then
138149
// terminates the connection.
139150
func (cs *ClientSession) Close() error {
151+
if cs.keepaliveCancel != nil {
152+
cs.keepaliveCancel()
153+
}
140154
return cs.conn.Close()
141155
}
142156

@@ -146,6 +160,33 @@ func (cs *ClientSession) Wait() error {
146160
return cs.conn.Wait()
147161
}
148162

163+
// startKeepalive starts the keepalive mechanism for this client session.
164+
func (cs *ClientSession) startKeepalive(interval time.Duration) {
165+
ctx, cancel := context.WithCancel(context.Background())
166+
cs.keepaliveCancel = cancel
167+
168+
go func() {
169+
ticker := time.NewTicker(interval)
170+
defer ticker.Stop()
171+
172+
for {
173+
select {
174+
case <-ctx.Done():
175+
return
176+
case <-ticker.C:
177+
pingCtx, pingCancel := context.WithTimeout(context.Background(), interval/2)
178+
err := cs.Ping(pingCtx, nil)
179+
pingCancel()
180+
if err != nil {
181+
// Ping failed, close the session
182+
_ = cs.Close()
183+
return
184+
}
185+
}
186+
}
187+
}()
188+
}
189+
149190
// AddRoots adds the given roots to the client,
150191
// replacing any with the same URIs,
151192
// and notifies any connected servers.

mcp/mcp_test.go

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,3 +838,98 @@ func falseSchema() *jsonschema.Schema { return &jsonschema.Schema{Not: &jsonsche
838838
func nopHandler(context.Context, *ServerSession, *CallToolParamsFor[map[string]any]) (*CallToolResult, error) {
839839
return nil, nil
840840
}
841+
842+
func TestKeepAlive(t *testing.T) {
843+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
844+
defer cancel()
845+
846+
ct, st := NewInMemoryTransports()
847+
848+
serverOpts := &ServerOptions{
849+
KeepAlive: 100 * time.Millisecond,
850+
}
851+
s := NewServer("testServer", "v1.0.0", serverOpts)
852+
s.AddTools(NewServerTool("greet", "say hi", sayHi))
853+
854+
ss, err := s.Connect(ctx, st)
855+
if err != nil {
856+
t.Fatal(err)
857+
}
858+
defer ss.Close()
859+
860+
clientOpts := &ClientOptions{
861+
KeepAlive: 100 * time.Millisecond,
862+
}
863+
c := NewClient("testClient", "v1.0.0", clientOpts)
864+
cs, err := c.Connect(ctx, ct)
865+
if err != nil {
866+
t.Fatal(err)
867+
}
868+
defer cs.Close()
869+
870+
// Wait for a few keepalive cycles to ensure pings are working
871+
time.Sleep(300 * time.Millisecond)
872+
873+
// Test that the connection is still alive by making a call
874+
result, err := cs.CallTool(ctx, &CallToolParams{
875+
Name: "greet",
876+
Arguments: map[string]any{"Name": "user"},
877+
})
878+
if err != nil {
879+
t.Fatalf("call failed after keepalive: %v", err)
880+
}
881+
if len(result.Content) == 0 {
882+
t.Fatal("expected content in result")
883+
}
884+
if textContent, ok := result.Content[0].(*TextContent); !ok || textContent.Text != "hi user" {
885+
t.Fatalf("unexpected result: %v", result.Content[0])
886+
}
887+
}
888+
889+
func TestKeepAliveFailure(t *testing.T) {
890+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
891+
defer cancel()
892+
893+
ct, st := NewInMemoryTransports()
894+
895+
// Server without keepalive (to test one-sided keepalive)
896+
s := NewServer("testServer", "v1.0.0", nil)
897+
s.AddTools(NewServerTool("greet", "say hi", sayHi))
898+
ss, err := s.Connect(ctx, st)
899+
if err != nil {
900+
t.Fatal(err)
901+
}
902+
903+
// Client with short keepalive
904+
clientOpts := &ClientOptions{
905+
KeepAlive: 50 * time.Millisecond,
906+
}
907+
c := NewClient("testClient", "v1.0.0", clientOpts)
908+
cs, err := c.Connect(ctx, ct)
909+
if err != nil {
910+
t.Fatal(err)
911+
}
912+
defer cs.Close()
913+
914+
// Let the connection establish properly first
915+
time.Sleep(30 * time.Millisecond)
916+
917+
// simulate ping failure
918+
ss.Close()
919+
920+
// Wait for keepalive to detect the failure and close the client
921+
// check periodically instead of just waiting
922+
deadline := time.Now().Add(1 * time.Second)
923+
for time.Now().Before(deadline) {
924+
_, err = cs.CallTool(ctx, &CallToolParams{
925+
Name: "greet",
926+
Arguments: map[string]any{"Name": "user"},
927+
})
928+
if errors.Is(err, ErrConnectionClosed) {
929+
return // Test passed
930+
}
931+
time.Sleep(25 * time.Millisecond)
932+
}
933+
934+
t.Errorf("expected connection to be closed by keepalive, but it wasn't. Last error: %v", err)
935+
}

mcp/server.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"path/filepath"
1717
"slices"
1818
"sync"
19+
"time"
1920

2021
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
2122
"github.com/modelcontextprotocol/go-sdk/internal/util"
@@ -57,6 +58,10 @@ type ServerOptions struct {
5758
RootsListChangedHandler func(context.Context, *ServerSession, *RootsListChangedParams)
5859
// If non-nil, called when "notifications/progress" is received.
5960
ProgressNotificationHandler func(context.Context, *ServerSession, *ProgressNotificationParams)
61+
// If non-zero, defines an interval for regular "ping" requests.
62+
// If the peer fails to respond to pings originating from the keepalive check,
63+
// the session is automatically closed.
64+
KeepAlive time.Duration
6065
}
6166

6267
// NewServer creates a new MCP server. The resulting server has no features:
@@ -460,6 +465,9 @@ func (s *Server) Connect(ctx context.Context, t Transport) (*ServerSession, erro
460465
}
461466

462467
func (s *Server) callInitializedHandler(ctx context.Context, ss *ServerSession, params *InitializedParams) (Result, error) {
468+
if s.opts.KeepAlive > 0 {
469+
ss.startKeepalive(s.opts.KeepAlive)
470+
}
463471
return callNotificationHandler(ctx, s.opts.InitializedHandler, ss, params)
464472
}
465473

@@ -492,6 +500,7 @@ type ServerSession struct {
492500
logLevel LoggingLevel
493501
initializeParams *InitializeParams
494502
initialized bool
503+
keepaliveCancel context.CancelFunc
495504
}
496505

497506
// Ping pings the client.
@@ -680,6 +689,9 @@ func (ss *ServerSession) setLevel(_ context.Context, params *SetLevelParams) (*e
680689
// requests from being handled, and waiting for ongoing requests to return.
681690
// Close then terminates the connection.
682691
func (ss *ServerSession) Close() error {
692+
if ss.keepaliveCancel != nil {
693+
ss.keepaliveCancel()
694+
}
683695
return ss.conn.Close()
684696
}
685697

@@ -688,6 +700,33 @@ func (ss *ServerSession) Wait() error {
688700
return ss.conn.Wait()
689701
}
690702

703+
// startKeepalive starts the keepalive mechanism for this server session.
704+
func (ss *ServerSession) startKeepalive(interval time.Duration) {
705+
ctx, cancel := context.WithCancel(context.Background())
706+
ss.keepaliveCancel = cancel
707+
708+
go func() {
709+
ticker := time.NewTicker(interval)
710+
defer ticker.Stop()
711+
712+
for {
713+
select {
714+
case <-ctx.Done():
715+
return
716+
case <-ticker.C:
717+
pingCtx, pingCancel := context.WithTimeout(context.Background(), interval/2)
718+
err := ss.Ping(pingCtx, nil)
719+
pingCancel()
720+
if err != nil {
721+
// Ping failed, close the session
722+
_ = ss.Close()
723+
return
724+
}
725+
}
726+
}
727+
}()
728+
}
729+
691730
// pageToken is the internal structure for the opaque pagination cursor.
692731
// It will be Gob-encoded and then Base64-encoded for use as a string token.
693732
type pageToken struct {

0 commit comments

Comments
 (0)