@@ -18,13 +18,16 @@ var (
1818 // errorType is the reflection type of the error interface.
1919 errorType = reflect .TypeOf ((* error )(nil )).Elem ()
2020
21+ // ctxType is the reflection type of the context.Context interface.
22+ ctxType = reflect .TypeOf ((* context .Context )(nil )).Elem ()
23+
2124 // protoMessageType is the reflection type of the proto.Message
2225 // interface.
2326 protoMessageType = reflect .TypeOf ((* proto .Message )(nil )).Elem ()
2427
2528 // passThroughMessageHandler is a messageHandler that does not modify
2629 // the message and just passes it through.
27- passThroughMessageHandler messageHandler = func (
30+ passThroughMessageHandler messageHandler = func (context. Context ,
2831 proto.Message ) (proto.Message , error ) {
2932
3033 return nil , nil
3740 }
3841
3942 // messageDenyHandler disallows the given message.
40- messageDenyHandler messageHandler = func (req proto. Message ) (
41- proto.Message , error ) {
43+ messageDenyHandler messageHandler = func (context. Context ,
44+ proto.Message ) (proto. Message , error ) {
4245
4346 return nil , ErrNotSupported
4447 }
@@ -72,7 +75,7 @@ type RequestInterceptor interface {
7275// new message of the same type (=return non-nil message, nil error) or abort
7376// the call by returning a non-nil error. If the message is a request, then
7477// returning a non-nil error will reject the request.
75- type messageHandler func (req proto.Message ) (proto.Message , error )
78+ type messageHandler func (context. Context , proto.Message ) (proto.Message , error )
7679
7780// ErrorHandler is a function type for a generic gRPC error handler. It can
7881// pass through the error unchanged (=return nil, nil), replace the error with
@@ -99,14 +102,14 @@ type RoundTripChecker interface {
99102 // a new message of the same type (=return non-nil message, nil) or
100103 // refuse (=return non-nil error with rejection reason) an incoming
101104 // request.
102- HandleRequest (proto.Message ) (proto.Message , error )
105+ HandleRequest (context. Context , proto.Message ) (proto.Message , error )
103106
104107 // HandleResponse is called for each outgoing gRPC response message of
105108 // the type declared to be handled by HandlesResponse. The handler can
106109 // pass through the response (=return nil, nil), replace the response
107110 // with a new message of the same type (=return non-nil message, nil
108111 // error) or abort the call by returning a non-nil error.
109- HandleResponse (proto.Message ) (proto.Message , error )
112+ HandleResponse (context. Context , proto.Message ) (proto.Message , error )
110113
111114 // HandleErrorResponse is called for any error response.
112115 // The handler can pass through the error (=return nil, nil), replace
@@ -142,27 +145,26 @@ func (r *DefaultChecker) HandlesResponse(t protoreflect.MessageType) bool {
142145 return t == r .responseType
143146}
144147
145- // HandleRequest is called for each incoming gRPC request message of the
146- // type declared to be accepted by HandlesRequest. The handler can
147- // accept the request as is (=return nil, nil), replace the request with
148- // a new message of the same type (=return non-nil message, nil) or
149- // refuse (=return non-nil error with rejection reason) an incoming
150- // request.
151- func (r * DefaultChecker ) HandleRequest (req proto.Message ) (proto.Message ,
152- error ) {
148+ // HandleRequest is called for each incoming gRPC request message of the type
149+ // declared to be accepted by HandlesRequest. The handler can accept the request
150+ // as is (=return nil, nil), replace the request with a new message of the same
151+ // type (=return non-nil message, nil) or refuse (=return non-nil error with
152+ // rejection reason) an incoming request.
153+ func (r * DefaultChecker ) HandleRequest (ctx context.Context ,
154+ req proto.Message ) (proto.Message , error ) {
153155
154- return r .requestHandler (req )
156+ return r .requestHandler (ctx , req )
155157}
156158
157- // HandleResponse is called for each outgoing gRPC response message of
158- // the type declared to be handled by HandlesResponse. The handler can
159- // pass through the response (=return nil, nil), replace the response
160- // with a new message of the same type (=return non-nil message, nil
161- // error) or abort the call by returning a non-nil error.
162- func (r * DefaultChecker ) HandleResponse (resp proto. Message ) (proto. Message ,
163- error ) {
159+ // HandleResponse is called for each outgoing gRPC response message of the type
160+ // declared to be handled by HandlesResponse. The handler can pass through the
161+ // response (=return nil, nil), replace the response with a new message of the
162+ // same type (=return non-nil message, nil error) or abort the call by returning
163+ // a non-nil error.
164+ func (r * DefaultChecker ) HandleResponse (ctx context. Context ,
165+ resp proto. Message ) (proto. Message , error ) {
164166
165- return r .responseHandler (resp )
167+ return r .responseHandler (ctx , resp )
166168}
167169
168170// HandleErrorResponse is called for any error response.
@@ -310,7 +312,9 @@ func newReflectionRequestCheckHandler(requestSample proto.Message,
310312 panic (err )
311313 }
312314
313- return func (req proto.Message ) (proto.Message , error ) {
315+ return func (ctx context.Context , req proto.Message ) (proto.Message ,
316+ error ) {
317+
314318 if req .ProtoReflect ().Type () != requestProtoType {
315319 return nil , fmt .Errorf ("request handler called for " +
316320 "unsupported type %v (expected %v)" ,
@@ -320,6 +324,7 @@ func newReflectionRequestCheckHandler(requestSample proto.Message,
320324 // We made sure this call would succeed when creating the
321325 // handler.
322326 resp := handlerValue .Call ([]reflect.Value {
327+ reflect .ValueOf (ctx ),
323328 reflect .ValueOf (req ),
324329 })
325330
@@ -352,7 +357,9 @@ func newReflectionMessageHandler(messageSample proto.Message,
352357 panic (err )
353358 }
354359
355- return func (req proto.Message ) (proto.Message , error ) {
360+ return func (ctx context.Context , req proto.Message ) (proto.Message ,
361+ error ) {
362+
356363 if req .ProtoReflect ().Type () != messageProtoType {
357364 return nil , fmt .Errorf ("message handler called for " +
358365 "unsupported type %v (expected %v)" ,
@@ -362,6 +369,7 @@ func newReflectionMessageHandler(messageSample proto.Message,
362369 // We made sure this call would succeed when creating the
363370 // handler.
364371 resp := handlerValue .Call ([]reflect.Value {
372+ reflect .ValueOf (ctx ),
365373 reflect .ValueOf (req ),
366374 })
367375
@@ -390,12 +398,16 @@ func validateRequestCheckHandler(typedHandlerType reflect.Type,
390398 if typedHandlerType .Kind () != reflect .Func {
391399 return fmt .Errorf ("request handler must be a function" )
392400 }
393- if typedHandlerType .NumIn () != 1 || typedHandlerType .NumOut () != 1 {
394- return fmt .Errorf ("request handler must have exactly one " +
401+ if typedHandlerType .NumIn () != 2 || typedHandlerType .NumOut () != 1 {
402+ return fmt .Errorf ("request handler must have exactly two " +
395403 "parameter and one return value" )
396404 }
397- if ! typedHandlerType .In (0 ).ConvertibleTo (requestType ) {
398- return fmt .Errorf ("request handler must have one parameter " +
405+ if ! typedHandlerType .In (0 ).ConvertibleTo (ctxType ) {
406+ return fmt .Errorf ("request handler must have first parameter " +
407+ "with a sub type of context.Context" )
408+ }
409+ if ! typedHandlerType .In (1 ).ConvertibleTo (requestType ) {
410+ return fmt .Errorf ("request handler must have second parameter " +
399411 "with a sub type of proto.Message" )
400412 }
401413 if typedHandlerType .Out (0 ) != errorType {
@@ -414,12 +426,16 @@ func validateMessageHandler(typedHandlerType reflect.Type,
414426 if typedHandlerType .Kind () != reflect .Func {
415427 return fmt .Errorf ("message handler must be a function" )
416428 }
417- if typedHandlerType .NumIn () != 1 || typedHandlerType .NumOut () != 2 {
418- return fmt .Errorf ("message handler must have exactly one " +
429+ if typedHandlerType .NumIn () != 2 || typedHandlerType .NumOut () != 2 {
430+ return fmt .Errorf ("message handler must have exactly two " +
419431 "parameter and two return values" )
420432 }
421- if ! typedHandlerType .In (0 ).ConvertibleTo (messageType ) {
422- return fmt .Errorf ("message handler must have one parameter " +
433+ if ! typedHandlerType .In (0 ).ConvertibleTo (ctxType ) {
434+ return fmt .Errorf ("request handler must have first parameter " +
435+ "with a sub type of context.Context" )
436+ }
437+ if ! typedHandlerType .In (1 ).ConvertibleTo (messageType ) {
438+ return fmt .Errorf ("message handler must have second parameter " +
423439 "with a sub type of proto.Message" )
424440 }
425441 outType0 := typedHandlerType .Out (0 )
0 commit comments