Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,41 @@ import (
"time"
)

// TokenInfo holds information from a bearer token.
type TokenInfo struct {
Scopes []string
Expiration time.Time
// TODO: add standard JWT fields
Extra map[string]any
}

// The error that a TokenVerifier should return if the token cannot be verified.
var ErrInvalidToken = errors.New("invalid token")

// A TokenVerifier checks the validity of a bearer token, and extracts information
// from it. If verification fails, it should return an error that unwraps to ErrInvalidToken.
type TokenVerifier func(ctx context.Context, token string) (*TokenInfo, error)

// RequireBearerTokenOptions are options for [RequireBearerToken].
type RequireBearerTokenOptions struct {
Scopes []string
// The URL for the resource server metadata OAuth flow, to be returned as part
// of the WWW-Authenticate header.
ResourceMetadataURL string
// The required scopes.
Scopes []string
}

var ErrInvalidToken = errors.New("invalid token")

type tokenInfoKey struct{}

// TokenInfoFromContext returns the [TokenInfo] stored in ctx, or nil if none.
func TokenInfoFromContext(ctx context.Context) *TokenInfo {
ti := ctx.Value(tokenInfoKey{})
if ti == nil {
return nil
}
return ti.(*TokenInfo)
}

// RequireBearerToken returns a piece of middleware that verifies a bearer token using the verifier.
// If verification succeeds, the [TokenInfo] is added to the request's context and the request proceeds.
// If verification fails, the request fails with a 401 Unauthenticated, and the WWW-Authenticate header
Expand Down Expand Up @@ -75,7 +94,7 @@ func verify(ctx context.Context, verifier TokenVerifier, opts *RequireBearerToke
return nil, err.Error(), http.StatusInternalServerError
}

