Skip to content

Commit eb12710

Browse files
committed
Support custom object-params
1 parent 44a9f01 commit eb12710

File tree

4 files changed

+204
-63
lines changed

4 files changed

+204
-63
lines changed

client.go

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,16 @@ func (c *client) setupRequestChan() chan clientRequest {
286286
case <-ctxDone: // send cancel request
287287
ctxDone = nil
288288

289+
rp, err := json.Marshal([]param{{v: reflect.ValueOf(cr.req.ID)}})
290+
if err != nil {
291+
return clientResponse{}, xerrors.Errorf("marshalling cancel request: %w", err)
292+
}
293+
289294
cancelReq := clientRequest{
290295
req: request{
291296
Jsonrpc: "2.0",
292297
Method: wsCancel,
293-
Params: []param{{v: reflect.ValueOf(cr.req.ID)}},
298+
Params: rp,
294299
},
295300
ready: make(chan clientResponse, 1),
296301
}
@@ -452,7 +457,11 @@ type rpcFunc struct {
452457
valOut int
453458
errOut int
454459

455-
hasCtx int
460+
// hasCtx is 1 if the function has a context.Context as its first argument.
461+
// Used as the number of the first non-context argument.
462+
hasCtx int
463+
464+
hasRawParams bool
456465
returnValueIsChannel bool
457466

458467
retry bool
@@ -507,20 +516,31 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value)
507516
}
508517
}
509518

