Skip to content

Commit 5774a37

Browse files
authored
mcp: properly handle missing session ID header (#1366)
**Description** Return a 400 Bad Request when the client does not send the session ID header, according to the spec [1]. **Related Issues/PRs (if applicable)** Fixes #1364 **Special notes for reviewers (if applicable)** N/A 1: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#session-management --------- Signed-off-by: Ignasi Barrera <[email protected]>
1 parent 29905d5 commit 5774a37

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

internal/mcpproxy/handlers.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,20 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) {
154154
onErrorResponse(w, http.StatusBadRequest, fmt.Sprintf("invalid JSON-RPC message: %v", err))
155155
return
156156
}
157+
157158
switch msg := rawMsg.(type) {
158159
case *jsonrpc.Response:
159160
if str, ok := msg.ID.Raw().(string); ok && strings.HasPrefix(str, envoyAIGatewayServerToClientPingRequestIDPrefix) {
160161
w.Header().Set(sessionIDHeader, string(s.clientGatewaySessionID()))
161162
w.WriteHeader(http.StatusAccepted)
162163
} else {
164+
// We do require a Session ID. If it is not present, a 400 Bad Request response should be returned:
165+
// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#session-management
166+
if s == nil {
167+
errType = metrics.MCPErrorInvalidSessionID
168+
onErrorResponse(w, http.StatusBadRequest, "missing session ID")
169+
return
170+
}
163171
m.l.Debug("Decoded MCP response", slog.Any("response", msg))
164172
err = m.handleClientToServerResponse(ctx, s, w, msg)
165173
}
@@ -169,6 +177,16 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) {
169177
m.l.Debug("Decoded MCP request",
170178
slog.Any("id", msg.ID), slog.String("method", msg.Method), slog.String("params", string(msg.Params)))
171179
}
180+
181+
// We do require a Session ID. If it is not present for requests other than initialize,
182+
// a 400 Bad Request response should be returned:
183+
// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#session-management
184+
if s == nil && msg.Method != "initialize" {
185+
errType = metrics.MCPErrorInvalidSessionID
186+
onErrorResponse(w, http.StatusBadRequest, "missing session ID")
187+
return
188+
}
189+
172190
switch msg.Method {
173191
case "notifications/roots/list_changed":
174192
p := &mcp.RootsListChangedParams{}

internal/mcpproxy/handlers_test.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import (
3232
"github.com/envoyproxy/ai-gateway/internal/internalapi"
3333
"github.com/envoyproxy/ai-gateway/internal/metrics"
3434
"github.com/envoyproxy/ai-gateway/internal/testing/testotel"
35-
tracing "github.com/envoyproxy/ai-gateway/internal/tracing"
35+
"github.com/envoyproxy/ai-gateway/internal/tracing"
3636
tracingapi "github.com/envoyproxy/ai-gateway/internal/tracing/api"
3737
)
3838

@@ -164,6 +164,15 @@ func TestServePOST_InvalidSessionID(t *testing.T) {
164164
require.Contains(t, rr.Body.String(), "invalid session ID")
165165
}
166166

167+
func TestServePOST_MissingSessionID(t *testing.T) {
168+
proxy := newTestMCPProxy()
169+
req := httptest.NewRequest(http.MethodPost, "/mcp", strings.NewReader(`{"jsonrpc":"2.0","method":"tools/call","params":{"name":"test-tool"},"id":"1"}`))
170+
rr := httptest.NewRecorder()
171+
proxy.servePOST(rr, req)
172+
require.Equal(t, http.StatusBadRequest, rr.Code)
173+
require.Contains(t, rr.Body.String(), "missing session ID")
174+
}
175+
167176
func TestServePOST_InitializeRequest(t *testing.T) {
168177
// Create a test server to simulate the mcp backend listener.
169178
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -611,6 +620,7 @@ func TestServePOST_UnsupportedMethod(t *testing.T) {
611620
require.NoError(t, err)
612621

613622
httpReq := httptest.NewRequest(http.MethodPost, "/mcp", bytes.NewReader(body))
623+
httpReq.Header.Set(sessionIDHeader, secureID(t, proxy, "test-route@@backend1:dGVzdC1zZXNzaW9u")) // "test-session" base64 encoded.
614624
rr := httptest.NewRecorder()
615625

616626
proxy.servePOST(rr, httpReq)
@@ -803,6 +813,7 @@ func TestServePOST_NotificationsInitialized(t *testing.T) {
803813
require.NoError(t, err)
804814

805815
httpReq := httptest.NewRequest(http.MethodPost, "/mcp", bytes.NewReader(body))
816+
httpReq.Header.Set(sessionIDHeader, secureID(t, proxy, "test-route@@backend1:dGVzdC1zZXNzaW9u")) // "test-session" base64 encoded.
806817
rr := httptest.NewRecorder()
807818

808819
proxy.servePOST(rr, httpReq)
@@ -883,6 +894,7 @@ func TestServePOST_InvalidPromptsGetParams(t *testing.T) {
883894
require.NoError(t, err)
884895

885896
httpReq := httptest.NewRequest(http.MethodPost, "/mcp", bytes.NewReader(body))
897+
httpReq.Header.Set(sessionIDHeader, secureID(t, proxy, "test-route@@backend1:dGVzdC1zZXNzaW9u")) // "test-session" base64 encoded.
886898
rr := httptest.NewRecorder()
887899

888900
proxy.servePOST(rr, httpReq)

0 commit comments

Comments
 (0)