Skip to content

Commit 18cd635

Browse files
authored
auth, mcp: add UserID to TokenInfo for session hijacking prevention (#695)
Add a UserID field to auth.TokenInfo that TokenVerifiers can populate from JWT "sub" claims or token introspection. The streamable HTTP transport uses this to bind sessions to users, rejecting requests where the user ID doesn't match the session's original user. Fixes #589
1 parent 6c16aa6 commit 18cd635

File tree

3 files changed

+102
-1
lines changed

3 files changed

+102
-1
lines changed

auth/auth.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ import (
1717
type TokenInfo struct {
1818
Scopes []string
1919
Expiration time.Time
20+
// UserID is an optional identifier for the authenticated user.
21+
// If set by a TokenVerifier, it can be used by transports to prevent
22+
// session hijacking by ensuring that all requests for a given session
23+
// come from the same user.
24+
UserID string
2025
// TODO: add standard JWT fields
2126
Extra map[string]any
2227
}

mcp/streamable.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ type StreamableHTTPHandler struct {
5151
type sessionInfo struct {
5252
session *ServerSession
5353
transport *StreamableServerTransport
54+
// userID is the user ID from the TokenInfo when the session was created.
55+
// If non-empty, subsequent requests must have the same user ID to prevent
56+
// session hijacking.
57+
userID string
5458

5559
// If timeout is set, automatically close the session after an idle period.
5660
timeout time.Duration
@@ -238,6 +242,15 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
238242
http.Error(w, "session not found", http.StatusNotFound)
239243
return
240244
}
245+
// Prevent session hijacking: if the session was created with a user ID,
246+
// verify that subsequent requests come from the same user.
247+
if sessInfo != nil && sessInfo.userID != "" {
248+
tokenInfo := auth.TokenInfoFromContext(req.Context())
249+
if tokenInfo == nil || tokenInfo.UserID != sessInfo.userID {
250+
http.Error(w, "session user mismatch", http.StatusForbidden)
251+
return
252+
}
253+
}
241254
}
242255

243256
if req.Method == http.MethodDelete {
@@ -404,9 +417,16 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
404417
http.Error(w, "failed connection", http.StatusInternalServerError)
405418
return
406419
}
420+
// Capture the user ID from the token info to enable session hijacking
421+
// prevention on subsequent requests.
422+
var userID string
423+
if tokenInfo := auth.TokenInfoFromContext(req.Context()); tokenInfo != nil {
424+
userID = tokenInfo.UserID
425+
}
407426
sessInfo = &sessionInfo{
408427
session: session,
409428
transport: transport,
429+
userID: userID,
410430
}
411431

412432
if stateless {

mcp/streamable_test.go

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1649,11 +1649,87 @@ func TestTokenInfo(t *testing.T) {
16491649
if !ok {
16501650
t.Fatal("not TextContent")
16511651
}
1652-
if g, w := tc.Text, "&{[scope] 5000-01-02 03:04:05 +0000 UTC map[]}"; g != w {
1652+
if g, w := tc.Text, "&{[scope] 5000-01-02 03:04:05 +0000 UTC map[]}"; g != w {
16531653
t.Errorf("got %q, want %q", g, w)
16541654
}
16551655
}
16561656

1657+
func TestSessionHijackingPrevention(t *testing.T) {
1658+
// This test verifies that sessions bound to a user ID cannot be accessed
1659+
// by a different user (session hijacking prevention).
1660+
ctx := context.Background()
1661+
1662+
server := NewServer(testImpl, nil)
1663+
streamHandler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil)
1664+
1665+
// Use the bearer token directly as the user ID. This simulates how a real
1666+
// verifier might extract a user ID from a JWT "sub" claim or introspection.
1667+
verifier := func(_ context.Context, token string, _ *http.Request) (*auth.TokenInfo, error) {
1668+
return &auth.TokenInfo{
1669+
Scopes: []string{"scope"},
1670+
UserID: token,
1671+
Expiration: time.Date(5000, 1, 2, 3, 4, 5, 0, time.UTC),
1672+
}, nil
1673+
}
1674+
handler := auth.RequireBearerToken(verifier, nil)(streamHandler)
1675+
httpServer := httptest.NewServer(mustNotPanic(t, handler))
1676+
defer httpServer.Close()
1677+
1678+
// Helper to send a JSON-RPC request as a given user.
1679+
doRequest := func(msg jsonrpc.Message, sessionID, userID string) *http.Response {
1680+
t.Helper()
1681+
data, _ := jsonrpc2.EncodeMessage(msg)
1682+
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, httpServer.URL, bytes.NewReader(data))
1683+
req.Header.Set("Content-Type", "application/json")
1684+
req.Header.Set("Accept", "application/json, text/event-stream")
1685+
req.Header.Set("Authorization", "Bearer "+userID)
1686+
if sessionID != "" {
1687+
req.Header.Set("Mcp-Session-Id", sessionID)
1688+
}
1689+
resp, err := http.DefaultClient.Do(req)
1690+
if err != nil {
1691+
t.Fatalf("request failed: %v", err)
1692+
}
1693+
return resp
1694+
}
1695+
1696+
// Create a session as user1.
1697+
initReq := &jsonrpc.Request{Method: "initialize", ID: jsonrpc2.Int64ID(1)}
1698+
initReq.Params, _ = json.Marshal(&InitializeParams{
1699+
ProtocolVersion: protocolVersion20250618,
1700+
ClientInfo: &Implementation{Name: "test", Version: "1.0"},
1701+
})
1702+
resp := doRequest(initReq, "", "user1")
1703+
defer resp.Body.Close()
1704+
if resp.StatusCode != http.StatusOK {
1705+
body, _ := io.ReadAll(resp.Body)
1706+
t.Fatalf("initialize failed with status %d: %s", resp.StatusCode, body)
1707+
}
1708+
sessionID := resp.Header.Get("Mcp-Session-Id")
1709+
if sessionID == "" {
1710+
t.Fatal("no session ID in response")
1711+
}
1712+
1713+
pingReq := &jsonrpc.Request{Method: "ping", ID: jsonrpc2.Int64ID(2)}
1714+
pingReq.Params, _ = json.Marshal(&PingParams{})
1715+
1716+
// Try to access the session as user2 - should fail.
1717+
resp2 := doRequest(pingReq, sessionID, "user2")
1718+
defer resp2.Body.Close()
1719+
if resp2.StatusCode != http.StatusForbidden {
1720+
body, _ := io.ReadAll(resp2.Body)
1721+
t.Errorf("expected status %d for user mismatch, got %d: %s", http.StatusForbidden, resp2.StatusCode, body)
1722+
}
1723+
1724+
// Access as original user1 should succeed.
1725+
resp3 := doRequest(pingReq, sessionID, "user1")
1726+
defer resp3.Body.Close()
1727+
if resp3.StatusCode != http.StatusOK {
1728+
body, _ := io.ReadAll(resp3.Body)
1729+
t.Errorf("expected status %d for matching user, got %d: %s", http.StatusOK, resp3.StatusCode, body)
1730+
}
1731+
}
1732+
16571733
func TestStreamableGET(t *testing.T) {
16581734
// This test checks the fix for problematic behavior described in #410:
16591735
// Hanging GET headers should be written immediately, even if there are no

0 commit comments

Comments
 (0)