Skip to content

Commit 0657788

Browse files
committed
refactor: code review changes
1 parent d22b2b0 commit 0657788

File tree

2 files changed

+35
-39
lines changed

2 files changed

+35
-39
lines changed

mcp/transport.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -422,15 +422,17 @@ func (t *ioConn) Read(ctx context.Context) (jsonrpc.Message, error) {
422422
}
423423

424424
// Read the next byte to check if there is trailing data.
425-
tr := make([]byte, 1)
426-
_, err := in.Buffered().Read(tr)
427-
if err != nil {
428-
return nil, err
425+
var tr [1]byte
426+
n, err := in.Buffered().Read(tr[:])
427+
if n > 0 {
428+
// If read byte is not a newline, it is an error.
429+
if tr[0] != '\n' {
430+
return nil, fmt.Errorf("invalid trailing data at the end of stream")
431+
}
429432
}
430-
431-
// If the next byte is not a newline, it is an error.
432-
if tr[0] != '\n' {
433-
return nil, fmt.Errorf("invalid trailing data at the end of stream")
433+
// Return error except for EOF
434+
if err != nil && err != io.EOF {
435+
return nil, err
434436
}
435437

436438
msgs, batch, err := readBatch(raw)

mcp/transport_test.go

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ package mcp
77
import (
88
"context"
99
"io"
10-
"reflect"
1110
"strings"
1211
"testing"
1312

@@ -54,43 +53,38 @@ func TestBatchFraming(t *testing.T) {
5453
}
5554
}
5655

57-
func Test_ioConn_Read_BadTrailingData(t *testing.T) {
58-
type fields struct {
59-
rwc io.ReadWriteCloser
60-
}
61-
type args struct {
62-
ctx context.Context
63-
}
56+
func TestIOConnRead(t *testing.T) {
6457
tests := []struct {
65-
name string
66-
fields fields
67-
args args
68-
want string
69-
wantErr bool
58+
name string
59+
input string
60+
want string
7061
}{
62+
7163
{
72-
name: "bad data at the end of first valid json",
73-
fields: fields{
74-
rwc: rwc{
75-
rc: io.NopCloser(strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"test","params":{}},`)),
76-
},
77-
},
78-
args: args{
79-
ctx: context.Background(),
80-
},
81-
want: "invalid trailing data at the end of stream",
82-
wantErr: true,
64+
name: "valid json input",
65+
input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}}`,
66+
want: "",
67+
},
68+
69+
{
70+
name: "newline at the end of first valid json input",
71+
input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}}
72+
`,
73+
want: "",
74+
},
75+
{
76+
name: "bad data at the end of first valid json input",
77+
input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}},`,
78+
want: "invalid trailing data at the end of stream",
8379
},
8480
}
8581
for _, tt := range tests {
8682
t.Run(tt.name, func(t *testing.T) {
87-
tr := newIOConn(tt.fields.rwc)
88-
_, err := tr.Read(tt.args.ctx)
89-
if (err != nil) != tt.wantErr {
90-
t.Errorf("ioConn.Read() error = %v, wantErr %v", err, tt.wantErr)
91-
return
92-
}
93-
if !reflect.DeepEqual(err.Error(), tt.want) {
83+
tr := newIOConn(rwc{
84+
rc: io.NopCloser(strings.NewReader(tt.input)),
85+
})
86+
_, err := tr.Read(context.Background())
87+
if err != nil && err.Error() != tt.want {
9488
t.Errorf("ioConn.Read() = %v, want %v", err.Error(), tt.want)
9589
}
9690
})

0 commit comments

Comments
 (0)