Skip to content

Commit 77de03b

Browse files
committed
protovalidate: avoid pointer comparisons
When deciding whether the user has bypassed validation for a message, we should compare the message's fully-qualified name rather than doing a fragile pointer comparison. Along the way, this commit removes some puzzling and unused code to plumb a context through the wrappedServerStream. Signed-off-by: Akshay Shah <[email protected]>
1 parent 53a6d9e commit 77de03b

File tree

2 files changed

+37
-57
lines changed

2 files changed

+37
-57
lines changed

interceptors/protovalidate/options.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import (
1212
)
1313

1414
type options struct {
15-
ignoreMessages []protoreflect.MessageType
15+
ignoreMessages []protoreflect.FullName
1616
}
1717

1818
// An Option lets you add options to protovalidate interceptors using With* funcs.
@@ -29,13 +29,17 @@ func evaluateOpts(opts []Option) *options {
2929
// WithIgnoreMessages sets the messages that should be ignored by the validator. Use with
3030
// caution and ensure validation is performed elsewhere.
3131
func WithIgnoreMessages(msgs ...protoreflect.MessageType) Option {
32+
names := make([]protoreflect.FullName, 0, len(msgs))
33+
for _, msg := range msgs {
34+
names = append(names, msg.Descriptor().FullName())
35+
}
36+
slices.Sort(names)
3237
return func(o *options) {
33-
o.ignoreMessages = msgs
38+
o.ignoreMessages = names
3439
}
3540
}
3641

37-
func (o *options) shouldIgnoreMessage(m protoreflect.MessageType) bool {
38-
return slices.ContainsFunc(o.ignoreMessages, func(t protoreflect.MessageType) bool {
39-
return m == t
40-
})
42+
func (o *options) shouldIgnoreMessage(fqn protoreflect.FullName) bool {
43+
_, found := slices.BinarySearch(o.ignoreMessages, fqn)
44+
return found
4145
}

interceptors/protovalidate/protovalidate.go

Lines changed: 27 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,17 @@ import (
1717
// UnaryServerInterceptor returns a new unary server interceptor that validates incoming messages.
1818
// If the request is invalid, clients may access a structured representation of the validation failure as an error detail.
1919
func UnaryServerInterceptor(validator *protovalidate.Validator, opts ...Option) grpc.UnaryServerInterceptor {
20+
o := evaluateOpts(opts)
21+
2022
return func(
2123
ctx context.Context,
2224
req interface{},
2325
info *grpc.UnaryServerInfo,
2426
handler grpc.UnaryHandler,
2527
) (resp interface{}, err error) {
26-
o := evaluateOpts(opts)
27-
switch msg := req.(type) {
28-
case proto.Message:
29-
if o.shouldIgnoreMessage(msg.ProtoReflect().Type()) {
30-
break
31-
}
32-
if err = validator.Validate(msg); err != nil {
33-
return nil, validationErrToStatus(err).Err()
34-
}
35-
default:
36-
return nil, errors.New("unsupported message type")
28+
if err := validateMsg(req, validator, o); err != nil {
29+
return nil, err
3730
}
38-
3931
return handler(ctx, req)
4032
}
4133
}
@@ -49,66 +41,50 @@ func StreamServerInterceptor(validator *protovalidate.Validator, opts ...Option)
4941
info *grpc.StreamServerInfo,
5042
handler grpc.StreamHandler,
5143
) error {
52-
ctx := stream.Context()
44+
return handler(srv, &wrappedServerStream{
45+
ServerStream: stream,
46+
validator: validator,
47+
options: evaluateOpts(opts),
48+
})
49+
}
50+
}
5351

54-
wrapped := wrapServerStream(stream)
55-
wrapped.wrappedContext = ctx
56-
wrapped.validator = validator
57-
wrapped.options = evaluateOpts(opts)
52+
// wrappedServerStream is a thin wrapper around grpc.ServerStream that allows modifying context.
53+
type wrappedServerStream struct {
54+
grpc.ServerStream
5855

59-
return handler(srv, wrapped)
60-
}
56+
validator *protovalidate.Validator
57+
options *options
6158
}
6259

6360
func (w *wrappedServerStream) RecvMsg(m interface{}) error {
6461
if err := w.ServerStream.RecvMsg(m); err != nil {
6562
return err
6663
}
64+
return validateMsg(m, w.validator, w.options)
65+
}
6766

67+
func validateMsg(m interface{}, validator *protovalidate.Validator, opts *options) error {
6868
msg, ok := m.(proto.Message)
6969
if !ok {
7070
return errors.New("unsupported message type")
7171
}
72-
if w.options.shouldIgnoreMessage(msg.ProtoReflect().Type()) {
72+
if opts.shouldIgnoreMessage(msg.ProtoReflect().Descriptor().FullName()) {
7373
return nil
7474
}
75-
if err := w.validator.Validate(msg); err != nil {
76-
return validationErrToStatus(err).Err()
75+
err := validator.Validate(msg)
76+
if err == nil {
77+
return nil
7778
}
78-
79-
return nil
80-
}
81-
82-
// wrappedServerStream is a thin wrapper around grpc.ServerStream that allows modifying context.
83-
type wrappedServerStream struct {
84-
grpc.ServerStream
85-
// wrappedContext is the wrapper's own Context. You can assign it.
86-
wrappedContext context.Context
87-
88-
validator *protovalidate.Validator
89-
options *options
90-
}
91-
92-
// Context returns the wrapper's WrappedContext, overwriting the nested grpc.ServerStream.Context()
93-
func (w *wrappedServerStream) Context() context.Context {
94-
return w.wrappedContext
95-
}
96-
97-
// wrapServerStream returns a ServerStream that has the ability to overwrite context.
98-
func wrapServerStream(stream grpc.ServerStream) *wrappedServerStream {
99-
return &wrappedServerStream{ServerStream: stream, wrappedContext: stream.Context()}
100-
}
101-
102-
func validationErrToStatus(err error) *status.Status {
103-
// Message is invalid.
10479
if valErr := new(protovalidate.ValidationError); errors.As(err, &valErr) {
80+
// Message is invalid.
10581
st := status.New(codes.InvalidArgument, err.Error())
10682
ds, detErr := st.WithDetails(valErr.ToProto())
10783
if detErr != nil {
108-
return st
84+
return st.Err()
10985
}
110-
return ds
86+
return ds.Err()
11187
}
11288
// CEL expression doesn't compile or type-check.
113-
return status.New(codes.Unknown, err.Error())
89+
return status.Error(codes.Unknown, err.Error())
11490
}

0 commit comments

Comments
 (0)