Skip to content

Commit baf68c3

Browse files
committed
jsonrpc2: fix jsonrpc2 logic
1 parent 3437b5a commit baf68c3

File tree

1 file changed

+86
-74
lines changed

1 file changed

+86
-74
lines changed

jsonrpc2.go

Lines changed: 86 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ package jsonrpc2
66

77
import (
88
"context"
9+
"encoding/json"
910
"sync"
1011
"time"
1112

12-
"github.com/francoispqt/gojay"
1313
"go.uber.org/atomic"
1414
"go.uber.org/zap"
1515
)
@@ -53,29 +53,22 @@ type Canceler func(context.Context, *Conn, *Request)
5353
// Conn is a JSON RPC 2 client server connection.
5454
// Conn is bidirectional; it does not have a designated server or client end.
5555
type Conn struct {
56-
Handler Handler
57-
Canceler Canceler
58-
logger *zap.Logger
59-
capacity int
60-
overloaded bool
61-
stream Stream
62-
done chan struct{}
63-
err error
64-
seq atomic.Int64 // must only be accessed using atomic operations
65-
pendingMu sync.Mutex // protects the pending map
66-
pending map[ID]chan *Response
67-
handlingMu sync.Mutex // protects the handling map
68-
handling map[ID]handling
56+
seq *atomic.Int64 // must only be accessed using atomic operations
57+
Handler Handler
58+
Canceler Canceler
59+
Logger *zap.Logger
60+
Capacity int
61+
RejectIfOverloaded bool
62+
stream Stream
63+
err error
64+
pendingMu sync.Mutex // protects the pending map
65+
pending map[ID]chan *Response
66+
handlingMu sync.Mutex // protects the handling map
67+
handling map[ID]handling
6968
}
7069

7170
var _ Interface = (*Conn)(nil)
7271

73-
type handling struct {
74-
request *Request
75-
cancel context.CancelFunc
76-
start time.Time
77-
}
78-
7972
type queueEntry struct {
8073
ctx context.Context
8174
conn *Conn
@@ -99,42 +92,45 @@ func WithCanceler(canceler Canceler) Options {
9992
}
10093
}
10194

102-
// WithLogger apply custom logger to Conn.
95+
// WithLogger apply custom Logger to Conn.
10396
func WithLogger(logger *zap.Logger) Options {
10497
return func(c *Conn) {
105-
c.logger = logger
98+
c.Logger = logger
10699
}
107100
}
108101

109102
// WithCapacity apply custom capacity to Conn.
110103
func WithCapacity(capacity int) Options {
111104
return func(c *Conn) {
112-
c.capacity = capacity
105+
c.Capacity = capacity
113106
}
114107
}
115108