// Check scopes.
// Check scopes. All must be present.
if opts != nil {
// Note: quadratic, but N is small.
for _, s := range opts.Scopes {
Expand Down
6 changes: 6 additions & 0 deletions internal/jsonrpc2/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ type Request struct {
Method string
// Params is either a struct or an array with the parameters of the method.
Params json.RawMessage
// Extra is additional information that does not appear on the wire. It can be
// used to pass information from the application to the underlying transport.
Extra any
}

// Response is a Message used as a reply to a call Request.
Expand All @@ -67,6 +70,9 @@ type Response struct {
Error error
// id of the request this is a response to.
ID ID
// Extra is additional information that does not appear on the wire. It can be
// used to pass information from the underlying transport to the application.
Extra any
}

// StringID creates a new string request identifier.
Expand Down
25 changes: 15 additions & 10 deletions mcp/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"strings"
"time"

"github.com/modelcontextprotocol/go-sdk/auth"
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
)
Expand Down Expand Up @@ -132,7 +133,8 @@ func handleReceive[S Session](ctx context.Context, session S, jreq *jsonrpc.Requ
}

mh := session.receivingMethodHandler().(MethodHandler)
req := info.newRequest(session, params)
re, _ := jreq.Extra.(*RequestExtra)
req := info.newRequest(session, params, re)
// mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol.
res, err := mh(ctx, jreq.Method, req)
if err != nil {
Expand Down Expand Up @@ -179,7 +181,7 @@ type methodInfo struct {
// Unmarshal params from the wire into a Params struct.
// Used on the receive side.
unmarshalParams func(json.RawMessage) (Params, error)
newRequest func(Session, Params) Request
newRequest func(Session, Params, *RequestExtra) Request
// Run the code when a call to the method is received.
// Used on the receive side.
handleMethod methodHandler
Expand Down Expand Up @@ -214,7 +216,7 @@ const (

func newClientMethodInfo[P paramsPtr[T], R Result, T any](d typedClientMethodHandler[P, R], flags methodFlags) methodInfo {
mi := newMethodInfo[P, R](flags)
mi.newRequest = func(s Session, p Params) Request {
mi.newRequest = func(s Session, p Params, _ *RequestExtra) Request {
r := &ClientRequest[P]{Session: s.(*ClientSession)}
if p != nil {
r.Params = p.(P)
Expand All @@ -229,19 +231,15 @@ func newClientMethodInfo[P paramsPtr[T], R Result, T any](d typedClientMethodHan

func newServerMethodInfo[P paramsPtr[T], R Result, T any](d typedServerMethodHandler[P, R], flags methodFlags) methodInfo {
mi := newMethodInfo[P, R](flags)
mi.newRequest = func(s Session, p Params) Request {
r := &ServerRequest[P]{Session: s.(*ServerSession)}
mi.newRequest = func(s Session, p Params, re *RequestExtra) Request {
r := &ServerRequest[P]{Session: s.(*ServerSession), Extra: re}
if p != nil {
r.Params = p.(P)
}
return r
}
mi.handleMethod = MethodHandler(func(ctx context.Context, _ string, req Request) (Result, error) {
rf := &ServerRequest[P]{Session: req.GetSession().(*ServerSession)}
if req.GetParams() != nil {
rf.Params = req.GetParams().(P)
}
return d(ctx, rf)
return d(ctx, req.(*ServerRequest[P]))
})
return mi
}
Expand Down Expand Up @@ -397,6 +395,13 @@ type ClientRequest[P Params] struct {
type ServerRequest[P Params] struct {
Session *ServerSession
Params P
Extra *RequestExtra
}

// RequestExtra is extra information included in requests, typically from
// the transport layer.
type RequestExtra struct {
TokenInfo *auth.TokenInfo // bearer token info (e.g. from OAuth) if any
}

func (*ClientRequest[P]) isRequest() {}
Expand Down
12 changes: 11 additions & 1 deletion mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"sync/atomic"
"time"

"github.com/modelcontextprotocol/go-sdk/auth"
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
)
Expand Down Expand Up @@ -494,12 +495,13 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
// This also requires access to the negotiated version, which would either be
// set by the MCP-Protocol-Version header, or would require peeking into the
// session.
incoming, _, err := readBatch(body)
if err != nil {
http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest)
return
}
incoming, _, err := readBatch(body)
requests := make(map[jsonrpc.ID]struct{})
tokenInfo := auth.TokenInfoFromContext(req.Context())
for _, msg := range incoming {
if req, ok := msg.(*jsonrpc.Request); ok {
// Preemptively check that this is a valid request, so that we can fail
Expand All @@ -509,6 +511,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
req.Extra = &RequestExtra{TokenInfo: tokenInfo}
if req.ID.IsValid() {
requests[req.ID] = struct{}{}
}
Expand Down Expand Up @@ -1038,6 +1041,10 @@ func (c *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error
}
}

// testAuth controls whether a fake Authorization header is added to outgoing requests.
// TODO: replace with a better mechanism when client-side auth is in place.
var testAuth = false

// Write implements the [Connection] interface.
func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) error {
if err := c.failure(); err != nil {
Expand All @@ -1055,6 +1062,9 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json, text/event-stream")
if testAuth {
req.Header.Set("Authorization", "Bearer foo")
}
c.setMCPHeaders(req)

resp, err := c.client.Do(req)
Expand Down
49 changes: 49 additions & 0 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/jsonschema-go/jsonschema"
"github.com/modelcontextprotocol/go-sdk/auth"
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
)
Expand Down Expand Up @@ -1038,3 +1039,51 @@ func TestStreamableStateless(t *testing.T) {
// Verify we can make another request without session ID
checkRequest(`{"jsonrpc":"2.0","method":"tools/call","id":2,"params":{"name":"greet","arguments":{"name":"World"}}}`)
}

func TestTokenInfo(t *testing.T) {
defer func(b bool) { testAuth = b }(testAuth)
testAuth = true
ctx := context.Background()

// Create a server with a tool that returns TokenInfo.
tokenInfo := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[struct{}]]) (*CallToolResultFor[any], error) {
return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: fmt.Sprintf("%v", req.Extra.TokenInfo)}}}, nil
}
server := NewServer(testImpl, nil)
AddTool(server, &Tool{Name: "tokenInfo", Description: "return token info"}, tokenInfo)

streamHandler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil)
verifier := func(context.Context, string) (*auth.TokenInfo, error) {
return &auth.TokenInfo{
Scopes: []string{"scope"},
// Expiration is far, far in the future.
Expiration: time.Date(5000, 1, 2, 3, 4, 5, 0, time.UTC),
}, nil
}
handler := auth.RequireBearerToken(verifier, nil)(streamHandler)
httpServer := httptest.NewServer(handler)
defer httpServer.Close()

transport := NewStreamableClientTransport(httpServer.URL, nil)
client := NewClient(testImpl, nil)
session, err := client.Connect(ctx, transport, nil)
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}
defer session.Close()

res, err := session.CallTool(ctx, &CallToolParams{Name: "tokenInfo"})
if err != nil {
t.Fatal(err)
}
if len(res.Content) == 0 {
t.Fatal("missing content")
}
tc, ok := res.Content[0].(*TextContent)
if !ok {
t.Fatal("not TextContent")
}
if g, w := tc.Text, "&{[scope] 5000-01-02 03:04:05 +0000 UTC map[]}"; g != w {
t.Errorf("got %q, want %q", g, w)
}
}
1 change: 1 addition & 0 deletions mcp/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ func newServerTool[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*serverTool
res, err := h(ctx, &ServerRequest[*CallToolParamsFor[In]]{
Session: req.Session,
Params: params,
Extra: req.Extra,
})
// TODO(rfindley): investigate why server errors are embedded in this strange way,
// rather than returned as jsonrpc2 server errors.
Expand Down
3 changes: 1 addition & 2 deletions mcp/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ type serverConnection interface {

// A StdioTransport is a [Transport] that communicates over stdin/stdout using
// newline-delimited JSON.
type StdioTransport struct {
}
type StdioTransport struct{}

// Connect implements the [Transport] interface.
func (*StdioTransport) Connect(context.Context) (Connection, error) {
Expand Down