Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion examples/client/loadtest/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ var (
timeout = flag.Duration("timeout", 1*time.Second, "request timeout")
qps = flag.Int("qps", 100, "tool calls per second, per worker")
verbose = flag.Bool("v", false, "if set, enable verbose logging")
cleanup = flag.Bool("cleanup", true, "whether to clean up sessions at the end of the test")
)

func main() {
Expand Down Expand Up @@ -76,7 +77,9 @@ func main() {
if err != nil {
log.Fatal(err)
}
defer cs.Close()
if *cleanup {
defer cs.Close()
}

ticker := time.NewTicker(1 * time.Second / time.Duration(*qps))
defer ticker.Stop()
Expand Down
100 changes: 71 additions & 29 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -450,22 +450,14 @@ type streamableServerConn struct {
// handled.

// streams holds the logical streams for this session, keyed by their ID.
// TODO: streams are never deleted, so the memory for a connection grows without
// bound. If we deleted a stream when the response is sent, we would lose the ability
// to replay if there was a cut just before the response was transmitted.
// Perhaps we could have a TTL for streams that starts just after the response.
//
// TODO(rfindley): Once all responses have been received for a stream, we can
// remove it as it is no longer necessary, even if the client wants to replay
// messages.
// Lifecycle: streams persist until all of their responses are received from
// the server.
streams map[string]*stream

// requestStreams maps incoming requests to their logical stream ID.
//
// Lifecycle: requestStreams persist for the duration of the session.
//
// TODO(rfindley): clean up once requests are handled. See the TODO for
// streams above.
// Lifecycle: requestStreams persist until their response is received.
requestStreams map[jsonrpc.ID]string
}

Expand Down Expand Up @@ -641,17 +633,39 @@ func (c *streamableServerConn) writeEvent(w http.ResponseWriter, stream *stream,
// all messages, so that no delivery or storage of new messages occurs while
// the stream is still replaying.
func (c *streamableServerConn) acquireStream(ctx context.Context, w http.ResponseWriter, streamID string, lastIdx *int) (*stream, chan struct{}) {
// if tempStream is set, the stream is done and we're just replaying messages.
//
// We record a temporary stream to claim exclusive replay rights.
tempStream := false
c.mu.Lock()
stream, ok := c.streams[streamID]
c.mu.Unlock()
s, ok := c.streams[streamID]
if !ok {
http.Error(w, "unknown stream", http.StatusBadRequest)
return nil, nil
// The stream is logically done, but claim exclusive rights to replay it by
// adding a temporary entry in the streams map.
//
// We create this entry with a non-nil deliver function, to ensure it isn't
// claimed by another request before we lock it below.
tempStream = true
s = &stream{
id: streamID,
deliver: func([]byte, bool) error { return nil },
}
c.streams[streamID] = s

// Since this stream is transient, we must clean up after replaying.
defer func() {
c.mu.Lock()
delete(c.streams, streamID)
c.mu.Unlock()
}()
}
c.mu.Unlock()

stream.mu.Lock()
defer stream.mu.Unlock()
if stream.deliver != nil {
s.mu.Lock()
defer s.mu.Unlock()

// Check that this stream wasn't claimed by another request.
if !tempStream && s.deliver != nil {
http.Error(w, "stream ID conflicts with ongoing stream", http.StatusConflict)
return nil, nil
}
Expand All @@ -664,7 +678,7 @@ func (c *streamableServerConn) acquireStream(ctx context.Context, w http.Respons
// messages, and registered our delivery function.
var toReplay [][]byte
if c.eventStore != nil {
for data, err := range c.eventStore.After(ctx, c.SessionID(), stream.id, *lastIdx) {
for data, err := range c.eventStore.After(ctx, c.SessionID(), s.id, *lastIdx) {
if err != nil {
// We can't replay events, perhaps because the underlying event store
// has garbage collected its storage.
Expand All @@ -685,7 +699,7 @@ func (c *streamableServerConn) acquireStream(ctx context.Context, w http.Respons
w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler]
w.Header().Set("Connection", "keep-alive")

if stream.id == "" {
if s.id == "" {
// Issue #410: the standalone SSE stream is likely not to receive messages
// for a long time. Ensure that headers are flushed.
w.WriteHeader(http.StatusOK)
Expand All @@ -695,30 +709,30 @@ func (c *streamableServerConn) acquireStream(ctx context.Context, w http.Respons
}

for _, data := range toReplay {
if err := c.writeEvent(w, stream, data, lastIdx); err != nil {
if err := c.writeEvent(w, s, data, lastIdx); err != nil {
return nil, nil
}
}

if stream.doneLocked() {
if tempStream || s.doneLocked() {
// Nothing more to do.
return nil, nil
}

// Finally register a delivery function and unlock the stream, allowing the
// connection to write new events.
// The stream is not done: register a delivery function before the stream is
// unlocked, allowing the connection to write new events.
done := make(chan struct{})
stream.deliver = func(data []byte, final bool) error {
s.deliver = func(data []byte, final bool) error {
if err := ctx.Err(); err != nil {
return err
}
err := c.writeEvent(w, stream, data, lastIdx)
err := c.writeEvent(w, s, data, lastIdx)
if final {
close(done)
}
return err
}
return stream, done
return s, done
}

// servePOST handles an incoming message, and replies with either an outgoing
Expand Down Expand Up @@ -1009,13 +1023,23 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e
s = c.streams[streamID]
}
} else {
s = c.streams[""]
s = c.streams[""] // standalone SSE stream
}
if responseTo.IsValid() {
// Once we've responded to a request, disallow related messages by removing
// the stream association. This also releases memory.
delete(c.requestStreams, responseTo)
}
sessionClosed := c.isDone
c.mu.Unlock()

if s == nil {
return fmt.Errorf("%w: no stream for request", jsonrpc2.ErrRejected)
// The request was made in the context of an ongoing request, but that
// request is complete.
//
// In the future, we could be less strict and allow the request to land on
// the standalone SSE stream.
return fmt.Errorf("%w: write to closed stream", jsonrpc2.ErrRejected)
}
if sessionClosed {
return errors.New("session is closed")
Expand All @@ -1024,10 +1048,28 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e
s.mu.Lock()
defer s.mu.Unlock()
if s.doneLocked() {
// It's possible that the stream was completed in between getting s above,
// and acquiring the stream lock. In order to avoid acquiring s.mu while
// holding c.mu, we check the terminal condition again.
return fmt.Errorf("%w: write to closed stream", jsonrpc2.ErrRejected)
}
// Perform accounting on responses.
if responseTo.IsValid() {
if _, ok := s.requests[responseTo]; !ok {
panic(fmt.Sprintf("internal error: stream %v: response to untracked request %v", s.id, responseTo))
}
if s.id == "" {
// This should be guaranteed not to happen by the stream resolution logic
// above, but be defensive: we don't ever want to delete the standalone
// stream.
panic("internal error: response on standalone stream")
}
delete(s.requests, responseTo)
if len(s.requests) == 0 {
c.mu.Lock()
delete(c.streams, s.id)
c.mu.Unlock()
}
}

delivered := false
Expand Down
58 changes: 46 additions & 12 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,23 +73,53 @@ func TestStreamableTransports(t *testing.T) {
return nil, nil, nil
}
AddTool(server, &Tool{Name: "hang"}, hang)
// We use sampling to test server->client requests, both before and after
// the related client->server request completes.
sampleDone := make(chan struct{})
var sampleWG sync.WaitGroup
AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) {
type testCase struct {
label string
ctx context.Context
wantSuccess bool
}
testSample := func(tc testCase) {
res, err := req.Session.CreateMessage(tc.ctx, &CreateMessageParams{})
if gotSuccess := err == nil; gotSuccess != tc.wantSuccess {
t.Errorf("%s: CreateMessage success=%v, want %v", tc.label, gotSuccess, tc.wantSuccess)
}
if err != nil {
return
}
if g, w := res.Model, "aModel"; g != w {
t.Errorf("%s: got model %q, want %q", tc.label, g, w)
}
}
// Test that we can make sampling requests during tool handling.
//
// Try this on both the request context and a background context, so
// that messages may be delivered on either the POST or GET connection.
for _, ctx := range map[string]context.Context{
"request context": ctx,
"background context": context.Background(),
for _, test := range []testCase{
{"request context", ctx, true},
{"background context", context.Background(), true},
} {
res, err := req.Session.CreateMessage(ctx, &CreateMessageParams{})
if err != nil {
return nil, nil, err
}
if g, w := res.Model, "aModel"; g != w {
return nil, nil, fmt.Errorf("got %q, want %q", g, w)
}
testSample(test)
}
// Now, spin off a goroutine that runs after the sampling request, to
// check behavior when the client request has completed.
sampleWG.Add(1)
go func() {
defer sampleWG.Done()
<-sampleDone
// Test that sampling requests in the tool context fail outside of
// tool handling, but succeed on the background context.
for _, test := range []testCase{
{"request context", ctx, false},
{"background context", context.Background(), true},
} {
testSample(test)
}
}()
return &CallToolResult{}, nil, nil
})

Expand Down Expand Up @@ -191,15 +221,19 @@ func TestStreamableTransports(t *testing.T) {
t.Fatal("timeout waiting for cancellation")
}

// The "sampling" tool should be able to issue sampling requests during
// tool operation.
// The "sampling" tool checks the validity of server->client requests
// both within and without the tool context.
result, err := session.CallTool(ctx, &CallToolParams{
Name: "sample",
Arguments: map[string]any{},
})
if err != nil {
t.Fatal(err)
}
// Run the out-of-band checks.
close(sampleDone)
sampleWG.Wait()

if result.IsError {
t.Fatalf("tool failed: %s", result.Content[0].(*TextContent).Text)
}
Expand Down