Skip to content

Commit 508b132

Browse files
committed
mcp: pass TokenInfo to server handler
If there is a TokenInfo in the request context of a StreamableServerTransport, then propagate it through to the ServerRequest that is passed to server methods like callTool.
1 parent 1afdb1f commit 508b132

File tree

6 files changed

+107
-17
lines changed

6 files changed

+107
-17
lines changed

auth/auth.go

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,41 @@ import (
1313
"time"
1414
)
1515

16+
// TokenInfo holds information from a bearer token.
1617
type TokenInfo struct {
1718
Scopes []string
1819
Expiration time.Time
20+
// TODO: add standard JWT fields
21+
Extra map[string]any
1922
}
2023

24+
// The error that a TokenVerifier should return if the token cannot be verified.
25+
var ErrInvalidToken = errors.New("invalid token")
26+
27+
// A TokenVerifier checks the validity of a bearer token, and extracts information
28+
// from it. If verification fails, it should return an error that unwraps to ErrInvalidToken.
2129
type TokenVerifier func(ctx context.Context, token string) (*TokenInfo, error)
2230

31+
// RequireBearerTokenOptions are options for [RequireBearerToken].
2332
type RequireBearerTokenOptions struct {
24-
Scopes []string
33+
// The URL for the resource server metadata OAuth flow, to be returned as part
34+
// of the WWW-Authenticate header.
2535
ResourceMetadataURL string
36+
// The required scopes.
37+
Scopes []string
2638
}
2739

28-
var ErrInvalidToken = errors.New("invalid token")
29-
3040
type tokenInfoKey struct{}
3141

42+
// TokenInfoFromContext returns the [TokenInfo] stored in ctx, or nil if none.
43+
func TokenInfoFromContext(ctx context.Context) *TokenInfo {
44+
ti := ctx.Value(tokenInfoKey{})
45+
if ti == nil {
46+
return nil
47+
}
48+
return ti.(*TokenInfo)
49+
}
50+
3251
// RequireBearerToken returns a piece of middleware that verifies a bearer token using the verifier.
3352
// If verification succeeds, the [TokenInfo] is added to the request's context and the request proceeds.
3453
// If verification fails, the request fails with a 401 Unauthenticated, and the WWW-Authenticate header
@@ -75,7 +94,7 @@ func verify(ctx context.Context, verifier TokenVerifier, opts *RequireBearerToke
7594
return nil, err.Error(), http.StatusInternalServerError
7695
}
7796

