From 73f668cca650597d1a297e229941bbbed91ec96f Mon Sep 17 00:00:00 2001 From: Robert Terhaar Date: Sun, 14 Sep 2025 19:02:53 +0300 Subject: [PATCH] mcp: allow injecting custom streams in StdioTransport Add optional In and Out fields to StdioTransport to allow injection of custom io.ReadCloser and io.WriteCloser streams. This enables testing with io.Pipe() instead of requiring os.Stdin/os.Stdout. The Connect method now checks for nil fields and uses os.Stdin/os.Stdout as defaults. --- mcp/transport.go | 18 +++++-- mcp/transport_test.go | 116 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 3 deletions(-) diff --git a/mcp/transport.go b/mcp/transport.go index 024863de..a2492bc7 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -86,11 +86,23 @@ type serverConnection interface { // A StdioTransport is a [Transport] that communicates over stdin/stdout using // newline-delimited JSON. -type StdioTransport struct{} +type StdioTransport struct { + In io.ReadCloser + Out io.WriteCloser +} // Connect implements the [Transport] interface. -func (*StdioTransport) Connect(context.Context) (Connection, error) { - return newIOConn(rwc{os.Stdin, os.Stdout}), nil +func (t *StdioTransport) Connect(context.Context) (Connection, error) { + in := t.In + out := t.Out + + if in == nil { + in = os.Stdin + } + if out == nil { + out = os.Stdout + } + return newIOConn(rwc{in, out}), nil } // An InMemoryTransport is a [Transport] that communicates over an in-memory diff --git a/mcp/transport_test.go b/mcp/transport_test.go index d40ce10f..edc2a476 100644 --- a/mcp/transport_test.go +++ b/mcp/transport_test.go @@ -117,3 +117,119 @@ func TestIOConnRead(t *testing.T) { }) } } + +func TestStdioTransport(t *testing.T) { + tests := []struct { + name string + setupIn func() io.ReadCloser + setupOut func() io.WriteCloser + wantErr bool + }{ + { + name: "defaults_use_stdin_stdout", + setupIn: func() io.ReadCloser { return nil }, + setupOut: func() io.WriteCloser { return nil }, + wantErr: false, + }, + { + name: "custom_streams", + setupIn: func() io.ReadCloser { r, _ := io.Pipe(); return r }, + setupOut: func() io.WriteCloser { _, w := io.Pipe(); return w }, + wantErr: false, + }, + { + name: "partial_custom_in_only", + setupIn: func() io.ReadCloser { return io.NopCloser(strings.NewReader("")) }, + setupOut: func() io.WriteCloser { return nil }, + wantErr: false, + }, + { + name: "partial_custom_out_only", + setupIn: func() io.ReadCloser { return nil }, + setupOut: func() io.WriteCloser { _, w := io.Pipe(); return w }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport := &StdioTransport{ + In: tt.setupIn(), + Out: tt.setupOut(), + } + + conn, err := transport.Connect(context.Background()) + if (err != nil) != tt.wantErr { + t.Errorf("StdioTransport.Connect() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if conn == nil { + t.Error("StdioTransport.Connect() returned nil connection") + return + } + + defer conn.Close() + }) + } +} + +func TestStdioTransportDefaults(t *testing.T) { + transport := &StdioTransport{} + + if transport.In != nil { + t.Error("StdioTransport{}.In should be nil (uses default)") + } + + if transport.Out != nil { + t.Error("StdioTransport{}.Out should be nil (uses default)") + } + + conn, err := transport.Connect(context.Background()) + if err != nil { + t.Fatalf("StdioTransport{}.Connect() failed: %v", err) + } + defer conn.Close() +} + +func TestStdioTransportReadWrite(t *testing.T) { + ctx := context.Background() + r, w := io.Pipe() + defer r.Close() + defer w.Close() + + transport := &StdioTransport{ + In: r, + Out: w, + } + + conn, err := transport.Connect(ctx) + if err != nil { + t.Fatalf("StdioTransport.Connect() failed: %v", err) + } + defer conn.Close() + + // Test that we can write a message and it gets transmitted + testMsg := &jsonrpc.Request{ + ID: jsonrpc2.Int64ID(1), + Method: "test", + Params: nil, + } + + // Write message in a goroutine since pipe may block + go func() { + if err := conn.Write(ctx, testMsg); err != nil { + t.Errorf("conn.Write() failed: %v", err) + } + }() + + // Read the message back + receivedMsg, err := conn.Read(ctx) + if err != nil { + t.Fatalf("conn.Read() failed: %v", err) + } + + if req, ok := receivedMsg.(*jsonrpc.Request); !ok || req.Method != "test" { + t.Errorf("Expected request with method 'test', got %v", receivedMsg) + } +}