Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,18 @@ type ClientSession struct {
client *Client
initializeResult *InitializeResult
keepaliveCancel context.CancelFunc
mcpConn Connection
}

func (cs *ClientSession) setConn(c Connection) {
cs.mcpConn = c
}

func (cs *ClientSession) ID() string {
if cs.mcpConn == nil {
return ""
}
return cs.mcpConn.SessionID()
}

// Close performs a graceful close of the connection, preventing new requests
Expand Down
12 changes: 12 additions & 0 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -496,13 +496,25 @@ func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNot
type ServerSession struct {
server *Server
conn *jsonrpc2.Connection
mcpConn Connection
mu sync.Mutex
logLevel LoggingLevel
initializeParams *InitializeParams
initialized bool
keepaliveCancel context.CancelFunc
}

func (ss *ServerSession) setConn(c Connection) {
ss.mcpConn = c
}

func (ss *ServerSession) ID() string {
if ss.mcpConn == nil {
return ""
}
return ss.mcpConn.SessionID()
}

// Ping pings the client.
func (ss *ServerSession) Ping(ctx context.Context, params *PingParams) error {
_, err := handleSend[*emptyResult](ctx, ss, methodPing, params)
Expand Down
3 changes: 3 additions & 0 deletions mcp/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ type methodHandler any // MethodHandler[*ClientSession] | MethodHandler[*ServerS
// A Session is either a ClientSession or a ServerSession.
type Session interface {
*ClientSession | *ServerSession
// ID returns the session ID, or the empty string if there is none.
ID() string

sendingMethodInfos() map[string]methodInfo
receivingMethodInfos() map[string]methodInfo
sendingMethodHandler() methodHandler
Expand Down
6 changes: 6 additions & 0 deletions mcp/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ type sseServerConn struct {
t *SSEServerTransport
}

// TODO(jba): get the session ID. (Not urgent because SSE transports have been removed from the spec.)
func (s sseServerConn) SessionID() string { return "" }

// Read implements jsonrpc2.Reader.
func (s sseServerConn) Read(ctx context.Context) (JSONRPCMessage, error) {
select {
Expand Down Expand Up @@ -518,6 +521,9 @@ type sseClientConn struct {
done chan struct{} // closed when the stream is closed
}

// TODO(jba): get the session ID. (Not urgent because SSE transports have been removed from the spec.)
func (c *sseClientConn) SessionID() string { return "" }

func (c *sseClientConn) isDone() bool {
c.mu.Lock()
defer c.mu.Unlock()
Expand Down
30 changes: 20 additions & 10 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ func NewStreamableServerTransport(sessionID string) *StreamableServerTransport {
}
}

func (t *StreamableServerTransport) SessionID() string {
return t.id
}

// A StreamableServerTransport implements the [Transport] interface for a
// single session.
type StreamableServerTransport struct {
Expand Down Expand Up @@ -331,7 +335,7 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest)
return
}
var requests = make(map[JSONRPCID]struct{})
requests := make(map[JSONRPCID]struct{})
for _, msg := range incoming {
if req, ok := msg.(*JSONRPCRequest); ok && req.ID.IsValid() {
requests[req.ID] = struct{}{}
Expand Down Expand Up @@ -624,20 +628,26 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
}

type streamableClientConn struct {
url string
sessionID string
client *http.Client
incoming chan []byte
done chan struct{}
url string
client *http.Client
incoming chan []byte
done chan struct{}

closeOnce sync.Once
closeErr error

mu sync.Mutex
mu sync.Mutex
_sessionID string
// bodies map[*http.Response]io.Closer
err error
}

func (c *streamableClientConn) SessionID() string {
c.mu.Lock()
defer c.mu.Unlock()
return c._sessionID
}

// Read implements the [Connection] interface.
func (s *streamableClientConn) Read(ctx context.Context) (JSONRPCMessage, error) {
select {
Expand All @@ -658,7 +668,7 @@ func (s *streamableClientConn) Write(ctx context.Context, msg JSONRPCMessage) er
return s.err
}

sessionID := s.sessionID
sessionID := s._sessionID
if sessionID == "" {
// Hold lock for the first request.
defer s.mu.Unlock()
Expand All @@ -681,7 +691,7 @@ func (s *streamableClientConn) Write(ctx context.Context, msg JSONRPCMessage) er

if sessionID == "" {
// locked
s.sessionID = gotSessionID
s._sessionID = gotSessionID
}

return nil
Expand Down Expand Up @@ -753,7 +763,7 @@ func (s *streamableClientConn) Close() error {
if err != nil {
s.closeErr = err
} else {
req.Header.Set("Mcp-Session-Id", s.sessionID)
req.Header.Set("Mcp-Session-Id", s._sessionID)
if _, err := s.client.Do(req); err != nil {
s.closeErr = err
}
Expand Down
8 changes: 7 additions & 1 deletion mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ func TestStreamableTransports(t *testing.T) {
t.Fatalf("client.Connect() failed: %v", err)
}
defer session.Close()

sid := session.ID()
if sid == "" {
t.Error("empty session ID")
}
// 4. The client calls the "greet" tool.
params := &CallToolParams{
Name: "greet",
Expand All @@ -79,6 +82,9 @@ func TestStreamableTransports(t *testing.T) {
if err != nil {
t.Fatalf("CallTool() failed: %v", err)
}
if g := session.ID(); g != sid {
t.Errorf("session ID: got %q, want %q", g, sid)
}

// 5. Verify that the correct response is received.
want := &CallToolResult{
Expand Down
7 changes: 7 additions & 0 deletions mcp/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type Connection interface {
Read(context.Context) (JSONRPCMessage, error)
Write(context.Context, JSONRPCMessage) error
Close() error // may be called concurrently by both peers
SessionID() string
}

// A StdioTransport is a [Transport] that communicates over stdin/stdout using
Expand Down Expand Up @@ -94,6 +95,7 @@ type binder[T handler] interface {

type handler interface {
handle(ctx context.Context, req *JSONRPCRequest) (any, error)
setConn(Connection)
}

func connect[H handler](ctx context.Context, t Transport, b binder[H]) (H, error) {
Expand Down Expand Up @@ -124,6 +126,7 @@ func connect[H handler](ctx context.Context, t Transport, b binder[H]) (H, error
},
})
assert(preempter.conn != nil, "unbound preempter")
h.setConn(conn)
return h, nil
}

Expand Down Expand Up @@ -200,6 +203,8 @@ type loggingConn struct {
w io.Writer
}

func (c *loggingConn) SessionID() string { return c.delegate.SessionID() }

// loggingReader is a stream middleware that logs incoming messages.
func (s *loggingConn) Read(ctx context.Context) (JSONRPCMessage, error) {
msg, err := s.delegate.Read(ctx)
Expand Down Expand Up @@ -285,6 +290,8 @@ func newIOConn(rwc io.ReadWriteCloser) *ioConn {
}
}

func (c *ioConn) SessionID() string { return "" }

// addBatch records a msgBatch for an incoming batch payload.
// It returns an error if batch is malformed, containing previously seen IDs.
//
Expand Down
Loading