diff --git a/examples/client/loadtest/main.go b/examples/client/loadtest/main.go index d5a04c2e..cbf7cd67 100644 --- a/examples/client/loadtest/main.go +++ b/examples/client/loadtest/main.go @@ -34,6 +34,7 @@ var ( timeout = flag.Duration("timeout", 1*time.Second, "request timeout") qps = flag.Int("qps", 100, "tool calls per second, per worker") verbose = flag.Bool("v", false, "if set, enable verbose logging") + cleanup = flag.Bool("cleanup", true, "whether to clean up sessions at the end of the test") ) func main() { @@ -76,7 +77,9 @@ func main() { if err != nil { log.Fatal(err) } - defer cs.Close() + if *cleanup { + defer cs.Close() + } ticker := time.NewTicker(1 * time.Second / time.Duration(*qps)) defer ticker.Stop() diff --git a/mcp/streamable.go b/mcp/streamable.go index 7204e2b1..994ebd6a 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -450,22 +450,14 @@ type streamableServerConn struct { // handled. // streams holds the logical streams for this session, keyed by their ID. - // TODO: streams are never deleted, so the memory for a connection grows without - // bound. If we deleted a stream when the response is sent, we would lose the ability - // to replay if there was a cut just before the response was transmitted. - // Perhaps we could have a TTL for streams that starts just after the response. // - // TODO(rfindley): Once all responses have been received for a stream, we can - // remove it as it is no longer necessary, even if the client wants to replay - // messages. + // Lifecycle: streams persist until all of their responses are received from + // the server. streams map[string]*stream // requestStreams maps incoming requests to their logical stream ID. // - // Lifecycle: requestStreams persist for the duration of the session. - // - // TODO(rfindley): clean up once requests are handled. See the TODO for - // streams above. + // Lifecycle: requestStreams persist until their response is received. requestStreams map[jsonrpc.ID]string } @@ -641,17 +633,39 @@ func (c *streamableServerConn) writeEvent(w http.ResponseWriter, stream *stream, // all messages, so that no delivery or storage of new messages occurs while // the stream is still replaying. func (c *streamableServerConn) acquireStream(ctx context.Context, w http.ResponseWriter, streamID string, lastIdx *int) (*stream, chan struct{}) { + // if tempStream is set, the stream is done and we're just replaying messages. + // + // We record a temporary stream to claim exclusive replay rights. + tempStream := false c.mu.Lock() - stream, ok := c.streams[streamID] - c.mu.Unlock() + s, ok := c.streams[streamID] if !ok { - http.Error(w, "unknown stream", http.StatusBadRequest) - return nil, nil + // The stream is logically done, but claim exclusive rights to replay it by + // adding a temporary entry in the streams map. + // + // We create this entry with a non-nil deliver function, to ensure it isn't + // claimed by another request before we lock it below. + tempStream = true + s = &stream{ + id: streamID, + deliver: func([]byte, bool) error { return nil }, + } + c.streams[streamID] = s + + // Since this stream is transient, we must clean up after replaying. + defer func() { + c.mu.Lock() + delete(c.streams, streamID) + c.mu.Unlock() + }() } + c.mu.Unlock() - stream.mu.Lock() - defer stream.mu.Unlock() - if stream.deliver != nil { + s.mu.Lock() + defer s.mu.Unlock() + + // Check that this stream wasn't claimed by another request. + if !tempStream && s.deliver != nil { http.Error(w, "stream ID conflicts with ongoing stream", http.StatusConflict) return nil, nil } @@ -664,7 +678,7 @@ func (c *streamableServerConn) acquireStream(ctx context.Context, w http.Respons // messages, and registered our delivery function. var toReplay [][]byte if c.eventStore != nil { - for data, err := range c.eventStore.After(ctx, c.SessionID(), stream.id, *lastIdx) { + for data, err := range c.eventStore.After(ctx, c.SessionID(), s.id, *lastIdx) { if err != nil { // We can't replay events, perhaps because the underlying event store // has garbage collected its storage. @@ -685,7 +699,7 @@ func (c *streamableServerConn) acquireStream(ctx context.Context, w http.Respons w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler] w.Header().Set("Connection", "keep-alive") - if stream.id == "" { + if s.id == "" { // Issue #410: the standalone SSE stream is likely not to receive messages // for a long time. Ensure that headers are flushed. w.WriteHeader(http.StatusOK) @@ -695,30 +709,30 @@ func (c *streamableServerConn) acquireStream(ctx context.Context, w http.Respons } for _, data := range toReplay { - if err := c.writeEvent(w, stream, data, lastIdx); err != nil { + if err := c.writeEvent(w, s, data, lastIdx); err != nil { return nil, nil } } - if stream.doneLocked() { + if tempStream || s.doneLocked() { // Nothing more to do. return nil, nil } - // Finally register a delivery function and unlock the stream, allowing the - // connection to write new events. + // The stream is not done: register a delivery function before the stream is + // unlocked, allowing the connection to write new events. done := make(chan struct{}) - stream.deliver = func(data []byte, final bool) error { + s.deliver = func(data []byte, final bool) error { if err := ctx.Err(); err != nil { return err } - err := c.writeEvent(w, stream, data, lastIdx) + err := c.writeEvent(w, s, data, lastIdx) if final { close(done) } return err } - return stream, done + return s, done } // servePOST handles an incoming message, and replies with either an outgoing @@ -1009,13 +1023,23 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e s = c.streams[streamID] } } else { - s = c.streams[""] + s = c.streams[""] // standalone SSE stream + } + if responseTo.IsValid() { + // Once we've responded to a request, disallow related messages by removing + // the stream association. This also releases memory. + delete(c.requestStreams, responseTo) } sessionClosed := c.isDone c.mu.Unlock() if s == nil { - return fmt.Errorf("%w: no stream for request", jsonrpc2.ErrRejected) + // The request was made in the context of an ongoing request, but that + // request is complete. + // + // In the future, we could be less strict and allow the request to land on + // the standalone SSE stream. + return fmt.Errorf("%w: write to closed stream", jsonrpc2.ErrRejected) } if sessionClosed { return errors.New("session is closed") @@ -1024,10 +1048,28 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e s.mu.Lock() defer s.mu.Unlock() if s.doneLocked() { + // It's possible that the stream was completed in between getting s above, + // and acquiring the stream lock. In order to avoid acquiring s.mu while + // holding c.mu, we check the terminal condition again. return fmt.Errorf("%w: write to closed stream", jsonrpc2.ErrRejected) } + // Perform accounting on responses. if responseTo.IsValid() { + if _, ok := s.requests[responseTo]; !ok { + panic(fmt.Sprintf("internal error: stream %v: response to untracked request %v", s.id, responseTo)) + } + if s.id == "" { + // This should be guaranteed not to happen by the stream resolution logic + // above, but be defensive: we don't ever want to delete the standalone + // stream. + panic("internal error: response on standalone stream") + } delete(s.requests, responseTo) + if len(s.requests) == 0 { + c.mu.Lock() + delete(c.streams, s.id) + c.mu.Unlock() + } } delivered := false diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index a0893689..436665f2 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -73,23 +73,53 @@ func TestStreamableTransports(t *testing.T) { return nil, nil, nil } AddTool(server, &Tool{Name: "hang"}, hang) + // We use sampling to test server->client requests, both before and after + // the related client->server request completes. + sampleDone := make(chan struct{}) + var sampleWG sync.WaitGroup AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) { + type testCase struct { + label string + ctx context.Context + wantSuccess bool + } + testSample := func(tc testCase) { + res, err := req.Session.CreateMessage(tc.ctx, &CreateMessageParams{}) + if gotSuccess := err == nil; gotSuccess != tc.wantSuccess { + t.Errorf("%s: CreateMessage success=%v, want %v", tc.label, gotSuccess, tc.wantSuccess) + } + if err != nil { + return + } + if g, w := res.Model, "aModel"; g != w { + t.Errorf("%s: got model %q, want %q", tc.label, g, w) + } + } // Test that we can make sampling requests during tool handling. // // Try this on both the request context and a background context, so // that messages may be delivered on either the POST or GET connection. - for _, ctx := range map[string]context.Context{ - "request context": ctx, - "background context": context.Background(), + for _, test := range []testCase{ + {"request context", ctx, true}, + {"background context", context.Background(), true}, } { - res, err := req.Session.CreateMessage(ctx, &CreateMessageParams{}) - if err != nil { - return nil, nil, err - } - if g, w := res.Model, "aModel"; g != w { - return nil, nil, fmt.Errorf("got %q, want %q", g, w) - } + testSample(test) } + // Now, spin off a goroutine that runs after the sampling request, to + // check behavior when the client request has completed. + sampleWG.Add(1) + go func() { + defer sampleWG.Done() + <-sampleDone + // Test that sampling requests in the tool context fail outside of + // tool handling, but succeed on the background context. + for _, test := range []testCase{ + {"request context", ctx, false}, + {"background context", context.Background(), true}, + } { + testSample(test) + } + }() return &CallToolResult{}, nil, nil }) @@ -191,8 +221,8 @@ func TestStreamableTransports(t *testing.T) { t.Fatal("timeout waiting for cancellation") } - // The "sampling" tool should be able to issue sampling requests during - // tool operation. + // The "sampling" tool checks the validity of server->client requests + // both within and without the tool context. result, err := session.CallTool(ctx, &CallToolParams{ Name: "sample", Arguments: map[string]any{}, @@ -200,6 +230,10 @@ func TestStreamableTransports(t *testing.T) { if err != nil { t.Fatal(err) } + // Run the out-of-band checks. + close(sampleDone) + sampleWG.Wait() + if result.IsError { t.Fatalf("tool failed: %s", result.Content[0].(*TextContent).Text) }