diff --git a/auth/auth.go b/auth/auth.go index 68873b48..14ad28c7 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -13,22 +13,41 @@ import ( "time" ) +// TokenInfo holds information from a bearer token. type TokenInfo struct { Scopes []string Expiration time.Time + // TODO: add standard JWT fields + Extra map[string]any } +// The error that a TokenVerifier should return if the token cannot be verified. +var ErrInvalidToken = errors.New("invalid token") + +// A TokenVerifier checks the validity of a bearer token, and extracts information +// from it. If verification fails, it should return an error that unwraps to ErrInvalidToken. type TokenVerifier func(ctx context.Context, token string) (*TokenInfo, error) +// RequireBearerTokenOptions are options for [RequireBearerToken]. type RequireBearerTokenOptions struct { - Scopes []string + // The URL for the resource server metadata OAuth flow, to be returned as part + // of the WWW-Authenticate header. ResourceMetadataURL string + // The required scopes. + Scopes []string } -var ErrInvalidToken = errors.New("invalid token") - type tokenInfoKey struct{} +// TokenInfoFromContext returns the [TokenInfo] stored in ctx, or nil if none. +func TokenInfoFromContext(ctx context.Context) *TokenInfo { + ti := ctx.Value(tokenInfoKey{}) + if ti == nil { + return nil + } + return ti.(*TokenInfo) +} + // RequireBearerToken returns a piece of middleware that verifies a bearer token using the verifier. // If verification succeeds, the [TokenInfo] is added to the request's context and the request proceeds. // If verification fails, the request fails with a 401 Unauthenticated, and the WWW-Authenticate header @@ -75,7 +94,7 @@ func verify(ctx context.Context, verifier TokenVerifier, opts *RequireBearerToke return nil, err.Error(), http.StatusInternalServerError } - // Check scopes. + // Check scopes. All must be present. if opts != nil { // Note: quadratic, but N is small. for _, s := range opts.Scopes { diff --git a/internal/jsonrpc2/messages.go b/internal/jsonrpc2/messages.go index 2de3d4f0..9c0d5d69 100644 --- a/internal/jsonrpc2/messages.go +++ b/internal/jsonrpc2/messages.go @@ -56,6 +56,9 @@ type Request struct { Method string // Params is either a struct or an array with the parameters of the method. Params json.RawMessage + // Extra is additional information that does not appear on the wire. It can be + // used to pass information from the application to the underlying transport. + Extra any } // Response is a Message used as a reply to a call Request. @@ -67,6 +70,9 @@ type Response struct { Error error // id of the request this is a response to. ID ID + // Extra is additional information that does not appear on the wire. It can be + // used to pass information from the underlying transport to the application. + Extra any } // StringID creates a new string request identifier. diff --git a/mcp/shared.go b/mcp/shared.go index 608e2aaf..bda631fe 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -19,6 +19,7 @@ import ( "strings" "time" + "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) @@ -126,7 +127,8 @@ func handleReceive[S Session](ctx context.Context, session S, jreq *jsonrpc.Requ } mh := session.receivingMethodHandler() - req := info.newRequest(session, params) + re, _ := jreq.Extra.(*RequestExtra) + req := info.newRequest(session, params, re) // mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol. res, err := mh(ctx, jreq.Method, req) if err != nil { @@ -173,7 +175,7 @@ type methodInfo struct { // Unmarshal params from the wire into a Params struct. // Used on the receive side. unmarshalParams func(json.RawMessage) (Params, error) - newRequest func(Session, Params) Request + newRequest func(Session, Params, *RequestExtra) Request // Run the code when a call to the method is received. // Used on the receive side. handleMethod MethodHandler @@ -208,7 +210,7 @@ const ( func newClientMethodInfo[P paramsPtr[T], R Result, T any](d typedClientMethodHandler[P, R], flags methodFlags) methodInfo { mi := newMethodInfo[P, R](flags) - mi.newRequest = func(s Session, p Params) Request { + mi.newRequest = func(s Session, p Params, _ *RequestExtra) Request { r := &ClientRequest[P]{Session: s.(*ClientSession)} if p != nil { r.Params = p.(P) @@ -223,19 +225,15 @@ func newClientMethodInfo[P paramsPtr[T], R Result, T any](d typedClientMethodHan func newServerMethodInfo[P paramsPtr[T], R Result, T any](d typedServerMethodHandler[P, R], flags methodFlags) methodInfo { mi := newMethodInfo[P, R](flags) - mi.newRequest = func(s Session, p Params) Request { - r := &ServerRequest[P]{Session: s.(*ServerSession)} + mi.newRequest = func(s Session, p Params, re *RequestExtra) Request { + r := &ServerRequest[P]{Session: s.(*ServerSession), Extra: re} if p != nil { r.Params = p.(P) } return r } mi.handleMethod = MethodHandler(func(ctx context.Context, _ string, req Request) (Result, error) { - rf := &ServerRequest[P]{Session: req.GetSession().(*ServerSession)} - if req.GetParams() != nil { - rf.Params = req.GetParams().(P) - } - return d(ctx, rf) + return d(ctx, req.(*ServerRequest[P])) }) return mi } @@ -391,6 +389,13 @@ type ClientRequest[P Params] struct { type ServerRequest[P Params] struct { Session *ServerSession Params P + Extra *RequestExtra +} + +// RequestExtra is extra information included in requests, typically from +// the transport layer. +type RequestExtra struct { + TokenInfo *auth.TokenInfo // bearer token info (e.g. from OAuth) if any } func (*ClientRequest[P]) isRequest() {} diff --git a/mcp/streamable.go b/mcp/streamable.go index 572fe5de..c51f3cc4 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -21,6 +21,7 @@ import ( "sync/atomic" "time" + "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) @@ -579,12 +580,17 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques // This also requires access to the negotiated version, which would either be // set by the MCP-Protocol-Version header, or would require peeking into the // session. + if err != nil { + http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest) + return + } incoming, _, err := readBatch(body) if err != nil { http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest) return } requests := make(map[jsonrpc.ID]struct{}) + tokenInfo := auth.TokenInfoFromContext(req.Context()) for _, msg := range incoming { if req, ok := msg.(*jsonrpc.Request); ok { // Preemptively check that this is a valid request, so that we can fail @@ -594,6 +600,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques http.Error(w, err.Error(), http.StatusBadRequest) return } + req.Extra = &RequestExtra{TokenInfo: tokenInfo} if req.IsCall() { requests[req.ID] = struct{}{} } @@ -1182,6 +1189,10 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return nil } +// testAuth controls whether a fake Authorization header is added to outgoing requests. +// TODO: replace with a better mechanism when client-side auth is in place. +var testAuth = false + func (c *streamableClientConn) setMCPHeaders(req *http.Request) { c.mu.Lock() defer c.mu.Unlock() @@ -1192,6 +1203,9 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request) { if c.sessionID != "" { req.Header.Set(sessionIDHeader, c.sessionID) } + if testAuth { + req.Header.Set("Authorization", "Bearer foo") + } } func (c *streamableClientConn) handleJSON(resp *http.Response) { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 8334bc0d..93eafb4a 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -26,6 +26,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) @@ -1098,3 +1099,51 @@ func textContent(t *testing.T, res *CallToolResult) string { } return text.Text } + +func TestTokenInfo(t *testing.T) { + defer func(b bool) { testAuth = b }(testAuth) + testAuth = true + ctx := context.Background() + + // Create a server with a tool that returns TokenInfo. + tokenInfo := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[struct{}]]) (*CallToolResultFor[any], error) { + return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: fmt.Sprintf("%v", req.Extra.TokenInfo)}}}, nil + } + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "tokenInfo", Description: "return token info"}, tokenInfo) + + streamHandler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + verifier := func(context.Context, string) (*auth.TokenInfo, error) { + return &auth.TokenInfo{ + Scopes: []string{"scope"}, + // Expiration is far, far in the future. + Expiration: time.Date(5000, 1, 2, 3, 4, 5, 0, time.UTC), + }, nil + } + handler := auth.RequireBearerToken(verifier, nil)(streamHandler) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + transport := NewStreamableClientTransport(httpServer.URL, nil) + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + defer session.Close() + + res, err := session.CallTool(ctx, &CallToolParams{Name: "tokenInfo"}) + if err != nil { + t.Fatal(err) + } + if len(res.Content) == 0 { + t.Fatal("missing content") + } + tc, ok := res.Content[0].(*TextContent) + if !ok { + t.Fatal("not TextContent") + } + if g, w := tc.Text, "&{[scope] 5000-01-02 03:04:05 +0000 UTC map[]}"; g != w { + t.Errorf("got %q, want %q", g, w) + } +} diff --git a/mcp/tool.go b/mcp/tool.go index 15f17e11..7173b8a8 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -68,6 +68,7 @@ func newServerTool[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*serverTool res, err := h(ctx, &ServerRequest[*CallToolParamsFor[In]]{ Session: req.Session, Params: params, + Extra: req.Extra, }) // TODO(rfindley): investigate why server errors are embedded in this strange way, // rather than returned as jsonrpc2 server errors. diff --git a/mcp/transport.go b/mcp/transport.go index 8018910b..2bcd8d7d 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -86,8 +86,7 @@ type serverConnection interface { // A StdioTransport is a [Transport] that communicates over stdin/stdout using // newline-delimited JSON. -type StdioTransport struct { -} +type StdioTransport struct{} // Connect implements the [Transport] interface. func (*StdioTransport) Connect(context.Context) (Connection, error) {