Skip to content

Commit 94d7a12

Browse files
committed
mcp: pass TokenInfo to server handler
If there is a TokenInfo in the request context of a StreamableServerTransport, then propagate it through to the ServerRequest that is passed to server methods like callTool.
1 parent 6e03217 commit 94d7a12

File tree

6 files changed

+103
-19
lines changed

6 files changed

+103
-19
lines changed

auth/auth.go

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,41 @@ import (
1313
"time"
1414
)
1515

16+
// TokenInfo holds information from a bearer token.
1617
type TokenInfo struct {
1718
Scopes []string
1819
Expiration time.Time
20+
// TODO: add standard JWT fields
21+
Extra map[string]any
1922
}
2023

24+
// The error that a TokenVerifier should return if the token cannot be verified.
25+
var ErrInvalidToken = errors.New("invalid token")
26+
27+
// A TokenVerifier checks the validity of a bearer token, and extracts information
28+
// from it. If verification fails, it should return an error that unwraps to ErrInvalidToken.
2129
type TokenVerifier func(ctx context.Context, token string) (*TokenInfo, error)
2230

31+
// RequireBearerTokenOptions are options for [RequireBearerToken].
2332
type RequireBearerTokenOptions struct {
24-
Scopes []string
33+
// The URL for the resource server metadata OAuth flow, to be returned as part
34+
// of the WWW-Authenticate header.
2535
ResourceMetadataURL string
36+
// The required scopes.
37+
Scopes []string
2638
}
2739

28-
var ErrInvalidToken = errors.New("invalid token")
29-
3040
type tokenInfoKey struct{}
3141

42+
// TokenInfoFromContext returns the [TokenInfo] stored in ctx, or nil if none.
43+
func TokenInfoFromContext(ctx context.Context) *TokenInfo {
44+
ti := ctx.Value(tokenInfoKey{})
45+
if ti == nil {
46+
return nil
47+
}
48+
return ti.(*TokenInfo)
49+
}
50+
3251
// RequireBearerToken returns a piece of middleware that verifies a bearer token using the verifier.
3352
// If verification succeeds, the [TokenInfo] is added to the request's context and the request proceeds.
3453
// If verification fails, the request fails with a 401 Unauthenticated, and the WWW-Authenticate header
@@ -75,7 +94,7 @@ func verify(ctx context.Context, verifier TokenVerifier, opts *RequireBearerToke
7594
return nil, err.Error(), http.StatusInternalServerError
7695
}
7796

