Skip to content

Commit dc6ec72

Browse files
committed
mcp: move GetSessionID closure into ServerOptions
fixes: #478
1 parent 22f86c4 commit dc6ec72

File tree

3 files changed

+34
-23
lines changed

3 files changed

+34
-23
lines changed

mcp/server.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,16 @@ type ServerOptions struct {
8383
// If true, advertises the tools capability during initialization,
8484
// even if no tools have been registered.
8585
HasTools bool
86+
87+
// GetSessionID provides the next session ID to use for an incoming request.
88+
// If nil, a default randomly generated ID will be used.
89+
//
90+
// Session IDs should be globally unique across the scope of the server,
91+
// which may span multiple processes in the case of distributed servers.
92+
//
93+
// As a special case, if GetSessionID returns the empty string, the
94+
// Mcp-Session-Id header will not be set.
95+
GetSessionID func() string
8696
}
8797

8898
// NewServer creates a new MCP server. The resulting server has no features:
@@ -114,6 +124,11 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server {
114124
if opts.UnsubscribeHandler != nil && opts.SubscribeHandler == nil {
115125
panic("UnsubscribeHandler requires SubscribeHandler")
116126
}
127+
128+
if opts.GetSessionID == nil {
129+
opts.GetSessionID = randText
130+
}
131+
117132
return &Server{
118133
impl: impl,
119134
opts: opts,

mcp/streamable.go

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,6 @@ type StreamableHTTPHandler struct {
5050

5151
// StreamableHTTPOptions configures the StreamableHTTPHandler.
5252
type StreamableHTTPOptions struct {
53-
// GetSessionID provides the next session ID to use for an incoming request.
54-
// If nil, a default randomly generated ID will be used.
55-
//
56-
// Session IDs should be globally unique across the scope of the server,
57-
// which may span multiple processes in the case of distributed servers.
58-
//
59-
// As a special case, if GetSessionID returns the empty string, the
60-
// Mcp-Session-Id header will not be set.
61-
GetSessionID func() string
62-
6353
// Stateless controls whether the session is 'stateless'.
6454
//
6555
// A stateless server does not validate the Mcp-Session-Id header, and uses a
@@ -92,9 +82,6 @@ func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *Strea
9282
if opts != nil {
9383
h.opts = *opts
9484
}
95-
if h.opts.GetSessionID == nil {
96-
h.opts.GetSessionID = randText
97-
}
9885
return h
9986
}
10087

@@ -233,7 +220,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
233220
if sessionID == "" {
234221
// In stateless mode, sessionID may be nonempty even if there's no
235222
// existing transport.
236-
sessionID = h.opts.GetSessionID()
223+
sessionID = server.opts.GetSessionID()
237224
}
238225
transport = &StreamableServerTransport{
239226
SessionID: sessionID,

mcp/streamable_test.go

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -409,15 +409,14 @@ func testClientReplay(t *testing.T, test clientReplayTest) {
409409
}
410410

411411
func TestServerTransportCleanup(t *testing.T) {
412-
server := NewServer(testImpl, &ServerOptions{KeepAlive: 10 * time.Millisecond})
413-
414412
nClient := 3
415413

416414
var mu sync.Mutex
417415
var id int = -1 // session id starting from "0", "1", "2"...
418416
chans := make(map[string]chan struct{}, nClient)
419417

420-
handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{
418+
server := NewServer(testImpl, &ServerOptions{
419+
KeepAlive: 10 * time.Millisecond,
421420
GetSessionID: func() string {
422421
mu.Lock()
423422
defer mu.Unlock()
@@ -430,6 +429,7 @@ func TestServerTransportCleanup(t *testing.T) {
430429
},
431430
})
432431

432+
handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)
433433
handler.onTransportDeletion = func(sessionID string) {
434434
chans[sessionID] <- struct{}{}
435435
}
@@ -1201,8 +1201,6 @@ func TestStreamableStateless(t *testing.T) {
12011201
}
12021202
return &CallToolResult{Content: []Content{&TextContent{Text: "hi " + args.Name}}}, nil, nil
12031203
}
1204-
server := NewServer(testImpl, nil)
1205-
AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi)
12061204

12071205
requests := []streamableRequest{
12081206
{
@@ -1265,9 +1263,15 @@ func TestStreamableStateless(t *testing.T) {
12651263
}
12661264
}
12671265

1268-
sessionlessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{
1269-
GetSessionID: func() string { return "" },
1270-
Stateless: true,
1266+
sessionlessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server {
1267+
// Return a stateless server which never assigns a session ID.
1268+
server := NewServer(testImpl, &ServerOptions{
1269+
GetSessionID: func() string { return "" },
1270+
})
1271+
AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi)
1272+
return server
1273+
}, &StreamableHTTPOptions{
1274+
Stateless: true,
12711275
})
12721276

12731277
// First, test the "sessionless" stateless mode, where there is no session ID.
@@ -1281,7 +1285,12 @@ func TestStreamableStateless(t *testing.T) {
12811285
// This can be used by tools to look up application state preserved across
12821286
// subsequent requests.
12831287
requests[0].wantSessionID = true // now expect a session ID for initialize
1284-
statelessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{
1288+
statelessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server {
1289+
// Return a server with default options which should assign a random session ID.
1290+
server := NewServer(testImpl, nil)
1291+
AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi)
1292+
return server
1293+
}, &StreamableHTTPOptions{
12851294
Stateless: true,
12861295
})
12871296
t.Run("stateless", func(t *testing.T) {

0 commit comments

Comments
 (0)