116-
// WithOverloaded apply overloaded boolean to Conn.
117-
func WithOverloaded(overloaded bool) Options {
109+
// WithOverloaded apply RejectIfOverloaded boolean to Conn.
110+
func WithOverloaded(rejectIfOverloaded bool) Options {
118111
return func(c *Conn) {
119-
c.overloaded = overloaded
112+
c.RejectIfOverloaded = rejectIfOverloaded
120113
}
121114
}
122115

123-
var defaultHandler = func(ctx context.Context, c *Conn, r *Request) {
124-
if r.IsNotify() {
125-
c.Reply(ctx, r, nil, Errorf(CodeMethodNotFound, "method %q not found", r.Method))
116+
var defaultHandler = func(ctx context.Context, conn *Conn, req *Request) {
117+
if req.IsNotify() {
118+
conn.Reply(ctx, req, nil, Errorf(CodeMethodNotFound, "method %q not found", req.Method))
126119
}
127120
}
128121

129122
var defaultCanceler = func(context.Context, *Conn, *Request) {}
130123

124+
var defaultLogger = zap.NewNop()
125+
131126
// NewConn creates a new connection object that reads and writes messages from
132127
// the supplied stream and dispatches incoming messages to the supplied handler.
133-
func NewConn(ctx context.Context, s Stream, options ...Options) *Conn {
128+
func NewConn(s Stream, options ...Options) *Conn {
134129
conn := &Conn{
130+
seq: new(atomic.Int64),
135131
Handler: defaultHandler, // the default handler reports a method error
136132
Canceler: defaultCanceler, // the default canceller does nothing
137-
logger: zap.NewNop(), // the default logger does nothing
133+
Logger: defaultLogger, // the default Logger does nothing
138134
stream: s,
139135
pending: make(map[ID]chan *Response),
140136
handling: make(map[ID]handling),
@@ -149,6 +145,7 @@ func NewConn(ctx context.Context, s Stream, options ...Options) *Conn {
149145

150146
// Cancel cancels a pending Call on the server side.
151147
func (c *Conn) Cancel(id ID) {
148+
c.Logger.Debug("Cancel")
152149
c.handlingMu.Lock()
153150
handling, found := c.handling[id]
154151
c.handlingMu.Unlock()
@@ -160,6 +157,7 @@ func (c *Conn) Cancel(id ID) {
160157

161158
// Notify is called to send a notification request over the connection.
162159
func (c *Conn) Notify(ctx context.Context, method string, params interface{}) error {
160+
c.Logger.Debug("Notify")
163161
p, err := c.marshalInterface(params)
164162
if err != nil {
165163
return Errorf(CodeParseError, "failed to marshaling notify parameters: %w", err)
@@ -170,12 +168,12 @@ func (c *Conn) Notify(ctx context.Context, method string, params interface{}) er
170168
Method: method,
171169
Params: p,
172170
}
173-
data, err := gojay.MarshalJSONObject(req)
171+
data, err := json.Marshal(req) // TODO(zchee): use gojay
174172
if err != nil {
175173
return Errorf(CodeParseError, "failed to marshaling notify request: %w", err)
176174
}
177175

178-
c.logger.Debug(Send,
176+
c.Logger.Debug(Send,
179177
zap.String("req.Method", req.Method),
180178
zap.Any("req.Params", req.Params),
181179
)
@@ -190,6 +188,7 @@ func (c *Conn) Notify(ctx context.Context, method string, params interface{}) er
190188

191189
// Call sends a request over the connection and then waits for a response.
192190
func (c *Conn) Call(ctx context.Context, method string, params, result interface{}) error {
191+
c.Logger.Debug("Call")
193192
p, err := c.marshalInterface(params)
194193
if err != nil {
195194
return Errorf(CodeParseError, "failed to marshaling call parameters: %w", err)
@@ -204,11 +203,10 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface
204203
}
205204

206205
// marshal the request now it is complete
207-
data, err := gojay.MarshalJSONObject(req)
206+
data, err := json.Marshal(req) // TODO(zchee): use gojay
208207
if err != nil {
209208
return Errorf(CodeParseError, "failed to marshaling call request: %w", err)
210209
}
211-
c.logger.Debug("gojay.MarshalJSONObject(req)", zap.ByteString("data", data))
212210

213211
rchan := make(chan *Response)
214212

@@ -222,7 +220,7 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface
222220
}()
223221

224222
start := time.Now()
225-
c.logger.Debug(Send,
223+
c.Logger.Debug(Send,
226224
zap.String("req.JSONRPC", req.JSONRPC),
227225
zap.String("id", id.String()),
228226
zap.String("req.method", req.Method),
@@ -233,31 +231,35 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface
233231
return Errorf(CodeInternalError, "failed to write call request data to steam: %w", err)
234232
}
235233

234+
// wait for the response
236235
select {
237236
case resp := <-rchan:
238-
c.logger.Debug(Receive,
239-
zap.String("req.JSONRPC", req.JSONRPC),
240-
zap.String("id", id.String()),
241-
zap.Duration("elapsed", time.Since(start)),
237+
elapsed := time.Since(start)
238+
c.Logger.Debug(Receive,
239+
zap.String("resp.ID", resp.ID.String()),
240+
zap.Duration("elapsed", elapsed),
242241
zap.String("req.method", req.Method),
243242
zap.Any("resp.Result", resp.Result),
244-
zap.Error(resp.Error),
245243
)
246244

245+
// is it an error response?
247246
if resp.Error != nil {
248247
return resp.Error
249248
}
249+
250250
if result == nil || resp.Result == nil {
251251
return nil
252252
}
253253

254-
if err := gojay.Unsafe.Unmarshal(*resp.Result, result); err != nil {
254+
if err := json.Unmarshal(*resp.Result, result); err != nil {
255+
// if err := gojay.Unsafe.Unmarshal(*resp.Result, result); err != nil {
255256
return Errorf(CodeParseError, "failed to unmarshalling result: %w", err)
256257
}
257258

258259
return nil
259260

260261
case <-ctx.Done():
262+
// allow the handler to propagate the cancel
261263
c.Canceler(ctx, c, req)
262264

263265
return ctx.Err()
@@ -266,6 +268,7 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface
266268

267269
// Reply sends a reply to the given request.
268270
func (c *Conn) Reply(ctx context.Context, req *Request, result interface{}, err error) error {
271+
c.Logger.Debug("Reply")
269272
if req.IsNotify() {
270273
return NewError(CodeInvalidRequest, "reply not invoked with a valid call")
271274
}
@@ -282,54 +285,66 @@ func (c *Conn) Reply(ctx context.Context, req *Request, result interface{}, err
282285

283286
elapsed := time.Since(handling.start)
284287

288+
var raw *json.RawMessage
289+
if err == nil {
290+
raw, err = c.marshalInterface(result)
291+
}
292+
285293
resp := &Response{
286294
JSONRPC: Version,
287295
ID: req.ID,
296+
Result: raw,
288297
}
289298

290-
if err == nil {
291-
if resp.Result, err = c.marshalInterface(result); err != nil {
292-
return err
299+
if err != nil {
300+
if callErr, ok := err.(*Error); ok {
301+
resp.Error = callErr
302+
} else {
303+
resp.Error = Errorf(0, "%s", err)
293304
}
294-
} else {
295-
resp.Error = NewError(CodeParseError, err)
296305
}
297306

298-
data, err := gojay.MarshalJSONObject(resp)
307+
data, err := json.Marshal(resp) // TODO(zchee): use gojay
299308
if err != nil {
300-
c.logger.Error(Send,
309+
c.Logger.Error(Send,
301310
zap.String("resp.ID", resp.ID.String()),
302311
zap.Duration("elapsed", elapsed),
303312
zap.String("req.Method", req.Method),
304313
zap.Any("resp.Result", resp.Result),
305314
zap.Error(err),
306315
)
307316
return Errorf(CodeParseError, "failed to marshaling reply response: %w", err)
317+
// return err
308318
}
309319

310-
c.logger.Debug(Send,
320+
c.Logger.Debug(Send,
311321
zap.String("resp.ID", resp.ID.String()),
312322
zap.Duration("elapsed", elapsed),
313323
zap.String("req.Method", req.Method),
314324
zap.Any("resp.Result", resp.Result),
315-
zap.Error(resp.Error),
316325
)
317326

318327
if err := c.stream.Write(ctx, data); err != nil {
328+
// TODO(iancottrell): if a stream write fails, we really need to shut down
329+
// the whole stream
319330
return Errorf(CodeInternalError, "failed to write response data to steam: %w", err)
320331
}
321332

322333
return nil
323334
}
324335

336+
type handling struct {
337+
request *Request
338+
cancel context.CancelFunc
339+
start time.Time
340+
}
341+
325342
func (c *Conn) deliver(ctx context.Context, q chan queueEntry, request *Request) bool {
326-
e := queueEntry{
327-
ctx: ctx,
328-
conn: c,
329-
request: request,
330-
}
343+
c.Logger.Debug("deliver")
344+
345+
e := queueEntry{ctx: ctx, conn: c, request: request}
331346

332-
if !c.overloaded {
347+
if !c.RejectIfOverloaded {
333348
q <- e
334349
return true
335350
}
@@ -344,7 +359,7 @@ func (c *Conn) deliver(ctx context.Context, q chan queueEntry, request *Request)
344359

345360
// Run run the jsonrpc2 server.
346361
func (c *Conn) Run(ctx context.Context) (err error) {
347-
q := make(chan queueEntry, c.capacity)
362+
q := make(chan queueEntry, c.Capacity)
348363
defer close(q)
349364

350365
// start the queue processor
@@ -353,6 +368,7 @@ func (c *Conn) Run(ctx context.Context) (err error) {
353368
if e.ctx.Err() != nil {
354369
continue
355370
}
371+
c.Logger.Debug("c.Handler", zap.Reflect("e", e.conn), zap.Reflect("e.request", e.request))
356372
c.Handler(e.ctx, e.conn, e.request)
357373
}
358374
}()
@@ -363,17 +379,13 @@ func (c *Conn) Run(ctx context.Context) (err error) {
363379
return err // read the stream failed, cannot continue
364380
}
365381

366-
c.logger.Debug(Receive, zap.ByteString("data", data), zap.Int("len(data)", len(data)))
367-
// if len(data) == 0 {
368-
// continue // stream is empty, continue
369-
// }
382+
c.Logger.Debug(Receive, zap.ByteString("data", data), zap.Int("len(data)", len(data)))
370383

371-
// read a combined message
372384
msg := &Combined{}
373-
if err := gojay.Unsafe.UnmarshalJSONObject(data, msg); err != nil {
385+
if err := json.Unmarshal(data, msg); err != nil { // TODO(zchee): use gojay
374386
// a badly formed message arrived, log it and continue
375387
// we trust the stream to have isolated the error to just this message
376-
c.logger.Debug(Receive,
388+
c.Logger.Debug(Receive,
377389
zap.Error(Errorf(CodeParseError, "unmarshal failed: %v", err)),
378390
)
379391
continue
@@ -391,7 +403,7 @@ func (c *Conn) Run(ctx context.Context) (err error) {
391403

392404
if req.IsNotify() {
393405
// handle the Notify because msg.ID is nil
394-
c.logger.Debug(Receive,
406+
c.Logger.Debug(Receive,
395407
zap.String("req.ID", req.ID.String()),
396408
zap.String("req.Method", req.Method),
397409
zap.Any("req.Params", req.Params),
@@ -409,7 +421,7 @@ func (c *Conn) Run(ctx context.Context) (err error) {
409421
start: time.Now(),
410422
}
411423
c.handlingMu.Unlock()
412-
c.logger.Debug(Receive,
424+
c.Logger.Debug(Receive,
413425
zap.String("req.ID", req.ID.String()),
414426
zap.String("req.Method", req.Method),
415427
zap.Any("req.Params", req.Params),
@@ -441,19 +453,19 @@ func (c *Conn) Run(ctx context.Context) (err error) {
441453
close(rchan) // for the range channel loop
442454

443455
default:
444-
c.logger.Warn(Receive, zap.Error(NewError(CodeInvalidParams, "ignoring because message not a call, notify or response")))
456+
c.Logger.Warn(Receive, zap.Error(NewError(CodeInvalidParams, "ignoring because message not a call, notify or response")))
445457
}
446458
}
447459
}
448460

449461
// marshalInterface marshal obj to RawMessage.
450-
func (c *Conn) marshalInterface(obj interface{}) (*RawMessage, error) {
451-
data, err := gojay.MarshalAny(obj)
462+
// TODO(zchee): use gojay
463+
func (c *Conn) marshalInterface(obj interface{}) (*json.RawMessage, error) {
464+
data, err := json.Marshal(obj)
452465
if err != nil {
453466
return nil, err
454467
}
455-
msg := RawMessage(gojay.EmbeddedJSON(data))
456-
c.logger.Debug("marshalInterface", zap.String("msg", msg.String()))
457-
458-
return &msg, nil
468+
raw := json.RawMessage(data)
469+
c.Logger.Debug("marshalInterface", zap.ByteString("raw", raw))
470+
return &raw, nil
459471
}

0 commit comments

Comments
 (0)