Skip to content

Commit a2a3df9

Browse files
committed
stream: chnage Stream interface methods arg to like io.ReadWriter
1 parent ed2543a commit a2a3df9

File tree

1 file changed

+28
-28
lines changed

1 file changed

+28
-28
lines changed

stream.go

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ import (
1919
// Stream abstracts the transport mechanics from the JSON RPC protocol.
2020
type Stream interface {
2121
// Read gets the next message from the stream.
22-
Read(context.Context) ([]byte, error)
22+
Read(ctx context.Context, p []byte) (n int, err error)
23+
2324
// Write sends a message to the stream.
24-
Write(context.Context, []byte) error
25+
Write(ctx context.Context, p []byte) (n int, err error)
2526
}
2627

2728
type stream struct {
@@ -30,26 +31,25 @@ type stream struct {
3031
sync.Mutex
3132
}
3233

33-
// NewStream returns the new Stream.
3434
func NewStream(in io.Reader, out io.Writer) Stream {
3535
return &stream{
3636
in: bufio.NewReader(in),
3737
out: out,
3838
}
3939
}
4040

41-
func (s *stream) Read(ctx context.Context) ([]byte, error) {
41+
func (s *stream) Read(ctx context.Context, p []byte) (n int, err error) {
4242
select {
4343
case <-ctx.Done():
44-
return nil, ctx.Err()
44+
return 0, ctx.Err()
4545
default:
4646
}
4747

4848
var length int64
4949
for {
5050
line, err := s.in.ReadString('\n')
5151
if err != nil {
52-
return nil, xerrors.Errorf("failed reading header line: %w", err)
52+
return 0, xerrors.Errorf("failed reading header line: %w", err)
5353
}
5454

5555
line = strings.TrimSpace(line)
@@ -59,49 +59,49 @@ func (s *stream) Read(ctx context.Context) ([]byte, error) {
5959

6060
colon := strings.IndexRune(line, ':')
6161
if colon < 0 {
62-
return nil, xerrors.Errorf("invalid header line: %q", line)
62+
return 0, xerrors.Errorf("invalid header line: %q", line)
6363
}
6464

6565
name, value := line[:colon], strings.TrimSpace(line[colon+1:])
66-
switch name {
67-
case "Content-Length":
68-
if length, err = strconv.ParseInt(value, 10, 32); err != nil {
69-
return nil, xerrors.Errorf("failed parsing Content-Length: %v", value)
70-
}
71-
72-
if length <= 0 {
73-
return nil, xerrors.Errorf("invalid Content-Length: %v", length)
74-
}
75-
default:
76-
// ignoring unknown headers
66+
if name != "Content-Length" {
67+
continue
68+
}
69+
70+
if length, err = strconv.ParseInt(value, 10, 32); err != nil {
71+
return 0, xerrors.Errorf("failed parsing Content-Length: %v", value)
72+
}
73+
74+
if length <= 0 {
75+
return 0, xerrors.Errorf("invalid Content-Length: %v", length)
7776
}
7877
}
7978

8079
if length == 0 {
81-
return nil, xerrors.New("missing Content-Length header")
80+
return 0, xerrors.New("missing Content-Length header")
8281
}
8382

84-
data := make([]byte, length)
85-
if _, err := io.ReadFull(s.in, data); err != nil {
86-
return nil, xerrors.Errorf("failed reading data: %w", err)
83+
p = make([]byte, length)
84+
n, err = io.ReadFull(s.in, p)
85+
if err != nil {
86+
return 0, xerrors.Errorf("failed reading data: %w", err)
8787
}
8888

89-
return data, nil
89+
return n, nil
9090
}
9191

92-
func (s *stream) Write(ctx context.Context, data []byte) error {
92+
func (s *stream) Write(ctx context.Context, p []byte) (n int, err error) {
9393
select {
9494
case <-ctx.Done():
95-
return ctx.Err()
95+
return 0, ctx.Err()
9696
default:
9797
}
9898

9999
s.Lock()
100-
_, err := fmt.Fprintf(s.out, "Content-Length: %v\r\n\r\n", len(data))
100+
n, err = fmt.Fprintf(s.out, "Content-Length: %v\r\n\r\n", len(p))
101101
if err == nil {
102-
_, err = s.out.Write(data)
102+
n, err = s.out.Write(p)
103103
}
104104
s.Unlock()
105105

106-
return err
106+
return n, err
107107
}

0 commit comments

Comments
 (0)