diff --git a/mcp/server.go b/mcp/server.go index ed4ec720..966cf7ed 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -602,9 +602,6 @@ func (ss *ServerSession) initialized(ctx context.Context, params *InitializedPar // params are non-nil. params = new(InitializedParams) } - if ss.server.opts.KeepAlive > 0 { - ss.startKeepalive(ss.server.opts.KeepAlive) - } var wasInit, wasInitd bool ss.updateState(func(state *ServerSessionState) { wasInit = state.InitializeParams != nil @@ -620,6 +617,9 @@ func (ss *ServerSession) initialized(ctx context.Context, params *InitializedPar if wasInitd { return nil, fmt.Errorf("duplicate %q received", notificationInitialized) } + if ss.server.opts.KeepAlive > 0 { + ss.startKeepalive(ss.server.opts.KeepAlive) + } if h := ss.server.opts.InitializedHandler; h != nil { h(ctx, serverRequestFor(ss, params)) } diff --git a/mcp/server_test.go b/mcp/server_test.go index 202ab5d9..bf325a95 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -9,6 +9,7 @@ import ( "log" "slices" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/jsonschema-go/jsonschema" @@ -371,3 +372,51 @@ func TestServerCapabilities(t *testing.T) { }) } } + +// TestServerSessionkeepaliveCancelOverwritten is to verify that `ServerSession.keepaliveCancel` is assigned exactly once, +// ensuring that only a single goroutine is responsible for the session's keepalive ping mechanism. +func TestServerSessionkeepaliveCancelOverwritten(t *testing.T) { + // Set KeepAlive to a long duration to ensure the keepalive + // goroutine stays alive for the duration of the test without actually sending + // ping requests, since we don't have a real client connection established. + server := NewServer(testImpl, &ServerOptions{KeepAlive: 5 * time.Second}) + ss := &ServerSession{server: server} + + // 1. Initialize the session. + _, err := ss.initialize(context.Background(), &InitializeParams{}) + if err != nil { + t.Fatalf("ServerSession initialize failed: %v", err) + } + + // 2. Call 'initialized' for the first time. This should start the keepalive mechanism. + _, err = ss.initialized(context.Background(), &InitializedParams{}) + if err != nil { + t.Fatalf("First initialized call failed: %v", err) + } + if ss.keepaliveCancel == nil { + t.Fatalf("expected ServerSession.keepaliveCancel to be set after the first call of initialized") + } + + // Save the cancel function and use defer to ensure resources are cleaned up. + firstCancel := ss.keepaliveCancel + defer firstCancel() + + // 3. Manually set the field to nil. + // Do this to facilitate the test's core assertion. The goal is to verify that + // 'ss.keepaliveCancel' is not assigned a second time. By setting it to nil, + // we can easily check after the next call if a new keepalive goroutine was started. + ss.keepaliveCancel = nil + + // 4. Call 'initialized' for the second time. This should return an error. + _, err = ss.initialized(context.Background(), &InitializedParams{}) + if err == nil { + t.Fatalf("Expected 'duplicate initialized received' error on second call, got nil") + } + + // 5. Re-check the field to ensure it remains nil. + // Since 'initialized' correctly returned an error and did not call + // 'startKeepalive', the field should remain unchanged. + if ss.keepaliveCancel != nil { + t.Fatal("expected ServerSession.keepaliveCancel to be nil after we manually niled it and re-initialized") + } +}