Skip to content

Commit 962f31b

Browse files
FlameHost10Fedor Bushlyacoderabbitai[bot]
authored
fix: use custom session id generator when provided (#715)
* fix: use custom session id generator when provided * fix: respect custom session ID generator for GET requests * Update server/sse.go add sessionID check Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --------- Co-authored-by: Fedor Bushlya <f.bushlya@centraluniversity.ru> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent 7ce32bf commit 962f31b

File tree

2 files changed

+33
-6
lines changed

2 files changed

+33
-6
lines changed

server/sse.go

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ type SSEContextFunc func(ctx context.Context, r *http.Request) context.Context
4747
// function should return the base path (e.g., "/mcp/tenant123").
4848
type DynamicBasePathFunc func(r *http.Request, sessionID string) string
4949

50+
// SessionIDGenFunc is a function that produces a session ID for a new SSE connection.
51+
// It receives the request context and the HTTP request, and should return a session
52+
// identifier (string) or an error.
53+
type SessionIDGenFunc func(ctx context.Context, r *http.Request) (string, error)
54+
5055
func (s *sseSession) SessionID() string {
5156
return s.sessionID
5257
}
@@ -189,6 +194,7 @@ type SSEServer struct {
189194
srv *http.Server
190195
contextFunc SSEContextFunc
191196
dynamicBasePathFunc DynamicBasePathFunc
197+
sessionIDGenFunc SessionIDGenFunc
192198

193199
keepAlive bool
194200
keepAliveInterval time.Duration
@@ -317,6 +323,15 @@ func WithSSEContextFunc(fn SSEContextFunc) SSEOption {
317323
}
318324
}
319325

326+
// WithSessionIDGenerator sets a custom session ID generator. If fn == nil the call is ignored.
327+
func WithSessionIDGenerator(fn SessionIDGenFunc) SSEOption {
328+
return func(s *SSEServer) {
329+
if fn != nil {
330+
s.sessionIDGenFunc = fn
331+
}
332+
}
333+
}
334+
320335
// NewSSEServer creates a new SSE server instance with the given MCP server and options.
321336
func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer {
322337
s := &SSEServer{
@@ -326,6 +341,9 @@ func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer {
326341
useFullURLForMessageEndpoint: true,
327342
keepAlive: false,
328343
keepAliveInterval: 10 * time.Second,
344+
sessionIDGenFunc: func(ctx context.Context, r *http.Request) (string, error) {
345+
return uuid.New().String(), nil
346+
},
329347
}
330348

331349
// Apply all options
@@ -407,7 +425,16 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
407425
return
408426
}
409427

410-
sessionID := uuid.New().String()
428+
sessionID, err := s.sessionIDGenFunc(r.Context(), r)
429+
if err != nil {
430+
http.Error(w, "Failed to create session ID", http.StatusInternalServerError)
431+
return
432+
}
433+
if sessionID == "" {
434+
http.Error(w, "Failed to create session ID", http.StatusInternalServerError)
435+
return
436+
}
437+
411438
session := &sseSession{
412439
done: make(chan struct{}),
413440
eventQueue: make(chan string, 100), // Buffer for events

server/streamable_http.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -531,12 +531,12 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
531531
}
532532

533533
sessionID := r.Header.Get(HeaderKeySessionID)
534-
// the specification didn't say we should validate the session id
535-
534+
// The MCP specification doesn't require validating session ID for GET requests.
535+
// If no session ID is provided by the client, generate one using the configured SessionIdManager
536+
// so that custom session id generators are honored consistently across POST/GET flows.
536537
if sessionID == "" {
537-
// It's a stateless server,
538-
// but the MCP server requires a unique ID for registering, so we use a random one
539-
sessionID = uuid.New().String()
538+
sessionIdManager := s.sessionIdManagerResolver.ResolveSessionIdManager(r)
539+
sessionID = sessionIdManager.Generate()
540540
}
541541

542542
// Get or create session atomically to prevent TOCTOU races

0 commit comments

Comments
 (0)