@@ -17,25 +17,17 @@ import (
17
17
// UnaryServerInterceptor returns a new unary server interceptor that validates incoming messages.
18
18
// If the request is invalid, clients may access a structured representation of the validation failure as an error detail.
19
19
func UnaryServerInterceptor (validator * protovalidate.Validator , opts ... Option ) grpc.UnaryServerInterceptor {
20
+ o := evaluateOpts (opts )
21
+
20
22
return func (
21
23
ctx context.Context ,
22
24
req interface {},
23
25
info * grpc.UnaryServerInfo ,
24
26
handler grpc.UnaryHandler ,
25
27
) (resp interface {}, err error ) {
26
- o := evaluateOpts (opts )
27
- switch msg := req .(type ) {
28
- case proto.Message :
29
- if o .shouldIgnoreMessage (msg .ProtoReflect ().Type ()) {
30
- break
31
- }
32
- if err = validator .Validate (msg ); err != nil {
33
- return nil , validationErrToStatus (err ).Err ()
34
- }
35
- default :
36
- return nil , errors .New ("unsupported message type" )
28
+ if err := validateMsg (req , validator , o ); err != nil {
29
+ return nil , err
37
30
}
38
-
39
31
return handler (ctx , req )
40
32
}
41
33
}
@@ -49,66 +41,50 @@ func StreamServerInterceptor(validator *protovalidate.Validator, opts ...Option)
49
41
info * grpc.StreamServerInfo ,
50
42
handler grpc.StreamHandler ,
51
43
) error {
52
- ctx := stream .Context ()
44
+ return handler (srv , & wrappedServerStream {
45
+ ServerStream : stream ,
46
+ validator : validator ,
47
+ options : evaluateOpts (opts ),
48
+ })
49
+ }
50
+ }
53
51
54
- wrapped := wrapServerStream (stream )
55
- wrapped .wrappedContext = ctx
56
- wrapped .validator = validator
57
- wrapped .options = evaluateOpts (opts )
52
+ // wrappedServerStream is a thin wrapper around grpc.ServerStream that allows modifying context.
53
+ type wrappedServerStream struct {
54
+ grpc.ServerStream
58
55
59
- return handler ( srv , wrapped )
60
- }
56
+ validator * protovalidate. Validator
57
+ options * options
61
58
}
62
59
63
60
func (w * wrappedServerStream ) RecvMsg (m interface {}) error {
64
61
if err := w .ServerStream .RecvMsg (m ); err != nil {
65
62
return err
66
63
}
64
+ return validateMsg (m , w .validator , w .options )
65
+ }
67
66
67
+ func validateMsg (m interface {}, validator * protovalidate.Validator , opts * options ) error {
68
68
msg , ok := m .(proto.Message )
69
69
if ! ok {
70
70
return errors .New ("unsupported message type" )
71
71
}
72
- if w . options . shouldIgnoreMessage (msg .ProtoReflect ().Type ()) {
72
+ if opts . shouldIgnoreMessage (msg .ProtoReflect ().Descriptor (). FullName ()) {
73
73
return nil
74
74
}
75
- if err := w .validator .Validate (msg ); err != nil {
76
- return validationErrToStatus (err ).Err ()
75
+ err := validator .Validate (msg )
76
+ if err == nil {
77
+ return nil
77
78
}
78
-
79
- return nil
80
- }
81
-
82
- // wrappedServerStream is a thin wrapper around grpc.ServerStream that allows modifying context.
83
- type wrappedServerStream struct {
84
- grpc.ServerStream
85
- // wrappedContext is the wrapper's own Context. You can assign it.
86
- wrappedContext context.Context
87
-
88
- validator * protovalidate.Validator
89
- options * options
90
- }
91
-
92
- // Context returns the wrapper's WrappedContext, overwriting the nested grpc.ServerStream.Context()
93
- func (w * wrappedServerStream ) Context () context.Context {
94
- return w .wrappedContext
95
- }
96
-
97
- // wrapServerStream returns a ServerStream that has the ability to overwrite context.
98
- func wrapServerStream (stream grpc.ServerStream ) * wrappedServerStream {
99
- return & wrappedServerStream {ServerStream : stream , wrappedContext : stream .Context ()}
100
- }
101
-
102
- func validationErrToStatus (err error ) * status.Status {
103
- // Message is invalid.
104
79
if valErr := new (protovalidate.ValidationError ); errors .As (err , & valErr ) {
80
+ // Message is invalid.
105
81
st := status .New (codes .InvalidArgument , err .Error ())
106
82
ds , detErr := st .WithDetails (valErr .ToProto ())
107
83
if detErr != nil {
108
- return st
84
+ return st . Err ()
109
85
}
110
- return ds
86
+ return ds . Err ()
111
87
}
112
88
// CEL expression doesn't compile or type-check.
113
- return status .New (codes .Unknown , err .Error ())
89
+ return status .Error (codes .Unknown , err .Error ())
114
90
}
0 commit comments