Skip to content

Commit bec773f

Browse files
committed
mcp: add Session.ID
Support retrieving the session ID from client and server sessions. For modelcontextprotocol#65.
1 parent c47dbcd commit bec773f

File tree

7 files changed

+67
-11
lines changed

7 files changed

+67
-11
lines changed

mcp/client.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,18 @@ type ClientSession struct {
142142
client *Client
143143
initializeResult *InitializeResult
144144
keepaliveCancel context.CancelFunc
145+
sessionIDFunc func() string
146+
}
147+
148+
func (cs *ClientSession) setSessionIDFunc(f func() string) {
149+
cs.sessionIDFunc = f
150+
}
151+
152+
func (cs *ClientSession) ID() string {
153+
if cs.sessionIDFunc == nil {
154+
return ""
155+
}
156+
return cs.sessionIDFunc()
145157
}
146158

147159
// Close performs a graceful close of the connection, preventing new requests

mcp/server.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,13 +496,25 @@ func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNot
496496
type ServerSession struct {
497497
server *Server
498498
conn *jsonrpc2.Connection
499+
sessionIDFunc func() string
499500
mu sync.Mutex
500501
logLevel LoggingLevel
501502
initializeParams *InitializeParams
502503
initialized bool
503504
keepaliveCancel context.CancelFunc
504505
}
505506