510-
params := make([]param, len(args)-fn.hasCtx)
511-
for i, arg := range args[fn.hasCtx:] {
512-
enc, found := fn.client.paramEncoders[arg.Type()]
513-
if found {
514-
// custom param encoder
515-
var err error
516-
arg, err = enc(arg)
517-
if err != nil {
518-
return fn.processError(fmt.Errorf("sendRequest failed: %w", err))
519+
var serializedParams json.RawMessage
520+
521+
if fn.hasRawParams {
522+
serializedParams = json.RawMessage(args[fn.hasCtx].Interface().(RawParams))
523+
} else {
524+
params := make([]param, len(args)-fn.hasCtx)
525+
for i, arg := range args[fn.hasCtx:] {
526+
enc, found := fn.client.paramEncoders[arg.Type()]
527+
if found {
528+
// custom param encoder
529+
var err error
530+
arg, err = enc(arg)
531+
if err != nil {
532+
return fn.processError(fmt.Errorf("sendRequest failed: %w", err))
533+
}
519534
}
520-
}
521535

522-
params[i] = param{
523-
v: arg,
536+
params[i] = param{
537+
v: arg,
538+
}
539+
}
540+
var err error
541+
serializedParams, err = json.Marshal(params)
542+
if err != nil {
543+
return fn.processError(fmt.Errorf("marshaling params failed: %w", err))
524544
}
525545
}
526546

@@ -545,7 +565,7 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value)
545565
Jsonrpc: "2.0",
546566
ID: id,
547567
Method: fn.name,
548-
Params: params,
568+
Params: serializedParams,
549569
}
550570

551571
if span != nil {
@@ -631,10 +651,18 @@ func (c *client) makeRpcFunc(f reflect.StructField) (reflect.Value, error) {
631651
return reflect.Value{}, xerrors.New("notify methods cannot return values")
632652
}
633653

654+
fun.returnValueIsChannel = fun.valOut != -1 && ftyp.Out(fun.valOut).Kind() == reflect.Chan
655+
634656
if ftyp.NumIn() > 0 && ftyp.In(0) == contextType {
635657
fun.hasCtx = 1
636658
}
637-
fun.returnValueIsChannel = fun.valOut != -1 && ftyp.Out(fun.valOut).Kind() == reflect.Chan
659+
// note: hasCtx is also the number of the first non-context argument
660+
if ftyp.NumIn() > fun.hasCtx && ftyp.In(fun.hasCtx) == rtRawParams {
661+
if ftyp.NumIn() > fun.hasCtx+1 {
662+
return reflect.Value{}, xerrors.New("raw params can't be mixed with other arguments")
663+
}
664+
fun.hasRawParams = true
665+
}
638666

639667
return reflect.MakeFunc(ftyp, fun.handleRpcCall), nil
640668
}

handler.go

Lines changed: 72 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,20 @@ import (
2020
"github.com/filecoin-project/go-jsonrpc/metrics"
2121
)
2222

23+
type RawParams json.RawMessage
24+
25+
var rtRawParams = reflect.TypeOf(RawParams{})
26+
27+
// todo is there a better way to tell 'struct with any number of fields'?
28+
func DecodeParams[T any](p RawParams) (T, error) {
29+
var t T
30+
err := json.Unmarshal(p, &t)
31+
32+
// todo also handle list-encoding automagically (json.Unmarshal doesn't do that, does it?)
33+
34+
return t, err
35+
}
36+
2337
// methodHandler is a handler for a single method
2438
type methodHandler struct {
2539
paramReceivers []reflect.Type
@@ -28,7 +42,8 @@ type methodHandler struct {
2842
receiver reflect.Value
2943
handlerFunc reflect.Value
3044

31-
hasCtx int
45+
hasCtx int
46+
hasRawParams bool
3247

3348
errOut int
3449
valOut int
@@ -40,7 +55,7 @@ type request struct {
4055
Jsonrpc string `json:"jsonrpc"`
4156
ID interface{} `json:"id,omitempty"`
4257
Method string `json:"method"`
43-
Params []param `json:"params"`
58+
Params json.RawMessage `json:"params"`
4459
Meta map[string]string `json:"meta,omitempty"`
4560
}
4661

@@ -135,9 +150,16 @@ func (s *handler) register(namespace string, r interface{}) {
135150
hasCtx = 1
136151
}
137152

153+
hasRawParams := false
138154
ins := funcType.NumIn() - 1 - hasCtx
139155
recvs := make([]reflect.Type, ins)
140156
for i := 0; i < ins; i++ {
157+
if hasRawParams && i > 0 {
158+
panic("raw params must be the last parameter")
159+
}
160+
if funcType.In(i+1+hasCtx) == rtRawParams {
161+
hasRawParams = true
162+
}
141163
recvs[i] = method.Type.In(i + 1 + hasCtx)
142164
}
143165

@@ -150,7 +172,8 @@ func (s *handler) register(namespace string, r interface{}) {
150172
handlerFunc: method.Func,
151173
receiver: val,
152174

153-
hasCtx: hasCtx,
175+
hasCtx: hasCtx,
176+
hasRawParams: hasRawParams,
154177

155178
errOut: errOut,
156179
valOut: valOut,
@@ -291,13 +314,6 @@ func (s *handler) handle(ctx context.Context, req request, w func(func(io.Writer
291314
}
292315
}
293316

294-
if len(req.Params) != handler.nParams {
295-
rpcError(w, &req, rpcInvalidParams, fmt.Errorf("wrong param count (method '%s'): %d != %d", req.Method, len(req.Params), handler.nParams))
296-
stats.Record(ctx, metrics.RPCRequestError.M(1))
297-
done(false)
298-
return
299-
}
300-
301317
outCh := handler.valOut != -1 && handler.handlerFunc.Type().Out(handler.valOut).Kind() == reflect.Chan
302318
defer done(outCh)
303319

@@ -313,30 +329,54 @@ func (s *handler) handle(ctx context.Context, req request, w func(func(io.Writer
313329
callParams[1] = reflect.ValueOf(ctx)
314330
}
315331

316-
for i := 0; i < handler.nParams; i++ {
317-
var rp reflect.Value
318-
319-
typ := handler.paramReceivers[i]
320-
dec, found := s.paramDecoders[typ]
321-
if !found {
322-
rp = reflect.New(typ)
323-
if err := json.NewDecoder(bytes.NewReader(req.Params[i].data)).Decode(rp.Interface()); err != nil {
324-
rpcError(w, &req, rpcParseError, xerrors.Errorf("unmarshaling params for '%s' (param: %T): %w", req.Method, rp.Interface(), err))
325-
stats.Record(ctx, metrics.RPCRequestError.M(1))
326-
return
327-
}
328-
rp = rp.Elem()
329-
} else {
330-
var err error
331-
rp, err = dec(ctx, req.Params[i].data)
332-
if err != nil {
333-
rpcError(w, &req, rpcParseError, xerrors.Errorf("decoding params for '%s' (param: %d; custom decoder): %w", req.Method, i, err))
334-
stats.Record(ctx, metrics.RPCRequestError.M(1))
335-
return
336-
}
332+
if handler.hasRawParams {
333+
// When hasRawParams is true, there is only one parameter and it is a
334+
// json.RawMessage.
335+
336+
callParams[1+handler.hasCtx] = reflect.ValueOf(RawParams(req.Params))
337+
} else {
338+
// "normal" param list; no good way to do named params in Golang
339+
340+
var ps []param
341+
err := json.Unmarshal(req.Params, &ps)
342+
if err != nil {
343+
rpcError(w, &req, rpcParseError, xerrors.Errorf("unmarshaling param array: %w", err))
344+
stats.Record(ctx, metrics.RPCRequestError.M(1))
345+
return
346+
}
347+
348+
if len(ps) != handler.nParams {
349+
rpcError(w, &req, rpcInvalidParams, fmt.Errorf("wrong param count (method '%s'): %d != %d", req.Method, len(ps), handler.nParams))
350+
stats.Record(ctx, metrics.RPCRequestError.M(1))
351+
done(false)
352+
return
337353
}
338354

339-
callParams[i+1+handler.hasCtx] = reflect.ValueOf(rp.Interface())
355+
for i := 0; i < handler.nParams; i++ {
356+
var rp reflect.Value
357+
358+
typ := handler.paramReceivers[i]
359+
dec, found := s.paramDecoders[typ]
360+
if !found {
361+
rp = reflect.New(typ)
362+
if err := json.NewDecoder(bytes.NewReader(ps[i].data)).Decode(rp.Interface()); err != nil {
363+
rpcError(w, &req, rpcParseError, xerrors.Errorf("unmarshaling params for '%s' (param: %T): %w", req.Method, rp.Interface(), err))
364+
stats.Record(ctx, metrics.RPCRequestError.M(1))
365+
return
366+
}
367+
rp = rp.Elem()
368+
} else {
369+
var err error
370+
rp, err = dec(ctx, ps[i].data)
371+
if err != nil {
372+
rpcError(w, &req, rpcParseError, xerrors.Errorf("decoding params for '%s' (param: %d; custom decoder): %w", req.Method, i, err))
373+
stats.Record(ctx, metrics.RPCRequestError.M(1))
374+
return
375+
}
376+
}
377+
378+
callParams[i+1+handler.hasCtx] = reflect.ValueOf(rp.Interface())
379+
}
340380
}
341381

342382
// /////////////////

rpc_test.go

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,12 +1220,49 @@ func TestNotif(t *testing.T) {
12201220
t.Run("http", tc("http"))
12211221
}
12221222

1223-
// 1. make server call on client **
1224-
// 2. make client handle **
1225-
// 3. alias on client **
1226-
// 4. alias call on server **
1227-
// 6. custom/object param type
1228-
// 7. notif mode proxy tag
1223+
type RawParamHandler struct {
1224+
}
1225+
1226+
type CustomParams struct {
1227+
I int
1228+
}
1229+
1230+
func (h *RawParamHandler) Call(ctx context.Context, ps RawParams) (int, error) {
1231+
p, err := DecodeParams[CustomParams](ps)
1232+
if err != nil {
1233+
return 0, err
1234+
}
1235+
return p.I + 1, nil
1236+
}
1237+
1238+
func TestCallWithRawParams(t *testing.T) {
1239+
// setup server
1240+
1241+
rpcServer := NewServer()
1242+
rpcServer.Register("Raw", &RawParamHandler{})
1243+
1244+
// httptest stuff
1245+
testServ := httptest.NewServer(rpcServer)
1246+
defer testServ.Close()
1247+
1248+
// setup client
1249+
var client struct {
1250+
Call func(ctx context.Context, ps RawParams) (int, error)
1251+
}
1252+
closer, err := NewMergeClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "Raw", []interface{}{
1253+
&client,
1254+
}, nil)
1255+
require.NoError(t, err)
1256+
1257+
// do the call!
1258+
1259+
// this will block if it's not sent as a notification
1260+
n, err := client.Call(context.Background(), []byte(`{"I": 1}`))
1261+
require.NoError(t, err)
1262+
require.Equal(t, 2, n)
1263+
1264+
closer()
1265+
}
12291266

12301267
type RevCallTestServerHandler struct {
12311268
}

0 commit comments

Comments
 (0)