78-
// Check scopes.
97+
// Check scopes. All must be present.
7998
if opts != nil {
8099
// Note: quadratic, but N is small.
81100
for _, s := range opts.Scopes {

internal/jsonrpc2/messages.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ type Request struct {
5656
Method string
5757
// Params is either a struct or an array with the parameters of the method.
5858
Params json.RawMessage
59+
// Meta is additional information that does not appear on the wire. It can be
60+
// used to pass information from the application to the underlying transport.
61+
Meta any
5962
}
6063

6164
// Response is a Message used as a reply to a call Request.
@@ -67,6 +70,9 @@ type Response struct {
6770
Error error
6871
// id of the request this is a response to.
6972
ID ID
73+
// Meta is additional information that does not appear on the wire. It can be
74+
// used to pass information from the underlying transport to the application.
75+
Meta any
7076
}
7177

7278
// StringID creates a new string request identifier.

mcp/shared.go

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"strings"
2020
"time"
2121

22+
"github.com/modelcontextprotocol/go-sdk/auth"
2223
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
2324
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
2425
)
@@ -132,7 +133,8 @@ func handleReceive[S Session](ctx context.Context, session S, jreq *jsonrpc.Requ
132133
}
133134

134135
mh := session.receivingMethodHandler().(MethodHandler)
135-
req := info.newRequest(session, params)
136+
ti, _ := jreq.Meta.(*auth.TokenInfo)
137+
req := info.newRequest(session, params, ti)
136138
// mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol.
137139
res, err := mh(ctx, jreq.Method, req)
138140
if err != nil {
@@ -179,7 +181,7 @@ type methodInfo struct {
179181
// Unmarshal params from the wire into a Params struct.
180182
// Used on the receive side.
181183
unmarshalParams func(json.RawMessage) (Params, error)
182-
newRequest func(Session, Params) Request
184+
newRequest func(Session, Params, *auth.TokenInfo) Request
183185
// Run the code when a call to the method is received.
184186
// Used on the receive side.
185187
handleMethod methodHandler
@@ -214,7 +216,7 @@ const (
214216

215217
func newClientMethodInfo[P paramsPtr[T], R Result, T any](d typedClientMethodHandler[P, R], flags methodFlags) methodInfo {
216218
mi := newMethodInfo[P, R](flags)
217-
mi.newRequest = func(s Session, p Params) Request {
219+
mi.newRequest = func(s Session, p Params, _ *auth.TokenInfo) Request {
218220
r := &ClientRequest[P]{Session: s.(*ClientSession)}
219221
if p != nil {
220222
r.Params = p.(P)
@@ -229,19 +231,15 @@ func newClientMethodInfo[P paramsPtr[T], R Result, T any](d typedClientMethodHan
229231

230232
func newServerMethodInfo[P paramsPtr[T], R Result, T any](d typedServerMethodHandler[P, R], flags methodFlags) methodInfo {
231233
mi := newMethodInfo[P, R](flags)
232-
mi.newRequest = func(s Session, p Params) Request {
233-
r := &ServerRequest[P]{Session: s.(*ServerSession)}
234+
mi.newRequest = func(s Session, p Params, ti *auth.TokenInfo) Request {
235+
r := &ServerRequest[P]{Session: s.(*ServerSession), TokenInfo: ti}
234236
if p != nil {
235237
r.Params = p.(P)
236238
}
237239
return r
238240
}
239241
mi.handleMethod = MethodHandler(func(ctx context.Context, _ string, req Request) (Result, error) {
240-
rf := &ServerRequest[P]{Session: req.GetSession().(*ServerSession)}
241-
if req.GetParams() != nil {
242-
rf.Params = req.GetParams().(P)
243-
}
244-
return d(ctx, rf)
242+
return d(ctx, req.(*ServerRequest[P]))
245243
})
246244
return mi
247245
}
@@ -395,8 +393,9 @@ type ClientRequest[P Params] struct {
395393

396394
// A ServerRequest is a request to a server.
397395
type ServerRequest[P Params] struct {
398-
Session *ServerSession
399-
Params P
396+
Session *ServerSession
397+
Params P
398+
TokenInfo *auth.TokenInfo // bearer token info (e.g. from OAuth) if any
400399
}
401400

402401
func (*ClientRequest[P]) isRequest() {}

mcp/streamable.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"sync/atomic"
2222
"time"
2323

24+
"github.com/modelcontextprotocol/go-sdk/auth"
2425
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
2526
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
2627
)
@@ -490,12 +491,13 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
490491
// This also requires access to the negotiated version, which would either be
491492
// set by the MCP-Protocol-Version header, or would require peeking into the
492493
// session.
493-
incoming, _, err := readBatch(body)
494494
if err != nil {
495495
http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest)
496496
return
497497
}
498+
incoming, _, err := readBatch(body)
498499
requests := make(map[jsonrpc.ID]struct{})
500+
tokenInfo := auth.TokenInfoFromContext(req.Context())
499501
for _, msg := range incoming {
500502
if req, ok := msg.(*jsonrpc.Request); ok {
501503
// Preemptively check that this is a valid request, so that we can fail
@@ -505,6 +507,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
505507
http.Error(w, err.Error(), http.StatusBadRequest)
506508
return
507509
}
510+
req.Meta = tokenInfo
508511
if req.ID.IsValid() {
509512
requests[req.ID] = struct{}{}
510513
}
@@ -1036,6 +1039,10 @@ func (c *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error
10361039
}
10371040
}
10381041

1042+
// testAuth controls whether a fake Authorization header is added to outgoing requests.
1043+
// TODO: replace with a better mechanism when client-side auth is in place.
1044+
var testAuth = false
1045+
10391046
// Write implements the [Connection] interface.
10401047
func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) error {
10411048
if err := c.failure(); err != nil {
@@ -1053,6 +1060,9 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
10531060
}
10541061
req.Header.Set("Content-Type", "application/json")
10551062
req.Header.Set("Accept", "application/json, text/event-stream")
1063+
if testAuth {
1064+
req.Header.Set("Authorization", "Bearer foo")
1065+
}
10561066
c.setMCPHeaders(req)
10571067

10581068
resp, err := c.client.Do(req)

mcp/streamable_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"github.com/google/go-cmp/cmp"
2626
"github.com/google/go-cmp/cmp/cmpopts"
2727
"github.com/google/jsonschema-go/jsonschema"
28+
"github.com/modelcontextprotocol/go-sdk/auth"
2829
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
2930
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
3031
)
@@ -894,3 +895,51 @@ func TestStreamableStateless(t *testing.T) {
894895
// Verify we can make another request without session ID
895896
checkRequest(`{"jsonrpc":"2.0","method":"tools/call","id":2,"params":{"name":"greet","arguments":{"name":"World"}}}`)
896897
}
898+
899+
func TestTokenInfo(t *testing.T) {
900+
defer func(b bool) { testAuth = b }(testAuth)
901+
testAuth = true
902+
ctx := context.Background()
903+
904+
// Create a server with a tool that returns TokenInfo.
905+
tokenInfo := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[struct{}]]) (*CallToolResultFor[any], error) {
906+
return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: fmt.Sprintf("%v", req.TokenInfo)}}}, nil
907+
}
908+
server := NewServer(testImpl, nil)
909+
AddTool(server, &Tool{Name: "tokenInfo", Description: "return token info"}, tokenInfo)
910+
911+
streamHandler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil)
912+
verifier := func(context.Context, string) (*auth.TokenInfo, error) {
913+
return &auth.TokenInfo{
914+
Scopes: []string{"scope"},
915+
// Expiration is far, far in the future.
916+
Expiration: time.Date(5000, 1, 2, 3, 4, 5, 0, time.Local),
917+
}, nil
918+
}
919+
handler := auth.RequireBearerToken(verifier, nil)(streamHandler)
920+
httpServer := httptest.NewServer(handler)
921+
defer httpServer.Close()
922+
923+
transport := NewStreamableClientTransport(httpServer.URL, nil)
924+
client := NewClient(testImpl, nil)
925+
session, err := client.Connect(ctx, transport, nil)
926+
if err != nil {
927+
t.Fatalf("client.Connect() failed: %v", err)
928+
}
929+
defer session.Close()
930+
931+
res, err := session.CallTool(ctx, &CallToolParams{Name: "tokenInfo"})
932+
if err != nil {
933+
t.Fatal(err)
934+
}
935+
if len(res.Content) == 0 {
936+
t.Fatal("missing content")
937+
}
938+
tc, ok := res.Content[0].(*TextContent)
939+
if !ok {
940+
t.Fatal("not TextContent")
941+
}
942+
if g, w := tc.Text, "&{[scope] 5000-01-02 03:04:05 -0500 EST map[]}"; g != w {
943+
t.Errorf("got %q, want %q", g, w)
944+
}
945+
}

mcp/tool.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,9 @@ func newServerTool[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*serverTool
6666
}
6767
// TODO(jba): improve copy
6868
res, err := h(ctx, &ServerRequest[*CallToolParamsFor[In]]{
69-
Session: req.Session,
70-
Params: params,
69+
Session: req.Session,
70+
Params: params,
71+
TokenInfo: req.TokenInfo,
7172
})
7273
// TODO(rfindley): investigate why server errors are embedded in this strange way,
7374
// rather than returned as jsonrpc2 server errors.

0 commit comments

Comments
 (0)