Skip to content

Commit 02f0b25

Browse files
authored
Move GetSessionID closure into ServerOptions (#488)
Fixes: #478.
1 parent 353d46f commit 02f0b25

File tree

3 files changed

+34
-23
lines changed

3 files changed

+34
-23
lines changed

mcp/server.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,16 @@ type ServerOptions struct {
8383
// If true, advertises the tools capability during initialization,
8484
// even if no tools have been registered.
8585
HasTools bool
86+
87+
// GetSessionID provides the next session ID to use for an incoming request.
88+
// If nil, a default randomly generated ID will be used.
89+
//
90+
// Session IDs should be globally unique across the scope of the server,
91+
// which may span multiple processes in the case of distributed servers.
92+
//
93+
// As a special case, if GetSessionID returns the empty string, the
94+
// Mcp-Session-Id header will not be set.
95+
GetSessionID func() string
8696
}
8797

8898
// NewServer creates a new MCP server. The resulting server has no features:
@@ -114,6 +124,11 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server {
114124
if opts.UnsubscribeHandler != nil && opts.SubscribeHandler == nil {
115125
panic("UnsubscribeHandler requires SubscribeHandler")
116126
}
127+
128+
if opts.GetSessionID == nil {
129+
opts.GetSessionID = randText
130+
}
131+
117132
return &Server{
118133
impl: impl,
119134
opts: opts,

mcp/streamable.go

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,6 @@ type StreamableHTTPHandler struct {
5050

5151
// StreamableHTTPOptions configures the StreamableHTTPHandler.
5252
type StreamableHTTPOptions struct {
53-
// GetSessionID provides the next session ID to use for an incoming request.
54-
// If nil, a default randomly generated ID will be used.
55-
//
56-
// Session IDs should be globally unique across the scope of the server,
57-
// which may span multiple processes in the case of distributed servers.
58-
//
59-
// As a special case, if GetSessionID returns the empty string, the
60-
// Mcp-Session-Id header will not be set.
61-
GetSessionID func() string
62-
6353
// Stateless controls whether the session is 'stateless'.
6454
//
6555
// A stateless server does not validate the Mcp-Session-Id header, and uses a
@@ -92,9 +82,6 @@ func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *Strea
9282
if opts != nil {
9383
h.opts = *opts
9484
}
95-
if h.opts.GetSessionID == nil {
96-
h.opts.GetSessionID = randText
97-
}
9885
return h
9986
}
10087

@@ -233,7 +220,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
233220
if sessionID == "" {
234221
// In stateless mode, sessionID may be nonempty even if there's no
235222
// existing transport.
236-
sessionID = h.opts.GetSessionID()
223+
sessionID = server.opts.GetSessionID()
237224
}
238225
transport = &StreamableServerTransport{
239226
SessionID: sessionID,

mcp/streamable_test.go

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -409,15 +409,14 @@ func testClientReplay(t *testing.T, test clientReplayTest) {
409409
}
410410

411411
func TestServerTransportCleanup(t *testing.T) {
412-
server := NewServer(testImpl, &ServerOptions{KeepAlive: 10 * time.Millisecond})
413-
414412
nClient := 3
415413

416414
var mu sync.Mutex
417415
var id int = -1 // session id starting from "0", "1", "2"...
418416
chans := make(map[string]chan struct{}, nClient)
419417

420-
handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{
418+
server := NewServer(testImpl, &ServerOptions{
419+
KeepAlive: 10 * time.Millisecond,
421420
GetSessionID: func() string {
422421
mu.Lock()
423422
defer mu.Unlock()
@@ -430,6 +429,7 @@ func TestServerTransportCleanup(t *testing.T) {
430429
},
431430
})
432431

432+
handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)
433433
handler.onTransportDeletion = func(sessionID string) {
434434
chans[sessionID] <- struct{}{}
435435
}
@@ -1199,8 +1199,6 @@ func TestStreamableStateless(t *testing.T) {
11991199
}
12001200
return &CallToolResult{Content: []Content{&TextContent{Text: "hi " + args.Name}}}, nil, nil
12011201
}
1202-
server := NewServer(testImpl, nil)
1203-
AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi)
12041202

12051203
requests := []streamableRequest{
12061204
{
@@ -1263,9 +1261,15 @@ func TestStreamableStateless(t *testing.T) {
12631261
}
12641262
}
12651263

1266-
sessionlessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{
1267-
GetSessionID: func() string { return "" },
1268-
Stateless: true,
1264+
sessionlessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server {
1265+
// Return a stateless server which never assigns a session ID.
1266+
server := NewServer(testImpl, &ServerOptions{
1267+
GetSessionID: func() string { return "" },
1268+
})
1269+
AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi)
1270+
return server
1271+
}, &StreamableHTTPOptions{
1272+
Stateless: true,
12691273
})
12701274

12711275
// First, test the "sessionless" stateless mode, where there is no session ID.
@@ -1279,7 +1283,12 @@ func TestStreamableStateless(t *testing.T) {
12791283
// This can be used by tools to look up application state preserved across
12801284
// subsequent requests.
12811285
requests[0].wantSessionID = true // now expect a session ID for initialize
1282-
statelessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{
1286+
statelessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server {
1287+
// Return a server with default options which should assign a random session ID.
1288+
server := NewServer(testImpl, nil)
1289+
AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi)
1290+
return server
1291+
}, &StreamableHTTPOptions{
12831292
Stateless: true,
12841293
})
12851294
t.Run("stateless", func(t *testing.T) {

0 commit comments

Comments
 (0)