diff --git a/docs/protocol.md b/docs/protocol.md index 6d22c4b5..c002ac3b 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -221,6 +221,21 @@ _See [examples/server/distributed](../examples/server/distributed/main.go) for an example using statless mode to implement a server distributed across multiple processes._ +#### Serverless Deployments + +For serverless or short-lived processes, configure +[`StreamableHTTPOptions.SessionStateStore`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#StreamableHTTPOptions.SessionStateStore) +with an implementation of +[`ServerSessionStateStore`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSessionStateStore). +The handler will persist [`ServerSessionState`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSessionState) +whenever it changes, and will automatically restore prior state when a request +arrives carrying an existing `Mcp-Session-Id`. This allows one invocation to +handle initialization while subsequent invocations resume the conversation +without re-running a long-lived server. The SDK provides an in-memory +[`MemoryServerSessionStateStore`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#MemoryServerSessionStateStore) +for testing; production deployments should supply a durable store (for example, +backed by a database or object storage). + ### Custom transports The SDK supports [custom diff --git a/internal/docs/protocol.src.md b/internal/docs/protocol.src.md index b0874ccf..82081ac7 100644 --- a/internal/docs/protocol.src.md +++ b/internal/docs/protocol.src.md @@ -147,6 +147,21 @@ _See [examples/server/distributed](../examples/server/distributed/main.go) for an example using statless mode to implement a server distributed across multiple processes._ +#### Serverless Deployments + +For serverless or short-lived processes, configure +[`StreamableHTTPOptions.SessionStateStore`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#StreamableHTTPOptions.SessionStateStore) +with an implementation of +[`ServerSessionStateStore`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSessionStateStore). +The handler will persist [`ServerSessionState`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSessionState) +whenever it changes, and will automatically restore prior state when a request +arrives carrying an existing `Mcp-Session-Id`. This allows one invocation to +handle initialization while subsequent invocations resume the conversation +without re-running a long-lived server. The SDK provides an in-memory +[`MemoryServerSessionStateStore`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#MemoryServerSessionStateStore) +for testing; production deployments should supply a durable store (for example, +backed by a database or object storage). + ### Custom transports The SDK supports [custom diff --git a/mcp/session_store.go b/mcp/session_store.go new file mode 100644 index 00000000..ba245954 --- /dev/null +++ b/mcp/session_store.go @@ -0,0 +1,91 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "sync" +) + +// ServerSessionStateStore persists server session state across process +// restarts. +// +// Implementations must be safe for concurrent use. +type ServerSessionStateStore interface { + // Load returns the previously saved state for sessionID. A nil result + // indicates that no state is available. + Load(ctx context.Context, sessionID string) (*ServerSessionState, error) + // Save persists the provided state. The state must not be modified after the + // call returns. Passing a nil state is equivalent to Delete. + Save(ctx context.Context, sessionID string, state *ServerSessionState) error + // Delete forgets any state associated with sessionID. This method must not + // return an error if no state is recorded. + Delete(ctx context.Context, sessionID string) error +} + +// MemoryServerSessionStateStore is an in-memory implementation of +// ServerSessionStateStore. +// +// It is primarily intended for testing or simple deployments. +type MemoryServerSessionStateStore struct { + mu sync.RWMutex + states map[string][]byte +} + +// NewMemoryServerSessionStateStore returns a MemoryServerSessionStateStore. +func NewMemoryServerSessionStateStore() *MemoryServerSessionStateStore { + return &MemoryServerSessionStateStore{ + states: make(map[string][]byte), + } +} + +// Load implements ServerSessionStateStore. +func (s *MemoryServerSessionStateStore) Load(ctx context.Context, sessionID string) (*ServerSessionState, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + s.mu.RLock() + data, ok := s.states[sessionID] + s.mu.RUnlock() + if !ok { + return nil, nil + } + var state ServerSessionState + if err := json.Unmarshal(data, &state); err != nil { + return nil, fmt.Errorf("decode server session state: %w", err) + } + return &state, nil +} + +// Save implements ServerSessionStateStore. +func (s *MemoryServerSessionStateStore) Save(ctx context.Context, sessionID string, state *ServerSessionState) error { + if err := ctx.Err(); err != nil { + return err + } + if state == nil { + return s.Delete(ctx, sessionID) + } + data, err := json.Marshal(state) + if err != nil { + return fmt.Errorf("encode server session state: %w", err) + } + s.mu.Lock() + defer s.mu.Unlock() + s.states[sessionID] = data + return nil +} + +// Delete implements ServerSessionStateStore. +func (s *MemoryServerSessionStateStore) Delete(ctx context.Context, sessionID string) error { + if err := ctx.Err(); err != nil { + return err + } + s.mu.Lock() + delete(s.states, sessionID) + s.mu.Unlock() + return nil +} diff --git a/mcp/streamable.go b/mcp/streamable.go index 12e24ffa..281204df 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -150,6 +150,12 @@ type StreamableHTTPOptions struct { // // If SessionTimeout is the zero value, idle sessions are never closed. SessionTimeout time.Duration + + // SessionStateStore enables persisting session state across process + // restarts. When configured, the handler will attempt to restore server + // sessions whose identifiers are unknown to the current process, allowing + // serverless deployments that spin up per-request. + SessionStateStore ServerSessionStateStore } // NewStreamableHTTPHandler returns a new [StreamableHTTPHandler]. @@ -223,19 +229,38 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque return } + logger := ensureLogger(h.opts.Logger) + sessionID := req.Header.Get(sessionIDHeader) - var sessInfo *sessionInfo + var ( + sessInfo *sessionInfo + restoredState *ServerSessionState + ) if sessionID != "" { h.mu.Lock() sessInfo = h.sessions[sessionID] h.mu.Unlock() if sessInfo == nil && !h.opts.Stateless { - // Unless we're in 'stateless' mode, which doesn't perform any Session-ID - // validation, we require that the session ID matches a known session. - // - // In stateless mode, a temporary transport is be created below. - http.Error(w, "session not found", http.StatusNotFound) - return + if store := h.opts.SessionStateStore; store != nil { + state, err := store.Load(req.Context(), sessionID) + if err != nil { + logger.Error("session state load failed", "session_id", sessionID, "error", err) + http.Error(w, "failed to load session state", http.StatusInternalServerError) + return + } + restoredState = state + if state == nil { + http.Error(w, "session not found", http.StatusNotFound) + return + } + } else { + // Unless we're in 'stateless' mode, which doesn't perform any Session-ID + // validation, we require that the session ID matches a known session. + // + // In stateless mode, a temporary transport is be created below. + http.Error(w, "session not found", http.StatusNotFound) + return + } } } @@ -248,6 +273,16 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque // Closing the session also removes it from h.sessions, due to the // onClose callback. sessInfo.session.Close() + } else if restoredState != nil { + // There is no running session, but persisted state exists. Delete it so + // that clients can terminate resumed sessions without an active server. + if store := h.opts.SessionStateStore; store != nil { + if err := store.Delete(req.Context(), sessionID); err != nil && !errors.Is(err, context.Canceled) { + logger.Error("session state delete failed", "session_id", sessionID, "error", err) + http.Error(w, "failed to delete session state", http.StatusInternalServerError) + return + } + } } w.WriteHeader(http.StatusNoContent) return @@ -322,6 +357,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque EventStore: h.opts.EventStore, jsonResponse: h.opts.JSONResponse, logger: h.opts.Logger, + StateStore: h.opts.SessionStateStore, } // Sessions without a session ID are also stateless: there's no way to @@ -329,7 +365,9 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque stateless := h.opts.Stateless || sessionID == "" // To support stateless mode, we initialize the session with a default // state, so that it doesn't reject subsequent requests. - var connectOpts *ServerSessionOptions + connectOpts := &ServerSessionOptions{ + State: restoredState, + } if stateless { // Peek at the body to see if it is initialize or initialized. // We want those to be handled as usual. @@ -374,13 +412,12 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque state.InitializedParams = new(InitializedParams) } state.LogLevel = "info" - connectOpts = &ServerSessionOptions{ - State: state, - } + connectOpts.State = state } else { // Cleanup is only required in stateful mode, as transportation is // not stored in the map otherwise. connectOpts = &ServerSessionOptions{ + State: restoredState, onClose: func() { h.mu.Lock() defer h.mu.Unlock() @@ -391,6 +428,11 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque h.onTransportDeletion(transport.SessionID) } } + if store := h.opts.SessionStateStore; store != nil && transport.SessionID != "" { + if err := store.Delete(context.Background(), transport.SessionID); err != nil { + logger.Error("session state delete failed", "session_id", transport.SessionID, "error", err) + } + } }, } } @@ -487,6 +529,10 @@ type StreamableServerTransport struct { // upon stream resumption. EventStore EventStore + // StateStore receives session state updates so that callers can resume + // sessions across processes. + StateStore ServerSessionStateStore + // jsonResponse, if set, tells the server to prefer to respond to requests // using application/json responses rather than text/event-stream. // @@ -519,6 +565,7 @@ func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, er sessionID: t.SessionID, stateless: t.Stateless, eventStore: t.EventStore, + stateStore: t.StateStore, jsonResponse: t.jsonResponse, logger: ensureLogger(t.logger), // see #556: must be non-nil incoming: make(chan jsonrpc.Message, 10), @@ -543,6 +590,7 @@ type streamableServerConn struct { stateless bool jsonResponse bool eventStore EventStore + stateStore ServerSessionStateStore logger *slog.Logger @@ -579,6 +627,16 @@ func (c *streamableServerConn) SessionID() string { return c.sessionID } +func (c *streamableServerConn) sessionUpdated(state ServerSessionState) { + if c.stateStore == nil || c.sessionID == "" || c.stateless { + return + } + stateCopy := state + if err := c.stateStore.Save(context.Background(), c.sessionID, &stateCopy); err != nil { + c.logger.Error("session state save failed", "session_id", c.sessionID, "error", err) + } +} + // A stream is a single logical stream of SSE events within a server session. // A stream begins with a client request, or with a client GET that has // no Last-Event-ID header. diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 0579f0cb..b8f8968a 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1444,6 +1444,158 @@ func TestStreamableStateless(t *testing.T) { }) } +func TestStreamableSessionStatePersistence(t *testing.T) { + type noopArgs struct{} + + newServer := func() *Server { + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "noop", Description: "no op"}, func(ctx context.Context, req *CallToolRequest, _ noopArgs) (*CallToolResult, any, error) { + return &CallToolResult{Content: []Content{&TextContent{Text: "ok"}}}, nil, nil + }) + return server + } + + store := NewMemoryServerSessionStateStore() + + server1 := newServer() + handler1 := NewStreamableHTTPHandler(func(*http.Request) *Server { return server1 }, &StreamableHTTPOptions{ + SessionStateStore: store, + }) + httpServer1 := httptest.NewServer(mustNotPanic(t, handler1)) + defer httpServer1.Close() + + call := func(serverURL, sessionID string, request streamableRequest) (string, int, []jsonrpc.Message, []byte, error) { + out := make(chan jsonrpc.Message, 10) + newSessionID, status, body, err := request.do(context.Background(), serverURL, sessionID, out) + var msgs []jsonrpc.Message + for msg := range out { + msgs = append(msgs, msg) + } + return newSessionID, status, msgs, body, err + } + + initializeRequest := streamableRequest{ + method: http.MethodPost, + messages: []jsonrpc.Message{req(1, methodInitialize, &InitializeParams{})}, + } + + sessionID, status, msgs, _, err := call(httpServer1.URL, "", initializeRequest) + if err != nil { + t.Fatalf("initialize request failed: %v", err) + } + if status != http.StatusOK { + t.Fatalf("initialize status = %d, want %d", status, http.StatusOK) + } + if len(msgs) != 1 { + t.Fatalf("initialize response count = %d, want 1", len(msgs)) + } + initResp, ok := msgs[0].(*jsonrpc.Response) + if !ok { + t.Fatalf("initialize response is %T, want *jsonrpc.Response", msgs[0]) + } + if initResp.Error != nil { + t.Fatalf("initialize response returned error: %+v", initResp.Error) + } + if sessionID == "" { + t.Fatal("initialize response missing session id") + } + + waitForState := func(check func(*ServerSessionState) bool) *ServerSessionState { + deadline := time.Now().Add(time.Second) + for time.Now().Before(deadline) { + st, err := store.Load(context.Background(), sessionID) + if err != nil { + t.Fatalf("Load failed: %v", err) + } + if check(st) { + return st + } + time.Sleep(10 * time.Millisecond) + } + st, _ := store.Load(context.Background(), sessionID) + t.Fatalf("timed out waiting for session state, last=%+v", st) + return nil + } + + waitForState(func(state *ServerSessionState) bool { + return state != nil && state.InitializeParams != nil + }) + + initializedRequest := streamableRequest{ + method: http.MethodPost, + messages: []jsonrpc.Message{req(0, notificationInitialized, &InitializedParams{})}, + } + _, status, _, _, err = call(httpServer1.URL, sessionID, initializedRequest) + if err != nil { + t.Fatalf("initialized notification failed: %v", err) + } + if status != http.StatusAccepted { + t.Fatalf("initialized status = %d, want %d", status, http.StatusAccepted) + } + + waitForState(func(state *ServerSessionState) bool { + return state != nil && state.InitializedParams != nil + }) + + httpServer1.Close() + + server2 := newServer() + handler2 := NewStreamableHTTPHandler(func(*http.Request) *Server { return server2 }, &StreamableHTTPOptions{ + SessionStateStore: store, + }) + httpServer2 := httptest.NewServer(mustNotPanic(t, handler2)) + defer httpServer2.Close() + + listRequest := streamableRequest{ + method: http.MethodPost, + messages: []jsonrpc.Message{req(2, "tools/list", &ListToolsParams{})}, + } + _, status, msgs, _, err = call(httpServer2.URL, sessionID, listRequest) + if err != nil { + t.Fatalf("list request failed: %v", err) + } + if status != http.StatusOK { + t.Fatalf("list status = %d, want %d", status, http.StatusOK) + } + if len(msgs) != 1 { + t.Fatalf("list response count = %d, want 1", len(msgs)) + } + listResp, ok := msgs[0].(*jsonrpc.Response) + if !ok { + t.Fatalf("list response is %T, want *jsonrpc.Response", msgs[0]) + } + if listResp.Error != nil { + t.Fatalf("list response returned error: %+v", listResp.Error) + } + var listResult ListToolsResult + if err := json.Unmarshal(listResp.Result, &listResult); err != nil { + t.Fatalf("decoding list result: %v", err) + } + if len(listResult.Tools) != 1 { + t.Fatalf("list result tools len = %d, want 1", len(listResult.Tools)) + } + if listResult.Tools[0].Name != "noop" { + t.Fatalf("list result tool name = %q, want %q", listResult.Tools[0].Name, "noop") + } + + deleteReq, err := http.NewRequest(http.MethodDelete, httpServer2.URL, nil) + if err != nil { + t.Fatalf("creating delete request: %v", err) + } + deleteReq.Header.Set(sessionIDHeader, sessionID) + deleteReq.Header.Set("Accept", "application/json, text/event-stream") + deleteResp, err := http.DefaultClient.Do(deleteReq) + if err != nil { + t.Fatalf("delete request failed: %v", err) + } + defer deleteResp.Body.Close() + if deleteResp.StatusCode != http.StatusNoContent { + t.Fatalf("delete status = %d, want %d", deleteResp.StatusCode, http.StatusNoContent) + } + + waitForState(func(state *ServerSessionState) bool { return state == nil }) +} + func textContent(t *testing.T, res *CallToolResult) string { t.Helper() if len(res.Content) != 1 {