diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 25dd224e..f50e044e 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -318,30 +318,6 @@ func TestStreamableServerTransport(t *testing.T) { // faking the behavior of a streamable client using a sequence of HTTP // requests. - // A step is a single step in the tests below, consisting of a request payload - // and expected response. - type step struct { - // If OnRequest is > 0, this step only executes after a request with the - // given ID is received. - // - // All OnRequest steps must occur before the step that creates the request. - // - // To avoid tests hanging when there's a bug, it's expected that this - // request is received in the course of a *synchronous* request to the - // server (otherwise, we wouldn't be able to terminate the test without - // analyzing a dependency graph). - OnRequest int64 - // If set, Async causes the step to run asynchronously to other steps. - // Redundant with OnRequest: all OnRequest steps are asynchronous. - Async bool - - Method string // HTTP request method - Send []jsonrpc.Message // messages to send - CloseAfter int // if nonzero, close after receiving this many messages - StatusCode int // expected status code - Recv []jsonrpc.Message // expected messages to receive - } - // JSON-RPC message constructors. req := func(id int64, method string, params any) *jsonrpc.Request { r := &jsonrpc.Request{ @@ -372,33 +348,67 @@ func TestStreamableServerTransport(t *testing.T) { ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, }, nil) initializedMsg := req(0, notificationInitialized, &InitializedParams{}) - initialize := step{ - Method: "POST", - Send: []jsonrpc.Message{initReq}, - StatusCode: http.StatusOK, - Recv: []jsonrpc.Message{initResp}, + initialize := streamableRequest{ + method: "POST", + messages: []jsonrpc.Message{initReq}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{initResp}, } - initialized := step{ - Method: "POST", - Send: []jsonrpc.Message{initializedMsg}, - StatusCode: http.StatusAccepted, + initialized := streamableRequest{ + method: "POST", + messages: []jsonrpc.Message{initializedMsg}, + wantStatusCode: http.StatusAccepted, } tests := []struct { name string tool func(*testing.T, context.Context, *ServerSession) - steps []step + steps []streamableRequest }{ { name: "basic", - steps: []step{ + steps: []streamableRequest{ + initialize, + initialized, + { + method: "POST", + messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{}, nil)}, + }, + }, + }, + { + name: "accept headers", + steps: []streamableRequest{ initialize, initialized, + // Test various accept headers. + { + method: "POST", + accept: []string{"text/plain", "application/*"}, + messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusBadRequest, // missing text/event-stream + }, + { + method: "POST", + accept: []string{"text/event-stream"}, + messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusBadRequest, // missing application/json + }, + { + method: "POST", + accept: []string{"text/plain", "*/*"}, + messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{}, nil)}, + }, { - Method: "POST", - Send: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, - StatusCode: http.StatusOK, - Recv: []jsonrpc.Message{resp(2, &CallToolResult{}, nil)}, + method: "POST", + accept: []string{"text/*, application/*"}, + messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{}, nil)}, }, }, }, @@ -410,16 +420,16 @@ func TestStreamableServerTransport(t *testing.T) { t.Errorf("Notify failed: %v", err) } }, - steps: []step{ + steps: []streamableRequest{ initialize, initialized, { - Method: "POST", - Send: []jsonrpc.Message{ + method: "POST", + messages: []jsonrpc.Message{ req(2, "tools/call", &CallToolParams{Name: "tool"}), }, - StatusCode: http.StatusOK, - Recv: []jsonrpc.Message{ + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{ req(0, "notifications/progress", &ProgressNotificationParams{}), resp(2, &CallToolResult{}, nil), }, @@ -434,24 +444,24 @@ func TestStreamableServerTransport(t *testing.T) { t.Errorf("Call failed: %v", err) } }, - steps: []step{ + steps: []streamableRequest{ initialize, initialized, { - Method: "POST", - OnRequest: 1, - Send: []jsonrpc.Message{ + method: "POST", + onRequest: 1, + messages: []jsonrpc.Message{ resp(1, &ListRootsResult{}, nil), }, - StatusCode: http.StatusAccepted, + wantStatusCode: http.StatusAccepted, }, { - Method: "POST", - Send: []jsonrpc.Message{ + method: "POST", + messages: []jsonrpc.Message{ req(2, "tools/call", &CallToolParams{Name: "tool"}), }, - StatusCode: http.StatusOK, - Recv: []jsonrpc.Message{ + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{ req(1, "roots/list", &ListRootsParams{}), resp(2, &CallToolResult{}, nil), }, @@ -475,34 +485,34 @@ func TestStreamableServerTransport(t *testing.T) { t.Errorf("Notify failed: %v", err) } }, - steps: []step{ + steps: []streamableRequest{ initialize, initialized, { - Method: "POST", - OnRequest: 1, - Send: []jsonrpc.Message{ + method: "POST", + onRequest: 1, + messages: []jsonrpc.Message{ resp(1, &ListRootsResult{}, nil), }, - StatusCode: http.StatusAccepted, + wantStatusCode: http.StatusAccepted, }, { - Method: "GET", - Async: true, - StatusCode: http.StatusOK, - CloseAfter: 2, - Recv: []jsonrpc.Message{ + method: "GET", + async: true, + wantStatusCode: http.StatusOK, + closeAfter: 2, + wantMessages: []jsonrpc.Message{ req(0, "notifications/progress", &ProgressNotificationParams{}), req(1, "roots/list", &ListRootsParams{}), }, }, { - Method: "POST", - Send: []jsonrpc.Message{ + method: "POST", + messages: []jsonrpc.Message{ req(2, "tools/call", &CallToolParams{Name: "tool"}), }, - StatusCode: http.StatusOK, - Recv: []jsonrpc.Message{ + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{ resp(2, &CallToolResult{}, nil), }, }, @@ -510,30 +520,30 @@ func TestStreamableServerTransport(t *testing.T) { }, { name: "errors", - steps: []step{ + steps: []streamableRequest{ { - Method: "PUT", - StatusCode: http.StatusMethodNotAllowed, + method: "PUT", + wantStatusCode: http.StatusMethodNotAllowed, }, { - Method: "DELETE", - StatusCode: http.StatusBadRequest, + method: "DELETE", + wantStatusCode: http.StatusBadRequest, }, { - Method: "POST", - Send: []jsonrpc.Message{req(1, "notamethod", nil)}, - StatusCode: http.StatusBadRequest, // notamethod is an invalid method + method: "POST", + messages: []jsonrpc.Message{req(1, "notamethod", nil)}, + wantStatusCode: http.StatusBadRequest, // notamethod is an invalid method }, { - Method: "POST", - Send: []jsonrpc.Message{req(0, "tools/call", &CallToolParams{Name: "tool"})}, - StatusCode: http.StatusBadRequest, // tools/call must have an ID + method: "POST", + messages: []jsonrpc.Message{req(0, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusBadRequest, // tools/call must have an ID }, { - Method: "POST", - Send: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, - StatusCode: http.StatusOK, - Recv: []jsonrpc.Message{resp(2, nil, &jsonrpc2.WireError{ + method: "POST", + messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(2, nil, &jsonrpc2.WireError{ Message: `method "tools/call" is invalid during session initialization`, })}, }, @@ -567,8 +577,8 @@ func TestStreamableServerTransport(t *testing.T) { var mu sync.Mutex blocks := make(map[int64]chan struct{}) for _, step := range test.steps { - if step.OnRequest > 0 { - blocks[step.OnRequest] = make(chan struct{}) + if step.onRequest > 0 { + blocks[step.onRequest] = make(chan struct{}) } } @@ -582,16 +592,16 @@ func TestStreamableServerTransport(t *testing.T) { sessionID.Store("") // doStep executes a single step. - doStep := func(t *testing.T, step step) { - if step.OnRequest > 0 { + doStep := func(t *testing.T, step streamableRequest) { + if step.onRequest > 0 { // Block the step until we've received the server->client request. mu.Lock() - block := blocks[step.OnRequest] + block := blocks[step.onRequest] mu.Unlock() select { case <-block: case <-syncRequestsDone: - t.Errorf("after all sync requests are complete, request still blocked on %d", step.OnRequest) + t.Errorf("after all sync requests are complete, request still blocked on %d", step.onRequest) return } } @@ -623,14 +633,13 @@ func TestStreamableServerTransport(t *testing.T) { mu.Unlock() } got = append(got, m) - if step.CloseAfter > 0 && len(got) == step.CloseAfter { + if step.closeAfter > 0 && len(got) == step.closeAfter { cancel() } } }() - gotSessionID, gotStatusCode, err := streamingRequest(ctx, - httpServer.URL, sessionID.Load().(string), step.Method, step.Send, out) + gotSessionID, gotStatusCode, err := step.do(ctx, httpServer.URL, sessionID.Load().(string), out) // Don't fail on cancelled requests: error (if any) is handled // elsewhere. @@ -638,13 +647,13 @@ func TestStreamableServerTransport(t *testing.T) { t.Fatal(err) } - if gotStatusCode != step.StatusCode { - t.Errorf("got status %d, want %d", gotStatusCode, step.StatusCode) + if gotStatusCode != step.wantStatusCode { + t.Errorf("got status %d, want %d", gotStatusCode, step.wantStatusCode) } wg.Wait() transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id jsonrpc.ID) any { return id.Raw() }) - if diff := cmp.Diff(step.Recv, got, transform); diff != "" { + if diff := cmp.Diff(step.wantMessages, got, transform); diff != "" { t.Errorf("received unexpected messages (-want +got):\n%s", diff) } sessionID.CompareAndSwap("", gotSessionID) @@ -652,7 +661,7 @@ func TestStreamableServerTransport(t *testing.T) { var wg sync.WaitGroup for _, step := range test.steps { - if step.Async || step.OnRequest > 0 { + if step.async || step.onRequest > 0 { wg.Add(1) go func() { defer wg.Done() @@ -672,6 +681,33 @@ func TestStreamableServerTransport(t *testing.T) { } } +// A streamableRequest describes a single streamable HTTP request, consisting +// of a request payload and expected response. +type streamableRequest struct { + // If onRequest is > 0, this step only executes after a request with the + // given ID is received. + // + // All onRequest steps must occur before the step that creates the request. + // + // To avoid tests hanging when there's a bug, it's expected that this + // request is received in the course of a *synchronous* request to the + // server (otherwise, we wouldn't be able to terminate the test without + // analyzing a dependency graph). + onRequest int64 + // If set, async causes the step to run asynchronously to other steps. + // Redundant with OnRequest: all OnRequest steps are asynchronous. + async bool + + // Request attributes + method string // HTTP request method + accept []string // if non-empty, the Accept header to use; otherwise the default header is used + messages []jsonrpc.Message // messages to send + + closeAfter int // if nonzero, close after receiving this many messages + wantStatusCode int // expected status code + wantMessages []jsonrpc.Message // expected messages to receive +} + // streamingRequest makes a request to the given streamable server with the // given url, sessionID, and method. // @@ -685,19 +721,19 @@ func TestStreamableServerTransport(t *testing.T) { // Returns the sessionID and http status code from the response. If an error is // returned, sessionID and status code may still be set if the error occurs // after the response headers have been received. -func streamingRequest(ctx context.Context, serverURL, sessionID, method string, in []jsonrpc.Message, out chan<- jsonrpc.Message) (string, int, error) { +func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, out chan<- jsonrpc.Message) (string, int, error) { defer close(out) var body []byte - if len(in) == 1 { - data, err := jsonrpc2.EncodeMessage(in[0]) + if len(s.messages) == 1 { + data, err := jsonrpc2.EncodeMessage(s.messages[0]) if err != nil { return "", 0, fmt.Errorf("encoding message: %w", err) } body = data } else { var rawMsgs []json.RawMessage - for _, msg := range in { + for _, msg := range s.messages { data, err := jsonrpc2.EncodeMessage(msg) if err != nil { return "", 0, fmt.Errorf("encoding message: %w", err) @@ -711,7 +747,7 @@ func streamingRequest(ctx context.Context, serverURL, sessionID, method string, body = data } - req, err := http.NewRequestWithContext(ctx, method, serverURL, bytes.NewReader(body)) + req, err := http.NewRequestWithContext(ctx, s.method, serverURL, bytes.NewReader(body)) if err != nil { return "", 0, fmt.Errorf("creating request: %w", err) } @@ -719,8 +755,13 @@ func streamingRequest(ctx context.Context, serverURL, sessionID, method string, req.Header.Set("Mcp-Session-Id", sessionID) } req.Header.Set("Content-Type", "application/json") - req.Header.Add("Accept", "text/plain") // ensure multiple accept headers are allowed - req.Header.Add("Accept", "application/json, text/event-stream") + if len(s.accept) > 0 { + for _, accept := range s.accept { + req.Header.Add("Accept", accept) + } + } else { + req.Header.Add("Accept", "application/json, text/event-stream") + } resp, err := http.DefaultClient.Do(req) if err != nil {