diff --git a/mcp/streamable.go b/mcp/streamable.go index e7777eb0..25efe31a 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1130,7 +1130,7 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) { // ยง 2.5: A server using the Streamable HTTP transport MAY assign a session // ID at initialization time, by including it in an Mcp-Session-Id header // on the HTTP response containing the InitializeResult. - go c.handleSSE(nil, true, nil) + go c.handleSSE("hanging GET", nil, true, nil) } // fail handles an asynchronous error while reading. @@ -1224,17 +1224,27 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return nil } + var requestSummary string + switch msg := msg.(type) { + case *jsonrpc.Request: + requestSummary = fmt.Sprintf("sending %q", msg.Method) + case *jsonrpc.Response: + requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID) + default: + panic("unreachable") + } + switch ct := resp.Header.Get("Content-Type"); ct { case "application/json": - go c.handleJSON(resp) + go c.handleJSON(requestSummary, resp) case "text/event-stream": jsonReq, _ := msg.(*jsonrpc.Request) - go c.handleSSE(resp, false, jsonReq) + go c.handleSSE(requestSummary, resp, false, jsonReq) default: resp.Body.Close() - return fmt.Errorf("unsupported content type %q", ct) + return fmt.Errorf("%s: unsupported content type %q", requestSummary, ct) } return nil } @@ -1258,16 +1268,16 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request) { } } -func (c *streamableClientConn) handleJSON(resp *http.Response) { +func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Response) { body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { - c.fail(err) + c.fail(fmt.Errorf("%s: failed to read body: %v", requestSummary, err)) return } msg, err := jsonrpc.DecodeMessage(body) if err != nil { - c.fail(fmt.Errorf("failed to decode response: %v", err)) + c.fail(fmt.Errorf("%s: failed to decode response: %v", requestSummary, err)) return } select { @@ -1282,12 +1292,12 @@ func (c *streamableClientConn) handleJSON(resp *http.Response) { // // If forReq is set, it is the request that initiated the stream, and the // stream is complete when we receive its response. -func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent bool, forReq *jsonrpc2.Request) { +func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *http.Response, persistent bool, forReq *jsonrpc2.Request) { resp := initialResp var lastEventID string for { if resp != nil { - eventID, clientClosed := c.processStream(resp, forReq) + eventID, clientClosed := c.processStream(requestSummary, resp, forReq) lastEventID = eventID // If the connection was closed by the client, we're done. @@ -1305,7 +1315,7 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent newResp, err := c.reconnect(lastEventID) if err != nil { // All reconnection attempts failed: fail the connection. - c.fail(err) + c.fail(fmt.Errorf("%s: failed to reconnect: %v", requestSummary, err)) return } resp = newResp @@ -1316,7 +1326,7 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent } if resp.StatusCode < 200 || resp.StatusCode >= 300 { resp.Body.Close() - c.fail(fmt.Errorf("failed to reconnect: %v", http.StatusText(resp.StatusCode))) + c.fail(fmt.Errorf("%s: failed to reconnect: %v", requestSummary, http.StatusText(resp.StatusCode))) return } // Reconnection was successful. Continue the loop with the new response. @@ -1327,7 +1337,7 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent // incoming channel. It returns the ID of the last processed event and a flag // indicating if the connection was closed by the client. If resp is nil, it // returns "", false. -func (c *streamableClientConn) processStream(resp *http.Response, forReq *jsonrpc.Request) (lastEventID string, clientClosed bool) { +func (c *streamableClientConn) processStream(requestSummary string, resp *http.Response, forReq *jsonrpc.Request) (lastEventID string, clientClosed bool) { defer resp.Body.Close() for evt, err := range scanEvents(resp.Body) { if err != nil { @@ -1340,7 +1350,7 @@ func (c *streamableClientConn) processStream(resp *http.Response, forReq *jsonrp msg, err := jsonrpc.DecodeMessage(evt.Data) if err != nil { - c.fail(fmt.Errorf("failed to decode event: %v", err)) + c.fail(fmt.Errorf("%s: failed to decode event: %v", requestSummary, err)) return "", true } diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go new file mode 100644 index 00000000..fe87b21c --- /dev/null +++ b/mcp/streamable_client_test.go @@ -0,0 +1,275 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +type streamableRequestKey struct { + httpMethod string // http method + sessionID string // session ID header + jsonrpcMethod string // jsonrpc method, or "" for non-requests +} + +type header map[string]string + +type streamableResponse struct { + header header + status int // or http.StatusOK + body string // or "" + optional bool // if set, request need not be sent + wantProtocolVersion string // if "", unchecked + callback func() // if set, called after the request is handled +} + +type fakeResponses map[streamableRequestKey]*streamableResponse + +type fakeStreamableServer struct { + t *testing.T + responses fakeResponses + + callMu sync.Mutex + calls map[streamableRequestKey]int +} + +func (s *fakeStreamableServer) missingRequests() []streamableRequestKey { + s.callMu.Lock() + defer s.callMu.Unlock() + + var unused []streamableRequestKey + for k, resp := range s.responses { + if s.calls[k] == 0 && !resp.optional { + unused = append(unused, k) + } + } + return unused +} + +func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { + key := streamableRequestKey{ + httpMethod: req.Method, + sessionID: req.Header.Get(sessionIDHeader), + } + if req.Method == http.MethodPost { + body, err := io.ReadAll(req.Body) + if err != nil { + s.t.Errorf("failed to read body: %v", err) + http.Error(w, "failed to read body", http.StatusInternalServerError) + return + } + msg, err := jsonrpc.DecodeMessage(body) + if err != nil { + s.t.Errorf("invalid body: %v", err) + http.Error(w, "invalid body", http.StatusInternalServerError) + return + } + if r, ok := msg.(*jsonrpc.Request); ok { + key.jsonrpcMethod = r.Method + } + } + + s.callMu.Lock() + if s.calls == nil { + s.calls = make(map[streamableRequestKey]int) + } + s.calls[key]++ + s.callMu.Unlock() + + resp, ok := s.responses[key] + if !ok { + s.t.Errorf("missing response for %v", key) + http.Error(w, "no response", http.StatusInternalServerError) + return + } + if resp.callback != nil { + defer resp.callback() + } + for k, v := range resp.header { + w.Header().Set(k, v) + } + status := resp.status + if status == 0 { + status = http.StatusOK + } + w.WriteHeader(status) + + if v := req.Header.Get(protocolVersionHeader); v != resp.wantProtocolVersion && resp.wantProtocolVersion != "" { + s.t.Errorf("%v: bad protocol version header: got %q, want %q", key, v, resp.wantProtocolVersion) + } + w.Write([]byte(resp.body)) +} + +var ( + initResult = &InitializeResult{ + Capabilities: &ServerCapabilities{ + Completions: &CompletionCapabilities{}, + Logging: &LoggingCapabilities{}, + Tools: &ToolCapabilities{ListChanged: true}, + }, + ProtocolVersion: latestProtocolVersion, + ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, + } + initResp = resp(1, initResult, nil) +) + +func jsonBody(t *testing.T, msg jsonrpc2.Message) string { + data, err := jsonrpc2.EncodeMessage(msg) + if err != nil { + t.Fatalf("encoding failed: %v", err) + } + return string(data) +} + +func TestStreamableClientTransportLifecycle(t *testing.T) { + ctx := context.Background() + + // The lifecycle test verifies various behavior of the streamable client + // initialization: + // - check that it can handle application/json responses + // - check that it sends the negotiated protocol version + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + {"POST", "123", notificationInitialized}: { + status: http.StatusAccepted, + wantProtocolVersion: latestProtocolVersion, + }, + {"GET", "123", ""}: { + header: header{ + "Content-Type": "text/event-stream", + }, + optional: true, + wantProtocolVersion: latestProtocolVersion, + }, + {"DELETE", "123", ""}: {}, + }, + } + + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + if err := session.Close(); err != nil { + t.Errorf("closing session: %v", err) + } + if missing := fake.missingRequests(); len(missing) > 0 { + t.Errorf("did not receive expected requests: %v", missing) + } + if diff := cmp.Diff(initResult, session.state.InitializeResult); diff != "" { + t.Errorf("mismatch (-want, +got):\n%s", diff) + } +} + +func TestStreamableClientGETHandling(t *testing.T) { + ctx := context.Background() + + tests := []struct { + status int + wantErrorContaining string + }{ + {http.StatusOK, ""}, + {http.StatusMethodNotAllowed, ""}, + {http.StatusBadRequest, "hanging GET"}, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("status=%d", test.status), func(t *testing.T) { + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + {"POST", "123", notificationInitialized}: { + status: http.StatusAccepted, + wantProtocolVersion: latestProtocolVersion, + }, + {"GET", "123", ""}: { + header: header{ + "Content-Type": "text/event-stream", + }, + status: test.status, + wantProtocolVersion: latestProtocolVersion, + }, + {"POST", "123", methodListTools}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, resp(2, &ListToolsResult{Tools: []*Tool{}}, nil)), + optional: true, + }, + {"DELETE", "123", ""}: {optional: true}, + }, + } + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + + // wait for all required requests to be handled, with exponential + // backoff. + start := time.Now() + delay := 1 * time.Millisecond + for range 10 { + if len(fake.missingRequests()) == 0 { + break + } + time.Sleep(delay) + delay *= 2 + } + if missing := fake.missingRequests(); len(missing) > 0 { + t.Errorf("did not receive expected requests after %s: %v", time.Since(start), missing) + } + + _, err = session.ListTools(ctx, nil) + if (err != nil) != (test.wantErrorContaining != "") { + t.Errorf("After initialization, got error %v, want %v", err, test.wantErrorContaining) + } else if err != nil { + if !strings.Contains(err.Error(), test.wantErrorContaining) { + t.Errorf("After initialization, got error %s, want containing %q", err, test.wantErrorContaining) + } + } + + if err := session.Close(); err != nil { + t.Errorf("closing session: %v", err) + } + }) + } +} diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 0d171d83..2963a04d 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1035,77 +1035,6 @@ func mustMarshal(v any) json.RawMessage { return data } -func TestStreamableClientTransport(t *testing.T) { - // This test verifies various behavior of the streamable client transport: - // - check that it can handle application/json responses - // - check that it sends the negotiated protocol version - // - // TODO(rfindley): make this test more comprehensive, similar to - // [TestStreamableServerTransport]. - ctx := context.Background() - resp := func(id int64, result any, err error) *jsonrpc.Response { - return &jsonrpc.Response{ - ID: jsonrpc2.Int64ID(id), - Result: mustMarshal(result), - Error: err, - } - } - initResult := &InitializeResult{ - Capabilities: &ServerCapabilities{ - Completions: &CompletionCapabilities{}, - Logging: &LoggingCapabilities{}, - Tools: &ToolCapabilities{ListChanged: true}, - }, - ProtocolVersion: latestProtocolVersion, - ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, - } - initResp := resp(1, initResult, nil) - - var reqN atomic.Int32 // request count - serverHandler := func(w http.ResponseWriter, r *http.Request) { - rN := reqN.Add(1) - - // TODO(rfindley): if the status code is NoContent or Accepted, we should - // probably be tolerant of when the content type is not application/json. - w.Header().Set("Content-Type", "application/json") - if rN == 1 { - data, err := jsonrpc2.EncodeMessage(initResp) - if err != nil { - t.Errorf("encoding failed: %v", err) - } - w.Header().Set("Mcp-Session-Id", "123") - w.Write(data) - } else { - if v := r.Header.Get(protocolVersionHeader); v != latestProtocolVersion { - t.Errorf("bad protocol version header: got %q, want %q", v, latestProtocolVersion) - } - } - } - - httpServer := httptest.NewServer(http.HandlerFunc(serverHandler)) - defer httpServer.Close() - - transport := &StreamableClientTransport{Endpoint: httpServer.URL} - client := NewClient(testImpl, nil) - session, err := client.Connect(ctx, transport, nil) - if err != nil { - t.Fatalf("client.Connect() failed: %v", err) - } - if err := session.Close(); err != nil { - t.Errorf("closing session: %v", err) - } - - if got, want := reqN.Load(), int32(3); got < want { - // Expect at least 3 requests: initialize, initialized, and DELETE. - // We may or may not observe the GET, depending on timing. - t.Errorf("unexpected number of requests: got %d, want at least %d", got, want) - } - - if diff := cmp.Diff(initResult, session.state.InitializeResult); diff != "" { - t.Errorf("mismatch (-want, +got):\n%s", diff) - } -} - func TestEventID(t *testing.T) { tests := []struct { sid StreamID diff --git a/mcp/transport.go b/mcp/transport.go index fac640a6..5c7ca130 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -194,7 +194,7 @@ func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params err := call.Await(ctx, result) switch { case errors.Is(err, jsonrpc2.ErrClientClosing), errors.Is(err, jsonrpc2.ErrServerClosing): - return fmt.Errorf("calling %q: %w", method, ErrConnectionClosed) + return fmt.Errorf("%w: calling %q: %v", ErrConnectionClosed, method, err) case ctx.Err() != nil: // Notify the peer of cancellation. err := conn.Notify(xcontext.Detach(ctx), notificationCancelled, &CancelledParams{