Skip to content

Commit fb87535

Browse files
committed
rpcmiddleware: add context to request and response handlers
With this commit we add a context to the handlers so we can use that to transport values as well as cancel a request/response if necessary.
1 parent f7c74ef commit fb87535

File tree

3 files changed

+87
-58
lines changed

3 files changed

+87
-58
lines changed

rpcmiddleware/interface.go

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,16 @@ var (
1818
// errorType is the reflection type of the error interface.
1919
errorType = reflect.TypeOf((*error)(nil)).Elem()
2020

21+
// ctxType is the reflection type of the context.Context interface.
22+
ctxType = reflect.TypeOf((*context.Context)(nil)).Elem()
23+
2124
// protoMessageType is the reflection type of the proto.Message
2225
// interface.
2326
protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem()
2427

2528
// passThroughMessageHandler is a messageHandler that does not modify
2629
// the message and just passes it through.
27-
passThroughMessageHandler messageHandler = func(
30+
passThroughMessageHandler messageHandler = func(context.Context,
2831
proto.Message) (proto.Message, error) {
2932

3033
return nil, nil
@@ -37,8 +40,8 @@ var (
3740
}
3841

3942
// messageDenyHandler disallows the given message.
40-
messageDenyHandler messageHandler = func(req proto.Message) (
41-
proto.Message, error) {
43+
messageDenyHandler messageHandler = func(context.Context,
44+
proto.Message) (proto.Message, error) {
4245

4346
return nil, ErrNotSupported
4447
}
@@ -72,7 +75,7 @@ type RequestInterceptor interface {
7275
// new message of the same type (=return non-nil message, nil error) or abort
7376
// the call by returning a non-nil error. If the message is a request, then
7477
// returning a non-nil error will reject the request.
75-
type messageHandler func(req proto.Message) (proto.Message, error)
78+
type messageHandler func(context.Context, proto.Message) (proto.Message, error)
7679

7780
// ErrorHandler is a function type for a generic gRPC error handler. It can
7881
// pass through the error unchanged (=return nil, nil), replace the error with
@@ -99,14 +102,14 @@ type RoundTripChecker interface {
99102
// a new message of the same type (=return non-nil message, nil) or
100103
// refuse (=return non-nil error with rejection reason) an incoming
101104
// request.
102-
HandleRequest(proto.Message) (proto.Message, error)
105+
HandleRequest(context.Context, proto.Message) (proto.Message, error)
103106

104107
// HandleResponse is called for each outgoing gRPC response message of
105108
// the type declared to be handled by HandlesResponse. The handler can
106109
// pass through the response (=return nil, nil), replace the response
107110
// with a new message of the same type (=return non-nil message, nil
108111
// error) or abort the call by returning a non-nil error.
109-
HandleResponse(proto.Message) (proto.Message, error)
112+
HandleResponse(context.Context, proto.Message) (proto.Message, error)
110113

111114
// HandleErrorResponse is called for any error response.
112115
// The handler can pass through the error (=return nil, nil), replace
@@ -142,27 +145,26 @@ func (r *DefaultChecker) HandlesResponse(t protoreflect.MessageType) bool {
142145
return t == r.responseType
143146
}
144147

145-
// HandleRequest is called for each incoming gRPC request message of the
146-
// type declared to be accepted by HandlesRequest. The handler can
147-
// accept the request as is (=return nil, nil), replace the request with
148-
// a new message of the same type (=return non-nil message, nil) or
149-
// refuse (=return non-nil error with rejection reason) an incoming
150-
// request.
151-
func (r *DefaultChecker) HandleRequest(req proto.Message) (proto.Message,
152-
error) {
148+
// HandleRequest is called for each incoming gRPC request message of the type
149+
// declared to be accepted by HandlesRequest. The handler can accept the request
150+
// as is (=return nil, nil), replace the request with a new message of the same
151+
// type (=return non-nil message, nil) or refuse (=return non-nil error with
152+
// rejection reason) an incoming request.
153+
func (r *DefaultChecker) HandleRequest(ctx context.Context,
154+
req proto.Message) (proto.Message, error) {
153155

154-
return r.requestHandler(req)
156+
return r.requestHandler(ctx, req)
155157
}
156158

157-
// HandleResponse is called for each outgoing gRPC response message of
158-
// the type declared to be handled by HandlesResponse. The handler can
159-
// pass through the response (=return nil, nil), replace the response
160-
// with a new message of the same type (=return non-nil message, nil
161-
// error) or abort the call by returning a non-nil error.
162-
func (r *DefaultChecker) HandleResponse(resp proto.Message) (proto.Message,
163-
error) {
159+
// HandleResponse is called for each outgoing gRPC response message of the type
160+
// declared to be handled by HandlesResponse. The handler can pass through the
161+
// response (=return nil, nil), replace the response with a new message of the
162+
// same type (=return non-nil message, nil error) or abort the call by returning
163+
// a non-nil error.
164+
func (r *DefaultChecker) HandleResponse(ctx context.Context,
165+
resp proto.Message) (proto.Message, error) {
164166

165-
return r.responseHandler(resp)
167+
return r.responseHandler(ctx, resp)
166168
}
167169

168170
// HandleErrorResponse is called for any error response.
@@ -310,7 +312,9 @@ func newReflectionRequestCheckHandler(requestSample proto.Message,
310312
panic(err)
311313
}
312314

313-
return func(req proto.Message) (proto.Message, error) {
315+
return func(ctx context.Context, req proto.Message) (proto.Message,
316+
error) {
317+
314318
if req.ProtoReflect().Type() != requestProtoType {
315319
return nil, fmt.Errorf("request handler called for "+
316320
"unsupported type %v (expected %v)",
@@ -320,6 +324,7 @@ func newReflectionRequestCheckHandler(requestSample proto.Message,
320324
// We made sure this call would succeed when creating the
321325
// handler.
322326
resp := handlerValue.Call([]reflect.Value{
327+
reflect.ValueOf(ctx),
323328
reflect.ValueOf(req),
324329
})
325330

@@ -352,7 +357,9 @@ func newReflectionMessageHandler(messageSample proto.Message,
352357
panic(err)
353358
}
354359

355-
return func(req proto.Message) (proto.Message, error) {
360+
return func(ctx context.Context, req proto.Message) (proto.Message,
361+
error) {
362+
356363
if req.ProtoReflect().Type() != messageProtoType {
357364
return nil, fmt.Errorf("message handler called for "+
358365
"unsupported type %v (expected %v)",
@@ -362,6 +369,7 @@ func newReflectionMessageHandler(messageSample proto.Message,
362369
// We made sure this call would succeed when creating the
363370
// handler.
364371
resp := handlerValue.Call([]reflect.Value{
372+
reflect.ValueOf(ctx),
365373
reflect.ValueOf(req),
366374
})
367375

@@ -390,12 +398,16 @@ func validateRequestCheckHandler(typedHandlerType reflect.Type,
390398
if typedHandlerType.Kind() != reflect.Func {
391399
return fmt.Errorf("request handler must be a function")
392400
}
393-
if typedHandlerType.NumIn() != 1 || typedHandlerType.NumOut() != 1 {
394-
return fmt.Errorf("request handler must have exactly one " +
401+
if typedHandlerType.NumIn() != 2 || typedHandlerType.NumOut() != 1 {
402+
return fmt.Errorf("request handler must have exactly two " +
395403
"parameter and one return value")
396404
}
397-
if !typedHandlerType.In(0).ConvertibleTo(requestType) {
398-
return fmt.Errorf("request handler must have one parameter " +
405+
if !typedHandlerType.In(0).ConvertibleTo(ctxType) {
406+
return fmt.Errorf("request handler must have first parameter " +
407+
"with a sub type of context.Context")
408+
}
409+
if !typedHandlerType.In(1).ConvertibleTo(requestType) {
410+
return fmt.Errorf("request handler must have second parameter " +
399411
"with a sub type of proto.Message")
400412
}
401413
if typedHandlerType.Out(0) != errorType {
@@ -414,12 +426,16 @@ func validateMessageHandler(typedHandlerType reflect.Type,
414426
if typedHandlerType.Kind() != reflect.Func {
415427
return fmt.Errorf("message handler must be a function")
416428
}
417-
if typedHandlerType.NumIn() != 1 || typedHandlerType.NumOut() != 2 {
418-
return fmt.Errorf("message handler must have exactly one " +
429+
if typedHandlerType.NumIn() != 2 || typedHandlerType.NumOut() != 2 {
430+
return fmt.Errorf("message handler must have exactly two " +
419431
"parameter and two return values")
420432
}
421-
if !typedHandlerType.In(0).ConvertibleTo(messageType) {
422-
return fmt.Errorf("message handler must have one parameter " +
433+
if !typedHandlerType.In(0).ConvertibleTo(ctxType) {
434+
return fmt.Errorf("request handler must have first parameter " +
435+
"with a sub type of context.Context")
436+
}
437+
if !typedHandlerType.In(1).ConvertibleTo(messageType) {
438+
return fmt.Errorf("message handler must have second parameter " +
423439
"with a sub type of proto.Message")
424440
}
425441
outType0 := typedHandlerType.Out(0)

rpcmiddleware/interface_test.go

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package rpcmiddleware
22

33
import (
4+
"context"
45
"fmt"
56
"testing"
67

@@ -10,6 +11,8 @@ import (
1011
)
1112

1213
var (
14+
ctxb = context.Background()
15+
1316
listPeersReq = &lnrpc.ListPeersRequest{}
1417
listPeersReqType = listPeersReq.ProtoReflect().Type()
1518

@@ -37,11 +40,11 @@ func TestPassThrough(t *testing.T) {
3740
require.True(t, peersChecker.HandlesRequest(listPeersReqType))
3841
require.True(t, peersChecker.HandlesResponse(listPeersRespType))
3942

40-
req, err := peersChecker.HandleRequest(listPeersReq)
43+
req, err := peersChecker.HandleRequest(ctxb, listPeersReq)
4144
require.NoError(t, err)
4245
require.Nil(t, req)
4346

44-
resp, err := peersChecker.HandleResponse(listPeersResp)
47+
resp, err := peersChecker.HandleResponse(ctxb, listPeersResp)
4548
require.NoError(t, err)
4649
require.Nil(t, resp)
4750

@@ -57,11 +60,11 @@ func TestRequestDenier(t *testing.T) {
5760
require.True(t, peersChecker.HandlesRequest(listPeersReqType))
5861
require.True(t, peersChecker.HandlesResponse(listPeersRespType))
5962

60-
req, err := peersChecker.HandleRequest(listPeersReq)
63+
req, err := peersChecker.HandleRequest(ctxb, listPeersReq)
6164
require.ErrorIs(t, err, ErrNotSupported)
6265
require.Nil(t, req)
6366

64-
resp, err := peersChecker.HandleResponse(listPeersResp)
67+
resp, err := peersChecker.HandleResponse(ctxb, listPeersResp)
6568
require.ErrorIs(t, err, ErrNotSupported)
6669
require.Nil(t, resp)
6770

@@ -74,19 +77,19 @@ func TestRequestDenier(t *testing.T) {
7477
func TestRequestChecker(t *testing.T) {
7578
peersChecker := NewRequestChecker(
7679
listPeersReq, listPeersResp,
77-
func(peer *lnrpc.ListPeersRequest) error {
80+
func(context.Context, *lnrpc.ListPeersRequest) error {
7881
return nil
7982
},
8083
)
8184

8285
require.True(t, peersChecker.HandlesRequest(listPeersReqType))
8386
require.True(t, peersChecker.HandlesResponse(listPeersRespType))
8487

85-
req, err := peersChecker.HandleRequest(listPeersReq)
88+
req, err := peersChecker.HandleRequest(ctxb, listPeersReq)
8689
require.NoError(t, err)
8790
require.Nil(t, req)
8891

89-
resp, err := peersChecker.HandleResponse(listPeersResp)
92+
resp, err := peersChecker.HandleResponse(ctxb, listPeersResp)
9093
require.NoError(t, err)
9194
require.Nil(t, resp)
9295

@@ -99,19 +102,21 @@ func TestRequestChecker(t *testing.T) {
99102
func TestRequestRewriter(t *testing.T) {
100103
peersChecker := NewRequestRewriter(
101104
listPeersReq, listPeersResp,
102-
func(peer *lnrpc.ListPeersRequest) (proto.Message, error) {
105+
func(ctx context.Context,
106+
peer *lnrpc.ListPeersRequest) (proto.Message, error) {
107+
103108
return peer, nil
104109
},
105110
)
106111

107112
require.True(t, peersChecker.HandlesRequest(listPeersReqType))
108113
require.True(t, peersChecker.HandlesResponse(listPeersRespType))
109114

110-
req, err := peersChecker.HandleRequest(listPeersReq)
115+
req, err := peersChecker.HandleRequest(ctxb, listPeersReq)
111116
require.NoError(t, err)
112117
require.Equal(t, listPeersReq, req)
113118

114-
resp, err := peersChecker.HandleResponse(listPeersResp)
119+
resp, err := peersChecker.HandleResponse(ctxb, listPeersResp)
115120
require.NoError(t, err)
116121
require.Nil(t, resp)
117122

@@ -124,19 +129,21 @@ func TestRequestRewriter(t *testing.T) {
124129
func TestResponseRewriter(t *testing.T) {
125130
peersChecker := NewResponseRewriter(
126131
listPeersReq, listPeersResp,
127-
func(peer *lnrpc.ListPeersResponse) (proto.Message, error) {
132+
func(ctx context.Context,
133+
peer *lnrpc.ListPeersResponse) (proto.Message, error) {
134+
128135
return peer, nil
129136
}, PassThroughErrorHandler,
130137
)
131138

132139
require.True(t, peersChecker.HandlesRequest(listPeersReqType))
133140
require.True(t, peersChecker.HandlesResponse(listPeersRespType))
134141

135-
req, err := peersChecker.HandleRequest(listPeersReq)
142+
req, err := peersChecker.HandleRequest(ctxb, listPeersReq)
136143
require.NoError(t, err)
137144
require.Nil(t, req)
138145

139-
resp, err := peersChecker.HandleResponse(listPeersResp)
146+
resp, err := peersChecker.HandleResponse(ctxb, listPeersResp)
140147
require.NoError(t, err)
141148
require.Equal(t, listPeersResp, resp)
142149

@@ -150,21 +157,23 @@ func TestFullChecker(t *testing.T) {
150157
myErr := fmt.Errorf("some error happened")
151158
peersChecker := NewFullChecker(
152159
listPeersReq, listPeersResp,
153-
func(peer *lnrpc.ListPeersRequest) error {
160+
func(ctx context.Context, peer *lnrpc.ListPeersRequest) error {
154161
return myErr
155162
},
156-
func(*lnrpc.ListPeersResponse) (proto.Message, error) {
163+
func(context.Context, *lnrpc.ListPeersResponse) (proto.Message,
164+
error) {
165+
157166
return nil, myErr
158167
}, PassThroughErrorHandler,
159168
)
160169

161170
require.True(t, peersChecker.HandlesRequest(listPeersReqType))
162171
require.True(t, peersChecker.HandlesResponse(listPeersRespType))
163172

164-
_, err := peersChecker.HandleRequest(listPeersReq)
173+
_, err := peersChecker.HandleRequest(ctxb, listPeersReq)
165174
require.Equal(t, myErr, err)
166175

167-
resp, err := peersChecker.HandleResponse(listPeersResp)
176+
resp, err := peersChecker.HandleResponse(ctxb, listPeersResp)
168177
require.Error(t, err)
169178
require.Equal(t, myErr, err)
170179
require.Nil(t, resp)
@@ -179,21 +188,25 @@ func TestFullRewriter(t *testing.T) {
179188
myErr := fmt.Errorf("some error happened")
180189
peersChecker := NewFullRewriter(
181190
listPeersReq, listPeersResp,
182-
func(peer *lnrpc.ListPeersRequest) (proto.Message, error) {
191+
func(ctx context.Context,
192+
peer *lnrpc.ListPeersRequest) (proto.Message, error) {
193+
183194
return nil, myErr
184195
},
185-
func(*lnrpc.ListPeersResponse) (proto.Message, error) {
196+
func(context.Context, *lnrpc.ListPeersResponse) (proto.Message,
197+
error) {
198+
186199
return nil, myErr
187200
}, PassThroughErrorHandler,
188201
)
189202

190203
require.True(t, peersChecker.HandlesRequest(listPeersReqType))
191204
require.True(t, peersChecker.HandlesResponse(listPeersRespType))
192205

193-
_, err := peersChecker.HandleRequest(listPeersReq)
206+
_, err := peersChecker.HandleRequest(ctxb, listPeersReq)
194207
require.Equal(t, myErr, err)
195208

196-
resp, err := peersChecker.HandleResponse(listPeersResp)
209+
resp, err := peersChecker.HandleResponse(ctxb, listPeersResp)
197210
require.Error(t, err)
198211
require.Equal(t, myErr, err)
199212
require.Nil(t, resp)
@@ -213,7 +226,7 @@ func TestImplementationPanics(t *testing.T) {
213226
},
214227
)
215228
require.PanicsWithError(
216-
t, "request handler must have exactly one parameter and one "+
229+
t, "request handler must have exactly two parameter and one "+
217230
"return value",
218231
func() {
219232
_ = NewRequestChecker(
@@ -232,7 +245,7 @@ func TestImplementationPanics(t *testing.T) {
232245
},
233246
)
234247
require.PanicsWithError(
235-
t, "message handler must have exactly one parameter and two "+
248+
t, "message handler must have exactly two parameter and two "+
236249
"return values",
237250
func() {
238251
_ = NewRequestRewriter(
@@ -252,7 +265,7 @@ func TestImplementationPanics(t *testing.T) {
252265
},
253266
)
254267
require.PanicsWithError(
255-
t, "message handler must have exactly one parameter and two "+
268+
t, "message handler must have exactly two parameter and two "+
256269
"return values",
257270
func() {
258271
_ = NewResponseRewriter(

rpcmiddleware/proto.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func RPCErr(req *lnrpc.RPCMiddlewareRequest,
5757
return RPCErrString(req, err.Error())
5858
}
5959

60-
return RPCErrString(req, "")
60+
return RPCOk(req)
6161
}
6262

6363
// RPCErrString constructs a middleware response. If an empty format param is

0 commit comments

Comments
 (0)