78-
// Check scopes.
97+
// Check scopes. All must be present.
7998
if opts != nil {
8099
// Note: quadratic, but N is small.
81100
for _, s := range opts.Scopes {

internal/jsonrpc2/messages.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ type Request struct {
5656
Method string
5757
// Params is either a struct or an array with the parameters of the method.
5858
Params json.RawMessage
59+
// Extra is additional information that does not appear on the wire. It can be
60+
// used to pass information from the application to the underlying transport.
61+
Extra any
5962
}
6063

6164
// Response is a Message used as a reply to a call Request.
@@ -67,6 +70,9 @@ type Response struct {
6770
Error error
6871
// id of the request this is a response to.
6972
ID ID
73+
// Extra is additional information that does not appear on the wire. It can be
74+
// used to pass information from the underlying transport to the application.
75+
Extra any
7076
}
7177

7278
// StringID creates a new string request identifier.

mcp/shared.go

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"strings"
2020
"time"
2121

22+
"github.com/modelcontextprotocol/go-sdk/auth"
2223
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
2324
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
2425
)
@@ -132,7 +133,8 @@ func handleReceive[S Session](ctx context.Context, session S, jreq *jsonrpc.Requ
132133
}
133134

134135
mh := session.receivingMethodHandler().(MethodHandler)
135-
req := info.newRequest(session, params)
136+
ti, _ := jreq.Extra.(*RequestExtra)
137+
req := info.newRequest(session, params, ti)
136138
// mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol.
137139
res, err := mh(ctx, jreq.Method, req)
138140
if err != nil {
@@ -179,7 +181,7 @@ type methodInfo struct {
179181
// Unmarshal params from the wire into a Params struct.
180182
// Used on the receive side.
181183
unmarshalParams func(json.RawMessage) (Params, error)
182-
newRequest func(Session, Params) Request
184+
newRequest func(Session, Params, *auth.TokenInfo) Request
183185
// Run the code when a call to the method is received.
184186
// Used on the receive side.
185187
handleMethod methodHandler
@@ -214,7 +216,7 @@ const (
214216

215217
func newClientMethodInfo[P paramsPtr[T], R Result, T any](d typedClientMethodHandler[P, R], flags methodFlags) methodInfo {
216218
mi := newMethodInfo[P, R](flags)
217-
mi.newRequest = func(s Session, p Params) Request {
219+
mi.newRequest = func(s Session, p Params, _ *auth.TokenInfo) Request {
218220
r := &ClientRequest[P]{Session: s.(*ClientSession)}
219221
if p != nil {
220222
r.Params = p.(P)
@@ -229,19 +231,15 @@ func newClientMethodInfo[P paramsPtr[T], R Result, T any](d typedClientMethodHan
229231

230232
func newServerMethodInfo[P paramsPtr[T], R Result, T any](d typedServerMethodHandler[P, R], flags methodFlags) methodInfo {
231233
mi := newMethodInfo[P, R](flags)
232-
mi.newRequest = func(s Session, p Params) Request {
233-
r := &ServerRequest[P]{Session: s.(*ServerSession)}
234+
mi.newRequest = func(s Session, p Params, ti *auth.TokenInfo) Request {
235+
r := &ServerRequest[P]{Session: s.(*ServerSession), Extra: RequestExtra{TokenInfo: ti}}
234236
if p != nil {
235237
r.Params = p.(P)
236238
}
237239
return r
238240
}
239241
mi.handleMethod = MethodHandler(func(ctx context.Context, _ string, req Request) (Result, error) {
240-
rf := &ServerRequest[P]{Session: req.GetSession().(*ServerSession)}
241-
if req.GetParams() != nil {
242-
rf.Params = req.GetParams().(P)
243-
}
244-
return d(ctx, rf)
242+
return d(ctx, req.(*ServerRequest[P]))
245243
})
246244
return mi
247245
}
@@ -397,6 +395,13 @@ type ClientRequest[P Params] struct {
397395
type ServerRequest[P Params] struct {
398396
Session *ServerSession
399397
Params P
398+
Extra RequestExtra
399+
}
400+
401+
// RequestExtra is extra information included in requests, typically from
402+
// the transport layer.
403+
type RequestExtra struct {
404+
TokenInfo *auth.TokenInfo // bearer token info (e.g. from OAuth) if any
400405
}
401406

402407
func (*ClientRequest[P]) isRequest() {}

mcp/streamable.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"sync/atomic"
2222
"time"
2323

24+
"github.com/modelcontextprotocol/go-sdk/auth"
2425
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
2526
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
2627
)
@@ -494,12 +495,13 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
494495
// This also requires access to the negotiated version, which would either be
495496
// set by the MCP-Protocol-Version header, or would require peeking into the
496497
// session.
497-
incoming, _, err := readBatch(body)
498498
if err != nil {
499499
http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest)
500500
return
501501
}
502+
incoming, _, err := readBatch(body)
502503
requests := make(map[jsonrpc.ID]struct{})
504+
tokenInfo := auth.TokenInfoFromContext(req.Context())
503505
for _, msg := range incoming {
504506
if req, ok := msg.(*jsonrpc.Request); ok {
505507
// Preemptively check that this is a valid request, so that we can fail
@@ -509,6 +511,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
509511
http.Error(w, err.Error(), http.StatusBadRequest)
510512
return
511513
}
514+
req.Extra = tokenInfo
512515
if req.ID.IsValid() {
513516
requests[req.ID] = struct{}{}
514517
}
@@ -1038,6 +1041,10 @@ func (c *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error
10381041
}
10391042
}
10401043

1044+
// testAuth controls whether a fake Authorization header is added to outgoing requests.
1045+
// TODO: replace with a better mechanism when client-side auth is in place.
1046+
var testAuth = false
1047+
10411048
// Write implements the [Connection] interface.
10421049
func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) error {
10431050
if err := c.failure(); err != nil {
@@ -1055,6 +1062,9 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
10551062
}
10561063
req.Header.Set("Content-Type", "application/json")
10571064
req.Header.Set("Accept", "application/json, text/event-stream")
1065+
if testAuth {
1066+
req.Header.Set("Authorization", "Bearer foo")
1067+
}
10581068
c.setMCPHeaders(req)
10591069

10601070
resp, err := c.client.Do(req)

mcp/streamable_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"github.com/google/go-cmp/cmp"
2727
"github.com/google/go-cmp/cmp/cmpopts"
2828
"github.com/google/jsonschema-go/jsonschema"
29+
"github.com/modelcontextprotocol/go-sdk/auth"
2930
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
3031
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
3132
)
@@ -1038,3 +1039,51 @@ func TestStreamableStateless(t *testing.T) {
10381039
// Verify we can make another request without session ID
10391040
checkRequest(`{"jsonrpc":"2.0","method":"tools/call","id":2,"params":{"name":"greet","arguments":{"name":"World"}}}`)
10401041
}
1042+
1043+
func TestTokenInfo(t *testing.T) {
1044+
defer func(b bool) { testAuth = b }(testAuth)
1045+
testAuth = true
1046+
ctx := context.Background()
1047+
1048+
// Create a server with a tool that returns TokenInfo.
1049+
tokenInfo := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[struct{}]]) (*CallToolResultFor[any], error) {
1050+
return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: fmt.Sprintf("%v", req.TokenInfo)}}}, nil
1051+
}
1052+
server := NewServer(testImpl, nil)
1053+
AddTool(server, &Tool{Name: "tokenInfo", Description: "return token info"}, tokenInfo)
1054+
1055+
streamHandler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil)
1056+
verifier := func(context.Context, string) (*auth.TokenInfo, error) {
1057+
return &auth.TokenInfo{
1058+
Scopes: []string{"scope"},
1059+
// Expiration is far, far in the future.
1060+
Expiration: time.Date(5000, 1, 2, 3, 4, 5, 0, time.UTC),
1061+
}, nil
1062+
}
1063+
handler := auth.RequireBearerToken(verifier, nil)(streamHandler)
1064+
httpServer := httptest.NewServer(handler)
1065+
defer httpServer.Close()
1066+
1067+
transport := NewStreamableClientTransport(httpServer.URL, nil)
1068+
client := NewClient(testImpl, nil)
1069+
session, err := client.Connect(ctx, transport, nil)
1070+
if err != nil {
1071+
t.Fatalf("client.Connect() failed: %v", err)
1072+
}
1073+
defer session.Close()
1074+
1075+
res, err := session.CallTool(ctx, &CallToolParams{Name: "tokenInfo"})
1076+
if err != nil {
1077+
t.Fatal(err)
1078+
}
1079+
if len(res.Content) == 0 {
1080+
t.Fatal("missing content")
1081+
}
1082+
tc, ok := res.Content[0].(*TextContent)
1083+
if !ok {
1084+
t.Fatal("not TextContent")
1085+
}
1086+
if g, w := tc.Text, "&{[scope] 5000-01-02 03:04:05 +0000 UTC map[]}"; g != w {
1087+
t.Errorf("got %q, want %q", g, w)
1088+
}
1089+
}

mcp/tool.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,9 @@ func newServerTool[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*serverTool
6666
}
6767
// TODO(jba): improve copy
6868
res, err := h(ctx, &ServerRequest[*CallToolParamsFor[In]]{
69-
Session: req.Session,
70-
Params: params,
69+
Session: req.Session,
70+
Params: params,
71+
TokenInfo: req.TokenInfo,
7172
})
7273
// TODO(rfindley): investigate why server errors are embedded in this strange way,
7374
// rather than returned as jsonrpc2 server errors.

0 commit comments

Comments
 (0)