Skip to content

Commit c145907

Browse files
committed
jsonrpc2: add io.ReadWriter interface to Interface
1 parent a2a3df9 commit c145907

File tree

1 file changed

+46
-30
lines changed

1 file changed

+46
-30
lines changed

jsonrpc2.go

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package jsonrpc2
66

77
import (
88
"context"
9+
"io"
910
"time"
1011

1112
"github.com/francoispqt/gojay"
@@ -16,15 +17,17 @@ import (
1617

1718
// Interface represents an interface for issuing requests that speak the JSON-RPC 2 protocol.
1819
type Interface interface {
19-
Call(ctx context.Context, method string, params, result interface{}) error
20+
io.ReadWriter
21+
22+
Call(ctx context.Context, method string, params, result interface{}) (err error)
2023

2124
Reply(ctx context.Context, req *Request, result interface{}, err error) error
2225

23-
Notify(ctx context.Context, method string, params interface{}) error
26+
Notify(ctx context.Context, method string, params interface{}) (err error)
2427

2528
Cancel(id ID)
2629

27-
Run(ctx context.Context) error
30+
Run(ctx context.Context) (err error)
2831
}
2932

3033
// Handler is an option you can pass to NewConn to handle incoming requests.
@@ -54,9 +57,10 @@ type Conn struct {
5457
stream Stream
5558
done chan struct{}
5659
err error
57-
seq atomic.Int64 // must only be accessed using atomic operations
58-
pending atomic.Value // map[ID]chan *Response
59-
handling atomic.Value // map[ID]handling
60+
ctx context.Context // for Read and Write only
61+
seq atomic.Int64 // must only be accessed using atomic operations
62+
pending atomic.Value // map[ID]chan *Response
63+
handling atomic.Value // map[ID]handling
6064
}
6165

6266
var _ Interface = (*Conn)(nil)
@@ -99,15 +103,13 @@ func WithOverloaded(overloaded bool) Options {
99103
}
100104
}
101105

102-
var (
103-
defaultHandler = func(ctx context.Context, c *Conn, r *Request) {
104-
if r.IsNotify() {
105-
c.Reply(ctx, r, nil, Errorf(CodeMethodNotFound, "method %q not found", r.Method))
106-
}
106+
var defaultHandler = func(ctx context.Context, c *Conn, r *Request) {
107+
if r.IsNotify() {
108+
c.Reply(ctx, r, nil, Errorf(CodeMethodNotFound, "method %q not found", r.Method))
107109
}
110+
}
108111

109-
defaultCanceler = func(context.Context, *Conn, *Request) {}
110-
)
112+
var defaultCanceler = func(context.Context, *Conn, *Request) {}
111113

112114
type handling struct {
113115
request *Request
@@ -154,6 +156,16 @@ func NewConn(ctx context.Context, s Stream, options ...Options) *Conn {
154156
return conn
155157
}
156158

159+
// Read implements io.Reader.
160+
func (c *Conn) Read(p []byte) (n int, err error) {
161+
return c.stream.Read(c.ctx, p)
162+
}
163+
164+
// Write implements io.Write.
165+
func (c *Conn) Write(p []byte) (n int, err error) {
166+
return c.stream.Write(c.ctx, p)
167+
}
168+
157169
// Call sends a request over the connection and then waits for a response.
158170
func (c *Conn) Call(ctx context.Context, method string, params, result interface{}) error {
159171
jsonParams, err := marshalToEmbedded(params)
@@ -196,7 +208,7 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface
196208
zap.String("req.method", req.Method),
197209
zap.Any("req.params", req.Params),
198210
)
199-
if err := c.stream.Write(ctx, data); err != nil {
211+
if _, err := c.stream.Write(ctx, data); err != nil {
200212
return err
201213
}
202214

@@ -216,10 +228,11 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface
216228
if result == nil || resp.Result == nil {
217229
return nil
218230
}
219-
if err := gojay.Unsafe.Unmarshal(*resp.Result, result); err != nil {
231+
if err := gojay.Unsafe.Unmarshal(*resp.Result.EmbeddedJSON, result); err != nil {
220232
return xerrors.Errorf("failed to unmarshalling result: %v", err)
221233
}
222234
return nil
235+
223236
case <-ctx.Done():
224237
c.canceler(ctx, c, req)
225238
return ctx.Err()
@@ -242,14 +255,14 @@ func (c *Conn) Reply(ctx context.Context, req *Request, result interface{}, err
242255
}
243256

244257
elapsed := time.Since(handling.start)
245-
var jsonParams *gojay.EmbeddedJSON
258+
var raw *RawMessage
246259
if err == nil {
247-
jsonParams, err = marshalToEmbedded(result)
260+
raw, err = marshalToEmbedded(result)
248261
}
249262

250263
resp := &Response{
251264
ID: req.ID,
252-
Result: jsonParams,
265+
Result: raw,
253266
}
254267

255268
if err != nil {
@@ -269,7 +282,7 @@ func (c *Conn) Reply(ctx context.Context, req *Request, result interface{}, err
269282
zap.Error(resp.Error),
270283
)
271284

272-
if err := c.stream.Write(ctx, data); err != nil {
285+
if _, err := c.stream.Write(ctx, data); err != nil {
273286
return err
274287
}
275288

@@ -297,7 +310,9 @@ func (c *Conn) Notify(ctx context.Context, method string, params interface{}) er
297310
zap.Any("req.Params", req.Params),
298311
)
299312

300-
return c.stream.Write(ctx, data)
313+
_, err = c.stream.Write(ctx, data)
314+
315+
return err
301316
}
302317

303318
// Cancel cancels a pending Call on the server side.
@@ -335,16 +350,16 @@ func (c *Conn) deliver(ctx context.Context, q chan queue, request *Request) bool
335350
// combined has all the fields of both Request and Response.
336351
// We can decode this and then work out which it is.
337352
type combined struct {
338-
VersionTag Message `json:"jsonrpc"`
339-
ID *ID `json:"id,omitempty"`
340-
Method string `json:"method"`
341-
Params *gojay.EmbeddedJSON `json:"params,omitempty"`
342-
Result *gojay.EmbeddedJSON `json:"result,omitempty"`
343-
Error *Error `json:"error,omitempty"`
353+
VersionTag Message `json:"jsonrpc"`
354+
ID *ID `json:"id,omitempty"`
355+
Method string `json:"method"`
356+
Params *RawMessage `json:"params,omitempty"`
357+
Result *RawMessage `json:"result,omitempty"`
358+
Error *Error `json:"error,omitempty"`
344359
}
345360

346361
// Run run the jsonrpc2 server.
347-
func (c *Conn) Run(ctx context.Context) error {
362+
func (c *Conn) Run(ctx context.Context) (err error) {
348363
q := make(chan queue, c.capacity)
349364
defer close(q)
350365

@@ -359,8 +374,9 @@ func (c *Conn) Run(ctx context.Context) error {
359374
}()
360375

361376
for {
377+
var data []byte
362378
// get the data for a message
363-
data, err := c.stream.Read(ctx)
379+
_, err = c.stream.Read(ctx, data)
364380
if err != nil {
365381
// the stream failed, we cannot continue
366382
return err
@@ -458,12 +474,12 @@ func (d Direction) String() string {
458474
}
459475
}
460476

461-
func marshalToEmbedded(obj interface{}) (*gojay.EmbeddedJSON, error) {
477+
func marshalToEmbedded(obj interface{}) (*RawMessage, error) {
462478
data, err := gojay.Marshal(obj)
463479
if err != nil {
464480
return nil, err
465481
}
466482
raw := gojay.EmbeddedJSON(data)
467483

468-
return &raw, nil
484+
return &RawMessage{EmbeddedJSON: &raw}, nil
469485
}

0 commit comments

Comments
 (0)