diff --git a/mcp/transport.go b/mcp/transport.go index 024863de..a2492bc7 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -86,11 +86,23 @@ type serverConnection interface { // A StdioTransport is a [Transport] that communicates over stdin/stdout using // newline-delimited JSON. -type StdioTransport struct{} +type StdioTransport struct { + In io.ReadCloser + Out io.WriteCloser +} // Connect implements the [Transport] interface. -func (*StdioTransport) Connect(context.Context) (Connection, error) { - return newIOConn(rwc{os.Stdin, os.Stdout}), nil +func (t *StdioTransport) Connect(context.Context) (Connection, error) { + in := t.In + out := t.Out + + if in == nil { + in = os.Stdin + } + if out == nil { + out = os.Stdout + } + return newIOConn(rwc{in, out}), nil } // An InMemoryTransport is a [Transport] that communicates over an in-memory diff --git a/mcp/transport_test.go b/mcp/transport_test.go index d40ce10f..edc2a476 100644 --- a/mcp/transport_test.go +++ b/mcp/transport_test.go @@ -117,3 +117,119 @@ func TestIOConnRead(t *testing.T) { }) } } + +func TestStdioTransport(t *testing.T) { + tests := []struct { + name string + setupIn func() io.ReadCloser + setupOut func() io.WriteCloser + wantErr bool + }{ + { + name: "defaults_use_stdin_stdout", + setupIn: func() io.ReadCloser { return nil }, + setupOut: func() io.WriteCloser { return nil }, + wantErr: false, + }, + { + name: "custom_streams", + setupIn: func() io.ReadCloser { r, _ := io.Pipe(); return r }, + setupOut: func() io.WriteCloser { _, w := io.Pipe(); return w }, + wantErr: false, + }, + { + name: "partial_custom_in_only", + setupIn: func() io.ReadCloser { return io.NopCloser(strings.NewReader("")) }, + setupOut: func() io.WriteCloser { return nil }, + wantErr: false, + }, + { + name: "partial_custom_out_only", + setupIn: func() io.ReadCloser { return nil }, + setupOut: func() io.WriteCloser { _, w := io.Pipe(); return w }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport := &StdioTransport{ + In: tt.setupIn(), + Out: tt.setupOut(), + } + + conn, err := transport.Connect(context.Background()) + if (err != nil) != tt.wantErr { + t.Errorf("StdioTransport.Connect() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if conn == nil { + t.Error("StdioTransport.Connect() returned nil connection") + return + } + + defer conn.Close() + }) + } +} + +func TestStdioTransportDefaults(t *testing.T) { + transport := &StdioTransport{} + + if transport.In != nil { + t.Error("StdioTransport{}.In should be nil (uses default)") + } + + if transport.Out != nil { + t.Error("StdioTransport{}.Out should be nil (uses default)") + } + + conn, err := transport.Connect(context.Background()) + if err != nil { + t.Fatalf("StdioTransport{}.Connect() failed: %v", err) + } + defer conn.Close() +} + +func TestStdioTransportReadWrite(t *testing.T) { + ctx := context.Background() + r, w := io.Pipe() + defer r.Close() + defer w.Close() + + transport := &StdioTransport{ + In: r, + Out: w, + } + + conn, err := transport.Connect(ctx) + if err != nil { + t.Fatalf("StdioTransport.Connect() failed: %v", err) + } + defer conn.Close() + + // Test that we can write a message and it gets transmitted + testMsg := &jsonrpc.Request{ + ID: jsonrpc2.Int64ID(1), + Method: "test", + Params: nil, + } + + // Write message in a goroutine since pipe may block + go func() { + if err := conn.Write(ctx, testMsg); err != nil { + t.Errorf("conn.Write() failed: %v", err) + } + }() + + // Read the message back + receivedMsg, err := conn.Read(ctx) + if err != nil { + t.Fatalf("conn.Read() failed: %v", err) + } + + if req, ok := receivedMsg.(*jsonrpc.Request); !ok || req.Method != "test" { + t.Errorf("Expected request with method 'test', got %v", receivedMsg) + } +}