Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ var clientMethodInfos = map[string]methodInfo{
methodPing: newClientMethodInfo(clientSessionMethod((*ClientSession).ping), missingParamsOK),
methodListRoots: newClientMethodInfo(clientMethod((*Client).listRoots), missingParamsOK),
methodCreateMessage: newClientMethodInfo(clientMethod((*Client).createMessage), 0),
notificationCancelled: newClientMethodInfo(clientSessionMethod((*ClientSession).cancel), notification|missingParamsOK),
notificationToolListChanged: newClientMethodInfo(clientMethod((*Client).callToolChangedHandler), notification|missingParamsOK),
notificationPromptListChanged: newClientMethodInfo(clientMethod((*Client).callPromptChangedHandler), notification|missingParamsOK),
notificationResourceListChanged: newClientMethodInfo(clientMethod((*Client).callResourceChangedHandler), notification|missingParamsOK),
Expand Down Expand Up @@ -344,6 +345,15 @@ func (*ClientSession) ping(context.Context, *PingParams) (*emptyResult, error) {
return &emptyResult{}, nil
}

// cancel is a placeholder: cancellation is handled the jsonrpc2 package.
//
// It should never be invoked in practice because cancellation is preempted,
// but having its signature here facilitates the construction of methodInfo
// that can be used to validate incoming cancellation notifications.
func (*ClientSession) cancel(context.Context, *CancelledParams) (Result, error) {
return nil, nil
}

func newClientRequest[P Params](cs *ClientSession, params P) *ClientRequest[P] {
return &ClientRequest[P]{Session: cs, Params: params}
}
Expand Down
5 changes: 2 additions & 3 deletions mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -646,8 +646,7 @@ func TestCancellation(t *testing.T) {
start = make(chan struct{})
cancelled = make(chan struct{}, 1) // don't block the request
)

slowRequest := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResult, error) {
slowRequest := func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) {
start <- struct{}{}
select {
case <-ctx.Done():
Expand All @@ -658,7 +657,7 @@ func TestCancellation(t *testing.T) {
return nil, nil
}
_, cs := basicConnection(t, func(s *Server) {
s.AddTool(&Tool{Name: "slow", InputSchema: &jsonschema.Schema{}}, slowRequest)
AddTool(s, &Tool{Name: "slow"}, slowRequest)
})
defer cs.Close()

Expand Down
10 changes: 10 additions & 0 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,7 @@ var serverMethodInfos = map[string]methodInfo{
methodSetLevel: newServerMethodInfo(serverSessionMethod((*ServerSession).setLevel), 0),
methodSubscribe: newServerMethodInfo(serverMethod((*Server).subscribe), 0),
methodUnsubscribe: newServerMethodInfo(serverMethod((*Server).unsubscribe), 0),
notificationCancelled: newServerMethodInfo(serverSessionMethod((*ServerSession).cancel), notification|missingParamsOK),
notificationInitialized: newServerMethodInfo(serverSessionMethod((*ServerSession).initialized), notification|missingParamsOK),
notificationRootsListChanged: newServerMethodInfo(serverMethod((*Server).callRootsListChangedHandler), notification|missingParamsOK),
notificationProgress: newServerMethodInfo(serverSessionMethod((*ServerSession).callProgressNotificationHandler), notification),
Expand Down Expand Up @@ -838,6 +839,15 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error
return &emptyResult{}, nil
}

// cancel is a placeholder: cancellation is handled the jsonrpc2 package.
//
// It should never be invoked in practice because cancellation is preempted,
// but having its signature here facilitates the construction of methodInfo
// that can be used to validate incoming cancellation notifications.
func (ss *ServerSession) cancel(context.Context, *CancelledParams) (Result, error) {
return nil, nil
}

func (ss *ServerSession) setLevel(_ context.Context, params *SetLevelParams) (*emptyResult, error) {
ss.updateState(func(state *ServerSessionState) {
state.LogLevel = params.Level
Expand Down
20 changes: 10 additions & 10 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,12 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
return
}

var session *StreamableServerTransport
var transport *StreamableServerTransport
if id := req.Header.Get(sessionIDHeader); id != "" {
h.mu.Lock()
session, _ = h.transports[id]
transport, _ = h.transports[id]
h.mu.Unlock()
if session == nil {
if transport == nil {
http.Error(w, "session not found", http.StatusNotFound)
return
}
Expand All @@ -129,22 +129,22 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
// TODO(rfindley): simplify the locking so that each request has only one
// critical section.
if req.Method == http.MethodDelete {
if session == nil {
if transport == nil {
// => Mcp-Session-Id was not set; else we'd have returned NotFound above.
http.Error(w, "DELETE requires an Mcp-Session-Id header", http.StatusBadRequest)
return
}
h.mu.Lock()
delete(h.transports, session.SessionID)
delete(h.transports, transport.SessionID)
h.mu.Unlock()
session.connection.Close()
transport.connection.Close()
w.WriteHeader(http.StatusNoContent)
return
}

switch req.Method {
case http.MethodPost, http.MethodGet:
if req.Method == http.MethodGet && session == nil {
if req.Method == http.MethodGet && transport == nil {
http.Error(w, "GET requires an active session", http.StatusMethodNotAllowed)
return
}
Expand All @@ -154,7 +154,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
return
}

if session == nil {
if transport == nil {
server := h.getServer(req)
if server == nil {
// The getServer argument to NewStreamableHTTPHandler returned nil.
Expand Down Expand Up @@ -191,10 +191,10 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
h.transports[s.SessionID] = s
h.mu.Unlock()
}
session = s
transport = s
}

session.ServeHTTP(w, req)
transport.ServeHTTP(w, req)
}

// StreamableServerTransportOptions configures the stramable server transport.
Expand Down
51 changes: 39 additions & 12 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,26 @@ func TestStreamableTransports(t *testing.T) {

for _, useJSON := range []bool{false, true} {
t.Run(fmt.Sprintf("JSONResponse=%v", useJSON), func(t *testing.T) {
// 1. Create a server with a simple "greet" tool.
// Create a server with some simple tools.
server := NewServer(testImpl, nil)
AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi)
// The "hang" tool checks that context cancellation is propagated.
// It hangs until the context is cancelled.
var (
start = make(chan struct{})
cancelled = make(chan struct{}, 1) // don't block the request
)
hang := func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) {
start <- struct{}{}
select {
case <-ctx.Done():
cancelled <- struct{}{}
case <-time.After(5 * time.Second):
return nil, nil
}
return nil, nil
}
AddTool(server, &Tool{Name: "hang"}, hang)
AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) {
// Test that we can make sampling requests during tool handling.
//
Expand All @@ -60,7 +77,7 @@ func TestStreamableTransports(t *testing.T) {
return &CallToolResultFor[any]{}, nil
})

// 2. Start an httptest.Server with the StreamableHTTPHandler, wrapped in a
// Start an httptest.Server with the StreamableHTTPHandler, wrapped in a
// cookie-checking middleware.
handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{
jsonResponse: useJSON,
Expand All @@ -84,7 +101,7 @@ func TestStreamableTransports(t *testing.T) {
}))
defer httpServer.Close()

// 3. Create a client and connect it to the server using our StreamableClientTransport.
// Create a client and connect it to the server using our StreamableClientTransport.
// Check that all requests honor a custom client.
jar, err := cookiejar.New(nil)
if err != nil {
Expand Down Expand Up @@ -117,10 +134,13 @@ func TestStreamableTransports(t *testing.T) {
if g, w := session.mcpConn.(*streamableClientConn).initializedResult.ProtocolVersion, latestProtocolVersion; g != w {
t.Fatalf("got protocol version %q, want %q", g, w)
}
// 4. The client calls the "greet" tool.

// Verify the behavior of various tools.

// The "greet" tool should just work.
params := &CallToolParams{
Name: "greet",
Arguments: map[string]any{"name": "streamy"},
Arguments: map[string]any{"name": "foo"},
}
got, err := session.CallTool(ctx, params)
if err != nil {
Expand All @@ -132,19 +152,26 @@ func TestStreamableTransports(t *testing.T) {
if g, w := lastHeader.Get(protocolVersionHeader), latestProtocolVersion; g != w {
t.Errorf("got protocol version header %q, want %q", g, w)
}

// 5. Verify that the correct response is received.
want := &CallToolResult{
Content: []Content{
&TextContent{Text: "hi streamy"},
},
Content: []Content{&TextContent{Text: "hi foo"}},
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("CallTool() returned unexpected content (-want +got):\n%s", diff)
}

// 6. Run the "sampling" tool and verify that the streamable server can
// call tools.
// The "hang" tool should be cancellable.
ctx2, cancel := context.WithCancel(context.Background())
go session.CallTool(ctx2, &CallToolParams{Name: "hang"})
<-start
cancel()
select {
case <-cancelled:
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for cancellation")
}

// The "sampling" tool should be able to issue sampling requests during
// tool operation.
result, err := session.CallTool(ctx, &CallToolParams{
Name: "sample",
Arguments: map[string]any{},
Expand Down
2 changes: 1 addition & 1 deletion mcp/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ type canceller struct {

// Preempt implements [jsonrpc2.Preempter].
func (c *canceller) Preempt(ctx context.Context, req *jsonrpc.Request) (result any, err error) {
if req.Method == "notifications/cancelled" {
if req.Method == notificationCancelled {
var params CancelledParams
if err := json.Unmarshal(req.Params, &params); err != nil {
return nil, err
Expand Down