Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion server/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ type SSEContextFunc func(ctx context.Context, r *http.Request) context.Context
// function should return the base path (e.g., "/mcp/tenant123").
type DynamicBasePathFunc func(r *http.Request, sessionID string) string

// SessionIDGenFunc is a function that produces a session ID for a new SSE connection.
// It receives the request context and the HTTP request, and should return a session
// identifier (string) or an error.
type SessionIDGenFunc func(ctx context.Context, r *http.Request) (string, error)

func (s *sseSession) SessionID() string {
return s.sessionID
}
Expand Down Expand Up @@ -189,6 +194,7 @@ type SSEServer struct {
srv *http.Server
contextFunc SSEContextFunc
dynamicBasePathFunc DynamicBasePathFunc
sessionIDGenFunc SessionIDGenFunc

keepAlive bool
keepAliveInterval time.Duration
Expand Down Expand Up @@ -317,6 +323,15 @@ func WithSSEContextFunc(fn SSEContextFunc) SSEOption {
}
}

// WithSessionIDGenerator sets a custom session ID generator. If fn == nil the call is ignored.
func WithSessionIDGenerator(fn SessionIDGenFunc) SSEOption {
return func(s *SSEServer) {
if fn != nil {
s.sessionIDGenFunc = fn
}
}
}

// NewSSEServer creates a new SSE server instance with the given MCP server and options.
func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer {
s := &SSEServer{
Expand All @@ -326,6 +341,9 @@ func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer {
useFullURLForMessageEndpoint: true,
keepAlive: false,
keepAliveInterval: 10 * time.Second,
sessionIDGenFunc: func(ctx context.Context, r *http.Request) (string, error) {
return uuid.New().String(), nil
},
}

// Apply all options
Expand Down Expand Up @@ -407,7 +425,16 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
return
}

sessionID := uuid.New().String()
sessionID, err := s.sessionIDGenFunc(r.Context(), r)
if err != nil {
http.Error(w, "Failed to create session ID", http.StatusInternalServerError)
return
}
if sessionID == "" {
http.Error(w, "Failed to create session ID", http.StatusInternalServerError)
return
}

session := &sseSession{
done: make(chan struct{}),
eventQueue: make(chan string, 100), // Buffer for events
Expand Down
10 changes: 5 additions & 5 deletions server/streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -531,12 +531,12 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
}

sessionID := r.Header.Get(HeaderKeySessionID)
// the specification didn't say we should validate the session id

// The MCP specification doesn't require validating session ID for GET requests.
// If no session ID is provided by the client, generate one using the configured SessionIdManager
// so that custom session id generators are honored consistently across POST/GET flows.
if sessionID == "" {
// It's a stateless server,
// but the MCP server requires a unique ID for registering, so we use a random one
sessionID = uuid.New().String()
sessionIdManager := s.sessionIdManagerResolver.ResolveSessionIdManager(r)
sessionID = sessionIdManager.Generate()
}

// Get or create session atomically to prevent TOCTOU races
Expand Down
Loading