diff --git a/mcp/server.go b/mcp/server.go index 69808ac7..1621e4bd 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -83,6 +83,16 @@ type ServerOptions struct { // If true, advertises the tools capability during initialization, // even if no tools have been registered. HasTools bool + + // GetSessionID provides the next session ID to use for an incoming request. + // If nil, a default randomly generated ID will be used. + // + // Session IDs should be globally unique across the scope of the server, + // which may span multiple processes in the case of distributed servers. + // + // As a special case, if GetSessionID returns the empty string, the + // Mcp-Session-Id header will not be set. + GetSessionID func() string } // NewServer creates a new MCP server. The resulting server has no features: @@ -114,6 +124,11 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { if opts.UnsubscribeHandler != nil && opts.SubscribeHandler == nil { panic("UnsubscribeHandler requires SubscribeHandler") } + + if opts.GetSessionID == nil { + opts.GetSessionID = randText + } + return &Server{ impl: impl, opts: opts, diff --git a/mcp/streamable.go b/mcp/streamable.go index 8ac6f59a..f359b7cc 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -50,16 +50,6 @@ type StreamableHTTPHandler struct { // StreamableHTTPOptions configures the StreamableHTTPHandler. type StreamableHTTPOptions struct { - // GetSessionID provides the next session ID to use for an incoming request. - // If nil, a default randomly generated ID will be used. - // - // Session IDs should be globally unique across the scope of the server, - // which may span multiple processes in the case of distributed servers. - // - // As a special case, if GetSessionID returns the empty string, the - // Mcp-Session-Id header will not be set. - GetSessionID func() string - // Stateless controls whether the session is 'stateless'. // // A stateless server does not validate the Mcp-Session-Id header, and uses a @@ -92,9 +82,6 @@ func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *Strea if opts != nil { h.opts = *opts } - if h.opts.GetSessionID == nil { - h.opts.GetSessionID = randText - } return h } @@ -233,7 +220,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque if sessionID == "" { // In stateless mode, sessionID may be nonempty even if there's no // existing transport. - sessionID = h.opts.GetSessionID() + sessionID = server.opts.GetSessionID() } transport = &StreamableServerTransport{ SessionID: sessionID, diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index e077308c..eb822071 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -409,15 +409,14 @@ func testClientReplay(t *testing.T, test clientReplayTest) { } func TestServerTransportCleanup(t *testing.T) { - server := NewServer(testImpl, &ServerOptions{KeepAlive: 10 * time.Millisecond}) - nClient := 3 var mu sync.Mutex var id int = -1 // session id starting from "0", "1", "2"... chans := make(map[string]chan struct{}, nClient) - handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ + server := NewServer(testImpl, &ServerOptions{ + KeepAlive: 10 * time.Millisecond, GetSessionID: func() string { mu.Lock() defer mu.Unlock() @@ -430,6 +429,7 @@ func TestServerTransportCleanup(t *testing.T) { }, }) + handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil) handler.onTransportDeletion = func(sessionID string) { chans[sessionID] <- struct{}{} } @@ -1201,8 +1201,6 @@ func TestStreamableStateless(t *testing.T) { } return &CallToolResult{Content: []Content{&TextContent{Text: "hi " + args.Name}}}, nil, nil } - server := NewServer(testImpl, nil) - AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) requests := []streamableRequest{ { @@ -1265,9 +1263,15 @@ func TestStreamableStateless(t *testing.T) { } } - sessionlessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ - GetSessionID: func() string { return "" }, - Stateless: true, + sessionlessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { + // Return a stateless server which never assigns a session ID. + server := NewServer(testImpl, &ServerOptions{ + GetSessionID: func() string { return "" }, + }) + AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) + return server + }, &StreamableHTTPOptions{ + Stateless: true, }) // First, test the "sessionless" stateless mode, where there is no session ID. @@ -1281,7 +1285,12 @@ func TestStreamableStateless(t *testing.T) { // This can be used by tools to look up application state preserved across // subsequent requests. requests[0].wantSessionID = true // now expect a session ID for initialize - statelessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ + statelessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { + // Return a server with default options which should assign a random session ID. + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) + return server + }, &StreamableHTTPOptions{ Stateless: true, }) t.Run("stateless", func(t *testing.T) {