Skip to content

Commit 30d2449

Browse files
committed
chore: address review comments
1 parent 5d9c799 commit 30d2449

File tree

4 files changed

+48
-46
lines changed

4 files changed

+48
-46
lines changed

mcp/client.go

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,11 @@ type ClientSession struct {
148148
// from being handled, and waiting for ongoing requests to return. Close then
149149
// terminates the connection.
150150
func (cs *ClientSession) Close() error {
151+
// Note: keepaliveCancel access is safe without a mutex because:
152+
// 1. keepaliveCancel is only written once during startKeepalive (happens-before user Close calls)
153+
// 2. context.CancelFunc is safe to call multiple times and from multiple goroutines
154+
// 3. The keepalive goroutine calls Close on ping failure, but this is safe since
155+
// Close is idempotent and conn.Close() handles concurrent calls correctly
151156
if cs.keepaliveCancel != nil {
152157
cs.keepaliveCancel()
153158
}
@@ -162,29 +167,7 @@ func (cs *ClientSession) Wait() error {
162167

163168
// startKeepalive starts the keepalive mechanism for this client session.
164169
func (cs *ClientSession) startKeepalive(interval time.Duration) {
165-
ctx, cancel := context.WithCancel(context.Background())
166-
cs.keepaliveCancel = cancel
167-
168-
go func() {
169-
ticker := time.NewTicker(interval)
170-
defer ticker.Stop()
171-
172-
for {
173-
select {
174-
case <-ctx.Done():
175-
return
176-
case <-ticker.C:
177-
pingCtx, pingCancel := context.WithTimeout(context.Background(), interval/2)
178-
err := cs.Ping(pingCtx, nil)
179-
pingCancel()
180-
if err != nil {
181-
// Ping failed, close the session
182-
_ = cs.Close()
183-
return
184-
}
185-
}
186-
}
187-
}()
170+
startKeepalive(cs, interval, &cs.keepaliveCancel)
188171
}
189172

190173
// AddRoots adds the given roots to the client,

mcp/mcp_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,7 @@ func nopHandler(context.Context, *ServerSession, *CallToolParamsFor[map[string]a
840840
}
841841

842842
func TestKeepAlive(t *testing.T) {
843+
// TODO: try to use the new synctest package for this test.
843844
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
844845
defer cancel()
845846

mcp/server.go

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,11 @@ func (ss *ServerSession) setLevel(_ context.Context, params *SetLevelParams) (*e
690690
// Close then terminates the connection.
691691
func (ss *ServerSession) Close() error {
692692
if ss.keepaliveCancel != nil {
693+
// Note: keepaliveCancel access is safe without a mutex because:
694+
// 1. keepaliveCancel is only written once during startKeepalive (happens-before user Close calls)
695+
// 2. context.CancelFunc is safe to call multiple times and from multiple goroutines
696+
// 3. The keepalive goroutine calls Close on ping failure, but this is safe since
697+
// Close is idempotent and conn.Close() handles concurrent calls correctly
693698
ss.keepaliveCancel()
694699
}
695700
return ss.conn.Close()
@@ -702,29 +707,7 @@ func (ss *ServerSession) Wait() error {
702707

703708
// startKeepalive starts the keepalive mechanism for this server session.
704709
func (ss *ServerSession) startKeepalive(interval time.Duration) {
705-
ctx, cancel := context.WithCancel(context.Background())
706-
ss.keepaliveCancel = cancel
707-
708-
go func() {
709-
ticker := time.NewTicker(interval)
710-
defer ticker.Stop()
711-
712-
for {
713-
select {
714-
case <-ctx.Done():
715-
return
716-
case <-ticker.C:
717-
pingCtx, pingCancel := context.WithTimeout(context.Background(), interval/2)
718-
err := ss.Ping(pingCtx, nil)
719-
pingCancel()
720-
if err != nil {
721-
// Ping failed, close the session
722-
_ = ss.Close()
723-
return
724-
}
725-
}
726-
}
727-
}()
710+
startKeepalive(ss, interval, &ss.keepaliveCancel)
728711
}
729712

730713
// pageToken is the internal structure for the opaque pagination cursor.

mcp/shared.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,3 +308,38 @@ type listResult[T any] interface {
308308
// Returns a pointer to the param's NextCursor field.
309309
nextCursorPtr() *string
310310
}
311+
312+
// keepaliveSession represents a session that supports keepalive functionality.
313+
type keepaliveSession interface {
314+
Ping(ctx context.Context, params *PingParams) error
315+
Close() error
316+
}
317+
318+
// startKeepalive starts the keepalive mechanism for a session.
319+
// It assigns the cancel function to the provided cancelPtr and starts a goroutine
320+
// that sends ping messages at the specified interval.
321+
func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr *context.CancelFunc) {
322+
ctx, cancel := context.WithCancel(context.Background())
323+
*cancelPtr = cancel
324+
325+
go func() {
326+
ticker := time.NewTicker(interval)
327+
defer ticker.Stop()
328+
329+
for {
330+
select {
331+
case <-ctx.Done():
332+
return
333+
case <-ticker.C:
334+
pingCtx, pingCancel := context.WithTimeout(context.Background(), interval/2)
335+
err := session.Ping(pingCtx, nil)
336+
pingCancel()
337+
if err != nil {
338+
// Ping failed, close the session
339+
_ = session.Close()
340+
return
341+
}
342+
}
343+
}
344+
}()
345+
}

0 commit comments

Comments
 (0)