@@ -15,98 +15,76 @@ import (
15
15
)
16
16
17
17
// UnaryServerInterceptor returns a new unary server interceptor that validates incoming messages.
18
+ // If the request is invalid, clients may access a structured representation of the validation failure as an error detail.
18
19
func UnaryServerInterceptor (validator * protovalidate.Validator , opts ... Option ) grpc.UnaryServerInterceptor {
20
+ o := evaluateOpts (opts )
21
+
19
22
return func (
20
23
ctx context.Context ,
21
24
req interface {},
22
25
info * grpc.UnaryServerInfo ,
23
26
handler grpc.UnaryHandler ,
24
27
) (resp interface {}, err error ) {
25
- o := evaluateOpts (opts )
26
- switch msg := req .(type ) {
27
- case proto.Message :
28
- if o .shouldIgnoreMessage (msg .ProtoReflect ().Type ()) {
29
- break
30
- }
31
- if err = validator .Validate (msg ); err != nil {
32
- return nil , validationErrToStatus (err ).Err ()
33
- }
34
- default :
35
- return nil , errors .New ("unsupported message type" )
28
+ if err := validateMsg (req , validator , o ); err != nil {
29
+ return nil , err
36
30
}
37
-
38
31
return handler (ctx , req )
39
32
}
40
33
}
41
34
42
35
// StreamServerInterceptor returns a new streaming server interceptor that validates incoming messages.
36
+ // If the request is invalid, clients may access a structured representation of the validation failure as an error detail.
43
37
func StreamServerInterceptor (validator * protovalidate.Validator , opts ... Option ) grpc.StreamServerInterceptor {
44
38
return func (
45
39
srv interface {},
46
40
stream grpc.ServerStream ,
47
41
info * grpc.StreamServerInfo ,
48
42
handler grpc.StreamHandler ,
49
43
) error {
50
- ctx := stream .Context ()
44
+ return handler (srv , & wrappedServerStream {
45
+ ServerStream : stream ,
46
+ validator : validator ,
47
+ options : evaluateOpts (opts ),
48
+ })
49
+ }
50
+ }
51
51
52
- wrapped := wrapServerStream (stream )
53
- wrapped .wrappedContext = ctx
54
- wrapped .validator = validator
55
- wrapped .options = evaluateOpts (opts )
52
+ // wrappedServerStream is a thin wrapper around grpc.ServerStream that allows modifying context.
53
+ type wrappedServerStream struct {
54
+ grpc.ServerStream
56
55
57
- return handler ( srv , wrapped )
58
- }
56
+ validator * protovalidate. Validator
57
+ options * options
59
58
}
60
59
61
60
func (w * wrappedServerStream ) RecvMsg (m interface {}) error {
62
61
if err := w .ServerStream .RecvMsg (m ); err != nil {
63
62
return err
64
63
}
64
+ return validateMsg (m , w .validator , w .options )
65
+ }
65
66
67
+ func validateMsg (m interface {}, validator * protovalidate.Validator , opts * options ) error {
66
68
msg , ok := m .(proto.Message )
67
69
if ! ok {
68
70
return errors .New ("unsupported message type" )
69
71
}
70
- if w . options . shouldIgnoreMessage (msg .ProtoReflect ().Type ()) {
72
+ if opts . shouldIgnoreMessage (msg .ProtoReflect ().Descriptor (). FullName ()) {
71
73
return nil
72
74
}
73
- if err := w .validator .Validate (msg ); err != nil {
74
- return validationErrToStatus (err ).Err ()
75
+ err := validator .Validate (msg )
76
+ if err == nil {
77
+ return nil
75
78
}
76
-
77
- return nil
78
- }
79
-
80
- // wrappedServerStream is a thin wrapper around grpc.ServerStream that allows modifying context.
81
- type wrappedServerStream struct {
82
- grpc.ServerStream
83
- // wrappedContext is the wrapper's own Context. You can assign it.
84
- wrappedContext context.Context
85
-
86
- validator * protovalidate.Validator
87
- options * options
88
- }
89
-
90
- // Context returns the wrapper's WrappedContext, overwriting the nested grpc.ServerStream.Context()
91
- func (w * wrappedServerStream ) Context () context.Context {
92
- return w .wrappedContext
93
- }
94
-
95
- // wrapServerStream returns a ServerStream that has the ability to overwrite context.
96
- func wrapServerStream (stream grpc.ServerStream ) * wrappedServerStream {
97
- return & wrappedServerStream {ServerStream : stream , wrappedContext : stream .Context ()}
98
- }
99
-
100
- func validationErrToStatus (err error ) * status.Status {
101
- // Message is invalid.
102
79
if valErr := new (protovalidate.ValidationError ); errors .As (err , & valErr ) {
80
+ // Message is invalid.
103
81
st := status .New (codes .InvalidArgument , err .Error ())
104
82
ds , detErr := st .WithDetails (valErr .ToProto ())
105
83
if detErr != nil {
106
- return st
84
+ return st . Err ()
107
85
}
108
- return ds
86
+ return ds . Err ()
109
87
}
110
88
// CEL expression doesn't compile or type-check.
111
- return status .New (codes .Unknown , err .Error ())
89
+ return status .Error (codes .Unknown , err .Error ())
112
90
}
0 commit comments