507+
func (ss *ServerSession) setSessionIDFunc(f func() string) {
508+
ss.sessionIDFunc = f
509+
}
510+
511+
func (ss *ServerSession) ID() string {
512+
if ss.sessionIDFunc == nil {
513+
return ""
514+
}
515+
return ss.sessionIDFunc()
516+
}
517+
506518
// Ping pings the client.
507519
func (ss *ServerSession) Ping(ctx context.Context, params *PingParams) error {
508520
_, err := handleSend[*emptyResult](ctx, ss, methodPing, params)

mcp/shared.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ type methodHandler any // MethodHandler[*ClientSession] | MethodHandler[*ServerS
3636
// A Session is either a ClientSession or a ServerSession.
3737
type Session interface {
3838
*ClientSession | *ServerSession
39+
// ID returns the session ID, or the empty string if there is none.
40+
ID() string
41+
3942
sendingMethodInfos() map[string]methodInfo
4043
receivingMethodInfos() map[string]methodInfo
4144
sendingMethodHandler() methodHandler

mcp/sse.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,9 @@ type sseServerConn struct {
263263
t *SSEServerTransport
264264
}
265265

266+
// TODO(jba): get the session ID. (Not urgent because SSE transports have been removed from the spec.)
267+
func (s sseServerConn) sessionID() string { return "" }
268+
266269
// Read implements jsonrpc2.Reader.
267270
func (s sseServerConn) Read(ctx context.Context) (JSONRPCMessage, error) {
268271
select {
@@ -518,6 +521,9 @@ type sseClientConn struct {
518521
done chan struct{} // closed when the stream is closed
519522
}
520523

524+
// TODO(jba): get the session ID. (Not urgent because SSE transports have been removed from the spec.)
525+
func (c *sseClientConn) sessionID() string { return "" }
526+
521527
func (c *sseClientConn) isDone() bool {
522528
c.mu.Lock()
523529
defer c.mu.Unlock()

mcp/streamable.go

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ func NewStreamableServerTransport(sessionID string) *StreamableServerTransport {
161161
}
162162
}
163163

164+
func (t *StreamableServerTransport) sessionID() string {
165+
return t.id
166+
}
167+
164168
// A StreamableServerTransport implements the [Transport] interface for a
165169
// single session.
166170
type StreamableServerTransport struct {
@@ -331,7 +335,7 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
331335
http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest)
332336
return
333337
}
334-
var requests = make(map[JSONRPCID]struct{})
338+
requests := make(map[JSONRPCID]struct{})
335339
for _, msg := range incoming {
336340
if req, ok := msg.(*JSONRPCRequest); ok && req.ID.IsValid() {
337341
requests[req.ID] = struct{}{}
@@ -624,20 +628,26 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
624628
}
625629

626630
type streamableClientConn struct {
627-
url string
628-
sessionID string
629-
client *http.Client
630-
incoming chan []byte
631-
done chan struct{}
631+
url string
632+
client *http.Client
633+
incoming chan []byte
634+
done chan struct{}
632635

633636
closeOnce sync.Once
634637
closeErr error
635638

636-
mu sync.Mutex
639+
mu sync.Mutex
640+
_sessionID string
637641
// bodies map[*http.Response]io.Closer
638642
err error
639643
}
640644

645+
func (c *streamableClientConn) sessionID() string {
646+
c.mu.Lock()
647+
defer c.mu.Unlock()
648+
return c._sessionID
649+
}
650+
641651
// Read implements the [Connection] interface.
642652
func (s *streamableClientConn) Read(ctx context.Context) (JSONRPCMessage, error) {
643653
select {
@@ -658,7 +668,7 @@ func (s *streamableClientConn) Write(ctx context.Context, msg JSONRPCMessage) er
658668
return s.err
659669
}
660670

661-
sessionID := s.sessionID
671+
sessionID := s._sessionID
662672
if sessionID == "" {
663673
// Hold lock for the first request.
664674
defer s.mu.Unlock()
@@ -681,7 +691,7 @@ func (s *streamableClientConn) Write(ctx context.Context, msg JSONRPCMessage) er
681691

682692
if sessionID == "" {
683693
// locked
684-
s.sessionID = gotSessionID
694+
s._sessionID = gotSessionID
685695
}
686696

687697
return nil
@@ -753,7 +763,7 @@ func (s *streamableClientConn) Close() error {
753763
if err != nil {
754764
s.closeErr = err
755765
} else {
756-
req.Header.Set("Mcp-Session-Id", s.sessionID)
766+
req.Header.Set("Mcp-Session-Id", s._sessionID)
757767
if _, err := s.client.Do(req); err != nil {
758768
s.closeErr = err
759769
}

mcp/streamable_test.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,10 @@ func TestStreamableTransports(t *testing.T) {
6969
t.Fatalf("client.Connect() failed: %v", err)
7070
}
7171
defer session.Close()
72-
72+
sid := session.ID()
73+
if sid == "" {
74+
t.Error("empty session ID")
75+
}
7376
// 4. The client calls the "greet" tool.
7477
params := &CallToolParams{
7578
Name: "greet",
@@ -79,6 +82,9 @@ func TestStreamableTransports(t *testing.T) {
7982
if err != nil {
8083
t.Fatalf("CallTool() failed: %v", err)
8184
}
85+
if g := session.ID(); g != sid {
86+
t.Errorf("session ID: got %q, want %q", g, sid)
87+
}
8288

8389
// 5. Verify that the correct response is received.
8490
want := &CallToolResult{

mcp/transport.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ type Connection interface {
5050
Read(context.Context) (JSONRPCMessage, error)
5151
Write(context.Context, JSONRPCMessage) error
5252
Close() error // may be called concurrently by both peers
53+
sessionID() string
5354
}
5455

5556
// A StdioTransport is a [Transport] that communicates over stdin/stdout using
@@ -94,6 +95,7 @@ type binder[T handler] interface {
9495

9596
type handler interface {
9697
handle(ctx context.Context, req *JSONRPCRequest) (any, error)
98+
setSessionIDFunc(func() string) // so Sessions can get the session ID
9799
}
98100

99101
func connect[H handler](ctx context.Context, t Transport, b binder[H]) (H, error) {
@@ -124,6 +126,7 @@ func connect[H handler](ctx context.Context, t Transport, b binder[H]) (H, error
124126
},
125127
})
126128
assert(preempter.conn != nil, "unbound preempter")
129+
h.setSessionIDFunc(conn.sessionID)
127130
return h, nil
128131
}
129132

@@ -200,6 +203,8 @@ type loggingConn struct {
200203
w io.Writer
201204
}
202205

206+
func (c *loggingConn) sessionID() string { return c.delegate.sessionID() }
207+
203208
// loggingReader is a stream middleware that logs incoming messages.
204209
func (s *loggingConn) Read(ctx context.Context) (JSONRPCMessage, error) {
205210
msg, err := s.delegate.Read(ctx)
@@ -285,6 +290,8 @@ func newIOConn(rwc io.ReadWriteCloser) *ioConn {
285290
}
286291
}
287292

293+
func (c *ioConn) sessionID() string { return "" }
294+
288295
// addBatch records a msgBatch for an incoming batch payload.
289296
// It returns an error if batch is malformed, containing previously seen IDs.
290297
//

0 commit comments

Comments
 (0)