Skip to content

Commit 73b8a7f

Browse files
authored
mcp: make request headers available to tools (#333)
Add HTTP request headers to RequestExtra, so tools and other user-defined handlers can access them. Fixes #331.
1 parent 48abccb commit 73b8a7f

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

mcp/shared.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"encoding/json"
1717
"fmt"
1818
"log"
19+
"net/http"
1920
"reflect"
2021
"slices"
2122
"strings"
@@ -400,6 +401,7 @@ type ServerRequest[P Params] struct {
400401
// the transport layer.
401402
type RequestExtra struct {
402403
TokenInfo *auth.TokenInfo // bearer token info (e.g. from OAuth) if any
404+
Header http.Header // header from HTTP request, if any
403405
}
404406

405407
func (*ClientRequest[P]) isRequest() {}

mcp/streamable.go

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -580,10 +580,6 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
580580
// This also requires access to the negotiated version, which would either be
581581
// set by the MCP-Protocol-Version header, or would require peeking into the
582582
// session.
583-
if err != nil {
584-
http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest)
585-
return
586-
}
587583
incoming, _, err := readBatch(body)
588584
if err != nil {
589585
http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest)
@@ -592,17 +588,20 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
592588
requests := make(map[jsonrpc.ID]struct{})
593589
tokenInfo := auth.TokenInfoFromContext(req.Context())
594590
for _, msg := range incoming {
595-
if req, ok := msg.(*jsonrpc.Request); ok {
591+
if jreq, ok := msg.(*jsonrpc.Request); ok {
596592
// Preemptively check that this is a valid request, so that we can fail
597593
// the HTTP request. If we didn't do this, a request with a bad method or
598594
// missing ID could be silently swallowed.
599-
if _, err := checkRequest(req, serverMethodInfos); err != nil {
595+
if _, err := checkRequest(jreq, serverMethodInfos); err != nil {
600596
http.Error(w, err.Error(), http.StatusBadRequest)
601597
return
602598
}
603-
req.Extra = &RequestExtra{TokenInfo: tokenInfo}
604-
if req.IsCall() {
605-
requests[req.ID] = struct{}{}
599+
jreq.Extra = &RequestExtra{
600+
TokenInfo: tokenInfo,
601+
Header: req.Header,
602+
}
603+
if jreq.IsCall() {
604+
requests[jreq.ID] = struct{}{}
606605
}
607606
}
608607
}

0 commit comments

Comments
 (0)