diff --git a/mcp/streamable.go b/mcp/streamable.go index e5ffa642..1e1d8579 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -800,6 +800,8 @@ func (s *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return nil } +// postMessage POSTs msg to the server and reads the response. +// It returns the session ID from the response. func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string, msg jsonrpc.Message) (string, error) { data, err := jsonrpc2.EncodeMessage(msg) if err != nil { @@ -836,9 +838,17 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string // Section 2.1: The SSE stream is initiated after a POST. go s.handleSSE(resp) case "application/json": - // TODO: read the body and send to s.incoming (in a select that also receives from s.done). + body, err := io.ReadAll(resp.Body) resp.Body.Close() - return "", fmt.Errorf("streamable HTTP client does not yet support raw JSON responses") + if err != nil { + return "", err + } + select { + case s.incoming <- body: + case <-s.done: + // The connection was closed by the client; exit gracefully. + } + return sessionID, nil default: resp.Body.Close() return "", fmt.Errorf("unsupported content type %q", ct) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 864265e5..185bc638 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -158,7 +158,8 @@ func TestClientReplay(t *testing.T) { client := NewClient(testImpl, &ClientOptions{ ProgressNotificationHandler: func(ctx context.Context, cc *ClientSession, params *ProgressNotificationParams) { notifications <- params.Message - }}) + }, + }) clientSession, err := client.Connect(ctx, NewStreamableClientTransport(proxy.URL, nil)) if err != nil { t.Fatalf("client.Connect() failed: %v", err) @@ -678,6 +679,51 @@ func mustMarshal(t *testing.T, v any) json.RawMessage { return data } +func TestStreamableClientTransportApplicationJSON(t *testing.T) { + // Test handling of application/json responses. + ctx := context.Background() + resp := func(id int64, result any, err error) *jsonrpc.Response { + return &jsonrpc.Response{ + ID: jsonrpc2.Int64ID(id), + Result: mustMarshal(t, result), + Error: err, + } + } + initResult := &InitializeResult{ + Capabilities: &serverCapabilities{ + Completions: &completionCapabilities{}, + Logging: &loggingCapabilities{}, + Tools: &toolCapabilities{ListChanged: true}, + }, + ProtocolVersion: latestProtocolVersion, + ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, + } + initResp := resp(1, initResult, nil) + + serverHandler := func(w http.ResponseWriter, r *http.Request) { + data, err := jsonrpc2.EncodeMessage(initResp) + if err != nil { + t.Fatal(err) + } + w.Header().Set("Content-Type", "application/json") + w.Write(data) + } + + httpServer := httptest.NewServer(http.HandlerFunc(serverHandler)) + defer httpServer.Close() + + transport := NewStreamableClientTransport(httpServer.URL, nil) + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + defer session.Close() + if diff := cmp.Diff(initResult, session.initializeResult); diff != "" { + t.Errorf("mismatch (-want, +got):\n%s", diff) + } +} + func TestEventID(t *testing.T) { tests := []struct { sid StreamID