Skip to content

Commit 64b5b91

Browse files
authored
mcp: support application/json in streamable client (#181)
The client handles POST responses with content type application/json. For: #10 Fixes #129
1 parent 2b6f7b5 commit 64b5b91

File tree

2 files changed

+59
-3
lines changed

2 files changed

+59
-3
lines changed

mcp/streamable.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,8 @@ func (s *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
813813
return nil
814814
}
815815

816+
// postMessage POSTs msg to the server and reads the response.
817+
// It returns the session ID from the response.
816818
func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string, msg jsonrpc.Message) (string, error) {
817819
data, err := jsonrpc2.EncodeMessage(msg)
818820
if err != nil {
@@ -849,9 +851,17 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
849851
// Section 2.1: The SSE stream is initiated after a POST.
850852
go s.handleSSE(resp)
851853
case "application/json":
852-
// TODO: read the body and send to s.incoming (in a select that also receives from s.done).
854+
body, err := io.ReadAll(resp.Body)
853855
resp.Body.Close()
854-
return "", fmt.Errorf("streamable HTTP client does not yet support raw JSON responses")
856+
if err != nil {
857+
return "", err
858+
}
859+
select {
860+
case s.incoming <- body:
861+
case <-s.done:
862+
// The connection was closed by the client; exit gracefully.
863+
}
864+
return sessionID, nil
855865
default:
856866
resp.Body.Close()
857867
return "", fmt.Errorf("unsupported content type %q", ct)

mcp/streamable_test.go

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ func TestClientReplay(t *testing.T) {
158158
client := NewClient(testImpl, &ClientOptions{
159159
ProgressNotificationHandler: func(ctx context.Context, cc *ClientSession, params *ProgressNotificationParams) {
160160
notifications <- params.Message
161-
}})
161+
},
162+
})
162163
clientSession, err := client.Connect(ctx, NewStreamableClientTransport(proxy.URL, nil))
163164
if err != nil {
164165
t.Fatalf("client.Connect() failed: %v", err)
@@ -678,6 +679,51 @@ func mustMarshal(t *testing.T, v any) json.RawMessage {
678679
return data
679680
}
680681

682+
func TestStreamableClientTransportApplicationJSON(t *testing.T) {
683+
// Test handling of application/json responses.
684+
ctx := context.Background()
685+
resp := func(id int64, result any, err error) *jsonrpc.Response {
686+
return &jsonrpc.Response{
687+
ID: jsonrpc2.Int64ID(id),
688+
Result: mustMarshal(t, result),
689+
Error: err,
690+
}
691+
}
692+
initResult := &InitializeResult{
693+
Capabilities: &serverCapabilities{
694+
Completions: &completionCapabilities{},
695+
Logging: &loggingCapabilities{},
696+
Tools: &toolCapabilities{ListChanged: true},
697+
},
698+
ProtocolVersion: latestProtocolVersion,
699+
ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"},
700+
}
701+
initResp := resp(1, initResult, nil)
702+
703+
serverHandler := func(w http.ResponseWriter, r *http.Request) {
704+
data, err := jsonrpc2.EncodeMessage(initResp)
705+
if err != nil {
706+
t.Fatal(err)
707+
}
708+
w.Header().Set("Content-Type", "application/json")
709+
w.Write(data)
710+
}
711+
712+
httpServer := httptest.NewServer(http.HandlerFunc(serverHandler))
713+
defer httpServer.Close()
714+
715+
transport := NewStreamableClientTransport(httpServer.URL, nil)
716+
client := NewClient(testImpl, nil)
717+
session, err := client.Connect(ctx, transport)
718+
if err != nil {
719+
t.Fatalf("client.Connect() failed: %v", err)
720+
}
721+
defer session.Close()
722+
if diff := cmp.Diff(initResult, session.initializeResult); diff != "" {
723+
t.Errorf("mismatch (-want, +got):\n%s", diff)
724+
}
725+
}
726+
681727
func TestEventID(t *testing.T) {
682728
tests := []struct {
683729
sid StreamID

0 commit comments

Comments
 (0)