@@ -6,10 +6,10 @@ package jsonrpc2
6
6
7
7
import (
8
8
"context"
9
+ "encoding/json"
9
10
"sync"
10
11
"time"
11
12
12
- "github.com/francoispqt/gojay"
13
13
"go.uber.org/atomic"
14
14
"go.uber.org/zap"
15
15
)
@@ -53,29 +53,22 @@ type Canceler func(context.Context, *Conn, *Request)
53
53
// Conn is a JSON RPC 2 client server connection.
54
54
// Conn is bidirectional; it does not have a designated server or client end.
55
55
type Conn struct {
56
- Handler Handler
57
- Canceler Canceler
58
- logger * zap.Logger
59
- capacity int
60
- overloaded bool
61
- stream Stream
62
- done chan struct {}
63
- err error
64
- seq atomic.Int64 // must only be accessed using atomic operations
65
- pendingMu sync.Mutex // protects the pending map
66
- pending map [ID ]chan * Response
67
- handlingMu sync.Mutex // protects the handling map
68
- handling map [ID ]handling
56
+ seq * atomic.Int64 // must only be accessed using atomic operations
57
+ Handler Handler
58
+ Canceler Canceler
59
+ Logger * zap.Logger
60
+ Capacity int
61
+ RejectIfOverloaded bool
62
+ stream Stream
63
+ err error
64
+ pendingMu sync.Mutex // protects the pending map
65
+ pending map [ID ]chan * Response
66
+ handlingMu sync.Mutex // protects the handling map
67
+ handling map [ID ]handling
69
68
}
70
69
71
70
var _ Interface = (* Conn )(nil )
72
71
73
- type handling struct {
74
- request * Request
75
- cancel context.CancelFunc
76
- start time.Time
77
- }
78
-
79
72
type queueEntry struct {
80
73
ctx context.Context
81
74
conn * Conn
@@ -99,42 +92,45 @@ func WithCanceler(canceler Canceler) Options {
99
92
}
100
93
}
101
94
102
- // WithLogger apply custom logger to Conn.
95
+ // WithLogger apply custom Logger to Conn.
103
96
func WithLogger (logger * zap.Logger ) Options {
104
97
return func (c * Conn ) {
105
- c .logger = logger
98
+ c .Logger = logger
106
99
}
107
100
}
108
101
109
102
// WithCapacity apply custom capacity to Conn.
110
103
func WithCapacity (capacity int ) Options {
111
104
return func (c * Conn ) {
112
- c .capacity = capacity
105
+ c .Capacity = capacity
113
106
}
114
107
}
115
108
116
- // WithOverloaded apply overloaded boolean to Conn.
117
- func WithOverloaded (overloaded bool ) Options {
109
+ // WithOverloaded apply RejectIfOverloaded boolean to Conn.
110
+ func WithOverloaded (rejectIfOverloaded bool ) Options {
118
111
return func (c * Conn ) {
119
- c .overloaded = overloaded
112
+ c .RejectIfOverloaded = rejectIfOverloaded
120
113
}
121
114
}
122
115
123
- var defaultHandler = func (ctx context.Context , c * Conn , r * Request ) {
124
- if r .IsNotify () {
125
- c .Reply (ctx , r , nil , Errorf (CodeMethodNotFound , "method %q not found" , r .Method ))
116
+ var defaultHandler = func (ctx context.Context , conn * Conn , req * Request ) {
117
+ if req .IsNotify () {
118
+ conn .Reply (ctx , req , nil , Errorf (CodeMethodNotFound , "method %q not found" , req .Method ))
126
119
}
127
120
}
128
121
129
122
var defaultCanceler = func (context.Context , * Conn , * Request ) {}
130
123
124
+ var defaultLogger = zap .NewNop ()
125
+
131
126
// NewConn creates a new connection object that reads and writes messages from
132
127
// the supplied stream and dispatches incoming messages to the supplied handler.
133
- func NewConn (ctx context. Context , s Stream , options ... Options ) * Conn {
128
+ func NewConn (s Stream , options ... Options ) * Conn {
134
129
conn := & Conn {
130
+ seq : new (atomic.Int64 ),
135
131
Handler : defaultHandler , // the default handler reports a method error
136
132
Canceler : defaultCanceler , // the default canceller does nothing
137
- logger : zap . NewNop () , // the default logger does nothing
133
+ Logger : defaultLogger , // the default Logger does nothing
138
134
stream : s ,
139
135
pending : make (map [ID ]chan * Response ),
140
136
handling : make (map [ID ]handling ),
@@ -149,6 +145,7 @@ func NewConn(ctx context.Context, s Stream, options ...Options) *Conn {
149
145
150
146
// Cancel cancels a pending Call on the server side.
151
147
func (c * Conn ) Cancel (id ID ) {
148
+ c .Logger .Debug ("Cancel" )
152
149
c .handlingMu .Lock ()
153
150
handling , found := c .handling [id ]
154
151
c .handlingMu .Unlock ()
@@ -160,6 +157,7 @@ func (c *Conn) Cancel(id ID) {
160
157
161
158
// Notify is called to send a notification request over the connection.
162
159
func (c * Conn ) Notify (ctx context.Context , method string , params interface {}) error {
160
+ c .Logger .Debug ("Notify" )
163
161
p , err := c .marshalInterface (params )
164
162
if err != nil {
165
163
return Errorf (CodeParseError , "failed to marshaling notify parameters: %w" , err )
@@ -170,12 +168,12 @@ func (c *Conn) Notify(ctx context.Context, method string, params interface{}) er
170
168
Method : method ,
171
169
Params : p ,
172
170
}
173
- data , err := gojay . MarshalJSONObject (req )
171
+ data , err := json . Marshal (req ) // TODO(zchee): use gojay
174
172
if err != nil {
175
173
return Errorf (CodeParseError , "failed to marshaling notify request: %w" , err )
176
174
}
177
175
178
- c .logger .Debug (Send ,
176
+ c .Logger .Debug (Send ,
179
177
zap .String ("req.Method" , req .Method ),
180
178
zap .Any ("req.Params" , req .Params ),
181
179
)
@@ -190,6 +188,7 @@ func (c *Conn) Notify(ctx context.Context, method string, params interface{}) er
190
188
191
189
// Call sends a request over the connection and then waits for a response.
192
190
func (c * Conn ) Call (ctx context.Context , method string , params , result interface {}) error {
191
+ c .Logger .Debug ("Call" )
193
192
p , err := c .marshalInterface (params )
194
193
if err != nil {
195
194
return Errorf (CodeParseError , "failed to marshaling call parameters: %w" , err )
@@ -204,11 +203,10 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface
204
203
}
205
204
206
205
// marshal the request now it is complete
207
- data , err := gojay . MarshalJSONObject (req )
206
+ data , err := json . Marshal (req ) // TODO(zchee): use gojay
208
207
if err != nil {
209
208
return Errorf (CodeParseError , "failed to marshaling call request: %w" , err )
210
209
}
211
- c .logger .Debug ("gojay.MarshalJSONObject(req)" , zap .ByteString ("data" , data ))
212
210
213
211
rchan := make (chan * Response )
214
212
@@ -222,7 +220,7 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface
222
220
}()
223
221
224
222
start := time .Now ()
225
- c .logger .Debug (Send ,
223
+ c .Logger .Debug (Send ,
226
224
zap .String ("req.JSONRPC" , req .JSONRPC ),
227
225
zap .String ("id" , id .String ()),
228
226
zap .String ("req.method" , req .Method ),
@@ -233,31 +231,35 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface
233
231
return Errorf (CodeInternalError , "failed to write call request data to steam: %w" , err )
234
232
}
235
233
234
+ // wait for the response
236
235
select {
237
236
case resp := <- rchan :
238
- c . logger . Debug ( Receive ,
239
- zap . String ( "req.JSONRPC" , req . JSONRPC ) ,
240
- zap .String ("id " , id .String ()),
241
- zap .Duration ("elapsed" , time . Since ( start ) ),
237
+ elapsed := time . Since ( start )
238
+ c . Logger . Debug ( Receive ,
239
+ zap .String ("resp.ID " , resp . ID .String ()),
240
+ zap .Duration ("elapsed" , elapsed ),
242
241
zap .String ("req.method" , req .Method ),
243
242
zap .Any ("resp.Result" , resp .Result ),
244
- zap .Error (resp .Error ),
245
243
)
246
244
245
+ // is it an error response?
247
246
if resp .Error != nil {
248
247
return resp .Error
249
248
}
249
+
250
250
if result == nil || resp .Result == nil {
251
251
return nil
252
252
}
253
253
254
- if err := gojay .Unsafe .Unmarshal (* resp .Result , result ); err != nil {
254
+ if err := json .Unmarshal (* resp .Result , result ); err != nil {
255
+ // if err := gojay.Unsafe.Unmarshal(*resp.Result, result); err != nil {
255
256
return Errorf (CodeParseError , "failed to unmarshalling result: %w" , err )
256
257
}
257
258
258
259
return nil
259
260
260
261
case <- ctx .Done ():
262
+ // allow the handler to propagate the cancel
261
263
c .Canceler (ctx , c , req )
262
264
263
265
return ctx .Err ()
@@ -266,6 +268,7 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface
266
268
267
269
// Reply sends a reply to the given request.
268
270
func (c * Conn ) Reply (ctx context.Context , req * Request , result interface {}, err error ) error {
271
+ c .Logger .Debug ("Reply" )
269
272
if req .IsNotify () {
270
273
return NewError (CodeInvalidRequest , "reply not invoked with a valid call" )
271
274
}
@@ -282,54 +285,66 @@ func (c *Conn) Reply(ctx context.Context, req *Request, result interface{}, err
282
285
283
286
elapsed := time .Since (handling .start )
284
287
288
+ var raw * json.RawMessage
289
+ if err == nil {
290
+ raw , err = c .marshalInterface (result )
291
+ }
292
+
285
293
resp := & Response {
286
294
JSONRPC : Version ,
287
295
ID : req .ID ,
296
+ Result : raw ,
288
297
}
289
298
290
- if err == nil {
291
- if resp .Result , err = c .marshalInterface (result ); err != nil {
292
- return err
299
+ if err != nil {
300
+ if callErr , ok := err .(* Error ); ok {
301
+ resp .Error = callErr
302
+ } else {
303
+ resp .Error = Errorf (0 , "%s" , err )
293
304
}
294
- } else {
295
- resp .Error = NewError (CodeParseError , err )
296
305
}
297
306
298
- data , err := gojay . MarshalJSONObject (resp )
307
+ data , err := json . Marshal (resp ) // TODO(zchee): use gojay
299
308
if err != nil {
300
- c .logger .Error (Send ,
309
+ c .Logger .Error (Send ,
301
310
zap .String ("resp.ID" , resp .ID .String ()),
302
311
zap .Duration ("elapsed" , elapsed ),
303
312
zap .String ("req.Method" , req .Method ),
304
313
zap .Any ("resp.Result" , resp .Result ),
305
314
zap .Error (err ),
306
315
)
307
316
return Errorf (CodeParseError , "failed to marshaling reply response: %w" , err )
317
+ // return err
308
318
}
309
319
310
- c .logger .Debug (Send ,
320
+ c .Logger .Debug (Send ,
311
321
zap .String ("resp.ID" , resp .ID .String ()),
312
322
zap .Duration ("elapsed" , elapsed ),
313
323
zap .String ("req.Method" , req .Method ),
314
324
zap .Any ("resp.Result" , resp .Result ),
315
- zap .Error (resp .Error ),
316
325
)
317
326
318
327
if err := c .stream .Write (ctx , data ); err != nil {
328
+ // TODO(iancottrell): if a stream write fails, we really need to shut down
329
+ // the whole stream
319
330
return Errorf (CodeInternalError , "failed to write response data to steam: %w" , err )
320
331
}
321
332
322
333
return nil
323
334
}
324
335
336
+ type handling struct {
337
+ request * Request
338
+ cancel context.CancelFunc
339
+ start time.Time
340
+ }
341
+
325
342
func (c * Conn ) deliver (ctx context.Context , q chan queueEntry , request * Request ) bool {
326
- e := queueEntry {
327
- ctx : ctx ,
328
- conn : c ,
329
- request : request ,
330
- }
343
+ c .Logger .Debug ("deliver" )
344
+
345
+ e := queueEntry {ctx : ctx , conn : c , request : request }
331
346
332
- if ! c .overloaded {
347
+ if ! c .RejectIfOverloaded {
333
348
q <- e
334
349
return true
335
350
}
@@ -344,7 +359,7 @@ func (c *Conn) deliver(ctx context.Context, q chan queueEntry, request *Request)
344
359
345
360
// Run run the jsonrpc2 server.
346
361
func (c * Conn ) Run (ctx context.Context ) (err error ) {
347
- q := make (chan queueEntry , c .capacity )
362
+ q := make (chan queueEntry , c .Capacity )
348
363
defer close (q )
349
364
350
365
// start the queue processor
@@ -353,6 +368,7 @@ func (c *Conn) Run(ctx context.Context) (err error) {
353
368
if e .ctx .Err () != nil {
354
369
continue
355
370
}
371
+ c .Logger .Debug ("c.Handler" , zap .Reflect ("e" , e .conn ), zap .Reflect ("e.request" , e .request ))
356
372
c .Handler (e .ctx , e .conn , e .request )
357
373
}
358
374
}()
@@ -363,17 +379,13 @@ func (c *Conn) Run(ctx context.Context) (err error) {
363
379
return err // read the stream failed, cannot continue
364
380
}
365
381
366
- c .logger .Debug (Receive , zap .ByteString ("data" , data ), zap .Int ("len(data)" , len (data )))
367
- // if len(data) == 0 {
368
- // continue // stream is empty, continue
369
- // }
382
+ c .Logger .Debug (Receive , zap .ByteString ("data" , data ), zap .Int ("len(data)" , len (data )))
370
383
371
- // read a combined message
372
384
msg := & Combined {}
373
- if err := gojay . Unsafe . UnmarshalJSONObject (data , msg ); err != nil {
385
+ if err := json . Unmarshal (data , msg ); err != nil { // TODO(zchee): use gojay
374
386
// a badly formed message arrived, log it and continue
375
387
// we trust the stream to have isolated the error to just this message
376
- c .logger .Debug (Receive ,
388
+ c .Logger .Debug (Receive ,
377
389
zap .Error (Errorf (CodeParseError , "unmarshal failed: %v" , err )),
378
390
)
379
391
continue
@@ -391,7 +403,7 @@ func (c *Conn) Run(ctx context.Context) (err error) {
391
403
392
404
if req .IsNotify () {
393
405
// handle the Notify because msg.ID is nil
394
- c .logger .Debug (Receive ,
406
+ c .Logger .Debug (Receive ,
395
407
zap .String ("req.ID" , req .ID .String ()),
396
408
zap .String ("req.Method" , req .Method ),
397
409
zap .Any ("req.Params" , req .Params ),
@@ -409,7 +421,7 @@ func (c *Conn) Run(ctx context.Context) (err error) {
409
421
start : time .Now (),
410
422
}
411
423
c .handlingMu .Unlock ()
412
- c .logger .Debug (Receive ,
424
+ c .Logger .Debug (Receive ,
413
425
zap .String ("req.ID" , req .ID .String ()),
414
426
zap .String ("req.Method" , req .Method ),
415
427
zap .Any ("req.Params" , req .Params ),
@@ -441,19 +453,19 @@ func (c *Conn) Run(ctx context.Context) (err error) {
441
453
close (rchan ) // for the range channel loop
442
454
443
455
default :
444
- c .logger .Warn (Receive , zap .Error (NewError (CodeInvalidParams , "ignoring because message not a call, notify or response" )))
456
+ c .Logger .Warn (Receive , zap .Error (NewError (CodeInvalidParams , "ignoring because message not a call, notify or response" )))
445
457
}
446
458
}
447
459
}
448
460
449
461
// marshalInterface marshal obj to RawMessage.
450
- func (c * Conn ) marshalInterface (obj interface {}) (* RawMessage , error ) {
451
- data , err := gojay .MarshalAny (obj )
462
+ // TODO(zchee): use gojay
463
+ func (c * Conn ) marshalInterface (obj interface {}) (* json.RawMessage , error ) {
464
+ data , err := json .Marshal (obj )
452
465
if err != nil {
453
466
return nil , err
454
467
}
455
- msg := RawMessage (gojay .EmbeddedJSON (data ))
456
- c .logger .Debug ("marshalInterface" , zap .String ("msg" , msg .String ()))
457
-
458
- return & msg , nil
468
+ raw := json .RawMessage (data )
469
+ c .Logger .Debug ("marshalInterface" , zap .ByteString ("raw" , raw ))
470
+ return & raw , nil
459
471
}
0 commit comments