@@ -167,16 +167,18 @@ func httpClient(ctx context.Context, addr string, namespace string, outs []inter
167167 defer httpResp .Body .Close ()
168168
169169 var resp clientResponse
170- if err := json .NewDecoder (httpResp .Body ).Decode (& resp ); err != nil {
171- return clientResponse {}, xerrors .Errorf ("http status %s unmarshaling response: %w" , httpResp .Status , err )
172- }
170+ if cr .req .ID != nil { // non-notification
171+ if err := json .NewDecoder (httpResp .Body ).Decode (& resp ); err != nil {
172+ return clientResponse {}, xerrors .Errorf ("http status %s unmarshaling response: %w" , httpResp .Status , err )
173+ }
173174
174- if resp .ID , err = normalizeID (resp .ID ); err != nil {
175- return clientResponse {}, xerrors .Errorf ("failed to response ID: %w" , err )
176- }
175+ if resp .ID , err = normalizeID (resp .ID ); err != nil {
176+ return clientResponse {}, xerrors .Errorf ("failed to response ID: %w" , err )
177+ }
177178
178- if resp .ID != cr .req .ID {
179- return clientResponse {}, xerrors .New ("request and response id didn't match" )
179+ if resp .ID != cr .req .ID {
180+ return clientResponse {}, xerrors .New ("request and response id didn't match" )
181+ }
180182 }
181183
182184 return resp , nil
@@ -220,6 +222,45 @@ func websocketClient(ctx context.Context, addr string, namespace string, outs []
220222 errors : config .errors ,
221223 }
222224
225+ requests := c .setupRequestChan ()
226+
227+ stop := make (chan struct {})
228+ exiting := make (chan struct {})
229+ c .exiting = exiting
230+
231+ var hnd reqestHandler
232+ if len (config .reverseHandlers ) > 0 {
233+ h := makeHandler (defaultServerConfig ())
234+ h .aliasedMethods = config .aliasedHandlerMethods
235+ for _ , reverseHandler := range config .reverseHandlers {
236+ h .register (reverseHandler .ns , reverseHandler .hnd )
237+ }
238+ hnd = h
239+ }
240+
241+ go (& wsConn {
242+ conn : conn ,
243+ connFactory : connFactory ,
244+ reconnectBackoff : config .reconnectBackoff ,
245+ pingInterval : config .pingInterval ,
246+ timeout : config .timeout ,
247+ handler : hnd ,
248+ requests : requests ,
249+ stop : stop ,
250+ exiting : exiting ,
251+ }).handleWsConn (ctx )
252+
253+ if err := c .provide (outs ); err != nil {
254+ return nil , err
255+ }
256+
257+ return func () {
258+ close (stop )
259+ <- exiting
260+ }, nil
261+ }
262+
263+ func (c * client ) setupRequestChan () chan clientRequest {
223264 requests := make (chan clientRequest )
224265
225266 c .doRequest = func (ctx context.Context , cr clientRequest ) (clientResponse , error ) {
@@ -245,12 +286,18 @@ func websocketClient(ctx context.Context, addr string, namespace string, outs []
245286 case <- ctxDone : // send cancel request
246287 ctxDone = nil
247288
289+ rp , err := json .Marshal ([]param {{v : reflect .ValueOf (cr .req .ID )}})
290+ if err != nil {
291+ return clientResponse {}, xerrors .Errorf ("marshalling cancel request: %w" , err )
292+ }
293+
248294 cancelReq := clientRequest {
249295 req : request {
250296 Jsonrpc : "2.0" ,
251297 Method : wsCancel ,
252- Params : [] param {{ v : reflect . ValueOf ( cr . req . ID )}} ,
298+ Params : rp ,
253299 },
300+ ready : make (chan clientResponse , 1 ),
254301 }
255302 select {
256303 case requests <- cancelReq :
@@ -264,30 +311,7 @@ func websocketClient(ctx context.Context, addr string, namespace string, outs []
264311 return resp , nil
265312 }
266313
267- stop := make (chan struct {})
268- exiting := make (chan struct {})
269- c .exiting = exiting
270-
271- go (& wsConn {
272- conn : conn ,
273- connFactory : connFactory ,
274- reconnectBackoff : config .reconnectBackoff ,
275- pingInterval : config .pingInterval ,
276- timeout : config .timeout ,
277- handler : nil ,
278- requests : requests ,
279- stop : stop ,
280- exiting : exiting ,
281- }).handleWsConn (ctx )
282-
283- if err := c .provide (outs ); err != nil {
284- return nil , err
285- }
286-
287- return func () {
288- close (stop )
289- <- exiting
290- }, nil
314+ return requests
291315}
292316
293317func (c * client ) provide (outs []interface {}) error {
@@ -433,10 +457,15 @@ type rpcFunc struct {
433457 valOut int
434458 errOut int
435459
436- hasCtx int
460+ // hasCtx is 1 if the function has a context.Context as its first argument.
461+ // Used as the number of the first non-context argument.
462+ hasCtx int
463+
464+ hasRawParams bool
437465 returnValueIsChannel bool
438466
439- retry bool
467+ retry bool
468+ notify bool
440469}
441470
442471func (fn * rpcFunc ) processResponse (resp clientResponse , rval reflect.Value ) []reflect.Value {
@@ -471,21 +500,47 @@ func (fn *rpcFunc) processError(err error) []reflect.Value {
471500}
472501
473502func (fn * rpcFunc ) handleRpcCall (args []reflect.Value ) (results []reflect.Value ) {
474- var id interface {} = atomic .AddInt64 (& fn .client .idCtr , 1 )
475- params := make ([]param , len (args )- fn .hasCtx )
476- for i , arg := range args [fn .hasCtx :] {
477- enc , found := fn .client .paramEncoders [arg .Type ()]
478- if found {
479- // custom param encoder
480- var err error
481- arg , err = enc (arg )
482- if err != nil {
483- return fn .processError (fmt .Errorf ("sendRequest failed: %w" , err ))
484- }
503+ var id interface {}
504+ if ! fn .notify {
505+ id = atomic .AddInt64 (& fn .client .idCtr , 1 )
506+
507+ // Prepare the ID to send on the wire.
508+ // We track int64 ids as float64 in the inflight map (because that's what
509+ // they'll be decoded to). encoding/json outputs numbers with their minimal
510+ // encoding, avoding the decimal point when possible, i.e. 3 will never get
511+ // converted to 3.0.
512+ var err error
513+ id , err = normalizeID (id )
514+ if err != nil {
515+ return fn .processError (fmt .Errorf ("failed to normalize id" )) // should probably panic
485516 }
517+ }
486518
487- params [i ] = param {
488- v : arg ,
519+ var serializedParams json.RawMessage
520+
521+ if fn .hasRawParams {
522+ serializedParams = json .RawMessage (args [fn .hasCtx ].Interface ().(RawParams ))
523+ } else {
524+ params := make ([]param , len (args )- fn .hasCtx )
525+ for i , arg := range args [fn .hasCtx :] {
526+ enc , found := fn .client .paramEncoders [arg .Type ()]
527+ if found {
528+ // custom param encoder
529+ var err error
530+ arg , err = enc (arg )
531+ if err != nil {
532+ return fn .processError (fmt .Errorf ("sendRequest failed: %w" , err ))
533+ }
534+ }
535+
536+ params [i ] = param {
537+ v : arg ,
538+ }
539+ }
540+ var err error
541+ serializedParams , err = json .Marshal (params )
542+ if err != nil {
543+ return fn .processError (fmt .Errorf ("marshaling params failed: %w" , err ))
489544 }
490545 }
491546
@@ -506,21 +561,11 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value)
506561 retVal , chCtor = fn .client .makeOutChan (ctx , fn .ftyp , fn .valOut )
507562 }
508563
509- // Prepare the ID to send on the wire.
510- // We track int64 ids as float64 in the inflight map (because that's what
511- // they'll be decoded to). encoding/json outputs numbers with their minimal
512- // encoding, avoding the decimal point when possible, i.e. 3 will never get
513- // converted to 3.0.
514- id , err := normalizeID (id )
515- if err != nil {
516- return fn .processError (fmt .Errorf ("failed to normalize id" )) // should probably panic
517- }
518-
519564 req := request {
520565 Jsonrpc : "2.0" ,
521566 ID : id ,
522- Method : fn .client . namespace + "." + fn . name ,
523- Params : params ,
567+ Method : fn .name ,
568+ Params : serializedParams ,
524569 }
525570
526571 if span != nil {
@@ -538,6 +583,7 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value)
538583 minDelay : methodMinRetryDelay ,
539584 }
540585
586+ var err error
541587 var resp clientResponse
542588 // keep retrying if got a forced closed websocket conn and calling method
543589 // has retry annotation
@@ -547,7 +593,7 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value)
547593 return fn .processError (fmt .Errorf ("sendRequest failed: %w" , err ))
548594 }
549595
550- if resp .ID != req .ID {
596+ if ! fn . notify && resp .ID != req .ID {
551597 return fn .processError (xerrors .New ("request and response id didn't match" ))
552598 }
553599
@@ -575,24 +621,48 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value)
575621 return fn .processResponse (resp , retVal ())
576622}
577623
624+ const (
625+ ProxyTagRetry = "retry"
626+ ProxyTagNotify = "notify"
627+ ProxyTagRPCMethod = "rpc_method"
628+ )
629+
578630func (c * client ) makeRpcFunc (f reflect.StructField ) (reflect.Value , error ) {
579631 ftyp := f .Type
580632 if ftyp .Kind () != reflect .Func {
581633 return reflect.Value {}, xerrors .New ("handler field not a func" )
582634 }
583635
636+ name := c .namespace + "." + f .Name
637+ if tag , ok := f .Tag .Lookup (ProxyTagRPCMethod ); ok {
638+ name = tag
639+ }
640+
584641 fun := & rpcFunc {
585642 client : c ,
586643 ftyp : ftyp ,
587- name : f .Name ,
588- retry : f .Tag .Get ("retry" ) == "true" ,
644+ name : name ,
645+ retry : f .Tag .Get (ProxyTagRetry ) == "true" ,
646+ notify : f .Tag .Get (ProxyTagNotify ) == "true" ,
589647 }
590648 fun .valOut , fun .errOut , fun .nout = processFuncOut (ftyp )
591649
650+ if fun .valOut != - 1 && fun .notify {
651+ return reflect.Value {}, xerrors .New ("notify methods cannot return values" )
652+ }
653+
654+ fun .returnValueIsChannel = fun .valOut != - 1 && ftyp .Out (fun .valOut ).Kind () == reflect .Chan
655+
592656 if ftyp .NumIn () > 0 && ftyp .In (0 ) == contextType {
593657 fun .hasCtx = 1
594658 }
595- fun .returnValueIsChannel = fun .valOut != - 1 && ftyp .Out (fun .valOut ).Kind () == reflect .Chan
659+ // note: hasCtx is also the number of the first non-context argument
660+ if ftyp .NumIn () > fun .hasCtx && ftyp .In (fun .hasCtx ) == rtRawParams {
661+ if ftyp .NumIn () > fun .hasCtx + 1 {
662+ return reflect.Value {}, xerrors .New ("raw params can't be mixed with other arguments" )
663+ }
664+ fun .hasRawParams = true
665+ }
596666
597667 return reflect .MakeFunc (ftyp , fun .handleRpcCall ), nil
598668}
0 commit comments