Skip to content

Commit 3606823

Browse files
Merge pull request #715 from akshayjshah/ajs/robust
protovalidate: avoid pointer comparisons
2 parents 8036513 + 21bacae commit 3606823

File tree

3 files changed

+64
-65
lines changed

3 files changed

+64
-65
lines changed

interceptors/protovalidate/options.go

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import (
1212
)
1313

1414
type options struct {
15-
ignoreMessages []protoreflect.MessageType
15+
ignoreMessages []protoreflect.FullName
1616
}
1717

1818
// An Option lets you add options to protovalidate interceptors using With* funcs.
@@ -26,16 +26,24 @@ func evaluateOpts(opts []Option) *options {
2626
return optCopy
2727
}
2828

29-
// WithIgnoreMessages sets the messages that should be ignored by the validator. Use with
30-
// caution and ensure validation is performed elsewhere.
29+
// WithIgnoreMessages sets the messages that should be ignored by the
30+
// validator. Message types are matched using their fully-qualified Protobuf
31+
// names.
32+
//
33+
// Use with caution and ensure validation is performed elsewhere.
3134
func WithIgnoreMessages(msgs ...protoreflect.MessageType) Option {
35+
names := make([]protoreflect.FullName, 0, len(msgs))
36+
for _, msg := range msgs {
37+
names = append(names, msg.Descriptor().FullName())
38+
}
39+
slices.Sort(names)
3240
return func(o *options) {
33-
o.ignoreMessages = msgs
41+
o.ignoreMessages = names
3442
}
3543
}
3644

37-
func (o *options) shouldIgnoreMessage(m protoreflect.MessageType) bool {
38-
return slices.ContainsFunc(o.ignoreMessages, func(t protoreflect.MessageType) bool {
39-
return m == t
40-
})
45+
func (o *options) shouldIgnoreMessage(fqn protoreflect.FullName) bool {
46+
// Names are sorted in WithIgnoreMessages, so we can use binary search.
47+
_, found := slices.BinarySearch(o.ignoreMessages, fqn)
48+
return found
4149
}

interceptors/protovalidate/protovalidate.go

Lines changed: 29 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -15,98 +15,76 @@ import (
1515
)
1616

1717
// UnaryServerInterceptor returns a new unary server interceptor that validates incoming messages.
18+
// If the request is invalid, clients may access a structured representation of the validation failure as an error detail.
1819
func UnaryServerInterceptor(validator *protovalidate.Validator, opts ...Option) grpc.UnaryServerInterceptor {
20+
o := evaluateOpts(opts)
21+
1922
return func(
2023
ctx context.Context,
2124
req interface{},
2225
info *grpc.UnaryServerInfo,
2326
handler grpc.UnaryHandler,
2427
) (resp interface{}, err error) {
25-
o := evaluateOpts(opts)
26-
switch msg := req.(type) {
27-
case proto.Message:
28-
if o.shouldIgnoreMessage(msg.ProtoReflect().Type()) {
29-
break
30-
}
31-
if err = validator.Validate(msg); err != nil {
32-
return nil, validationErrToStatus(err).Err()
33-
}
34-
default:
35-
return nil, errors.New("unsupported message type")
28+
if err := validateMsg(req, validator, o); err != nil {
29+
return nil, err
3630
}
37-
3831
return handler(ctx, req)
3932
}
4033
}
4134

4235
// StreamServerInterceptor returns a new streaming server interceptor that validates incoming messages.
36+
// If the request is invalid, clients may access a structured representation of the validation failure as an error detail.
4337
func StreamServerInterceptor(validator *protovalidate.Validator, opts ...Option) grpc.StreamServerInterceptor {
4438
return func(
4539
srv interface{},
4640
stream grpc.ServerStream,
4741
info *grpc.StreamServerInfo,
4842
handler grpc.StreamHandler,
4943
) error {
50-
ctx := stream.Context()
44+
return handler(srv, &wrappedServerStream{
45+
ServerStream: stream,
46+
validator: validator,
47+
options: evaluateOpts(opts),
48+
})
49+
}
50+
}
5151

52-
wrapped := wrapServerStream(stream)
53-
wrapped.wrappedContext = ctx
54-
wrapped.validator = validator
55-
wrapped.options = evaluateOpts(opts)
52+
// wrappedServerStream is a thin wrapper around grpc.ServerStream that allows modifying context.
53+
type wrappedServerStream struct {
54+
grpc.ServerStream
5655

57-
return handler(srv, wrapped)
58-
}
56+
validator *protovalidate.Validator
57+
options *options
5958
}
6059

6160
func (w *wrappedServerStream) RecvMsg(m interface{}) error {
6261
if err := w.ServerStream.RecvMsg(m); err != nil {
6362
return err
6463
}
64+
return validateMsg(m, w.validator, w.options)
65+
}
6566

67+
func validateMsg(m interface{}, validator *protovalidate.Validator, opts *options) error {
6668
msg, ok := m.(proto.Message)
6769
if !ok {
6870
return errors.New("unsupported message type")
6971
}
70-
if w.options.shouldIgnoreMessage(msg.ProtoReflect().Type()) {
72+
if opts.shouldIgnoreMessage(msg.ProtoReflect().Descriptor().FullName()) {
7173
return nil
7274
}
73-
if err := w.validator.Validate(msg); err != nil {
74-
return validationErrToStatus(err).Err()
75+
err := validator.Validate(msg)
76+
if err == nil {
77+
return nil
7578
}
76-
77-
return nil
78-
}
79-
80-
// wrappedServerStream is a thin wrapper around grpc.ServerStream that allows modifying context.
81-
type wrappedServerStream struct {
82-
grpc.ServerStream
83-
// wrappedContext is the wrapper's own Context. You can assign it.
84-
wrappedContext context.Context
85-
86-
validator *protovalidate.Validator
87-
options *options
88-
}
89-
90-
// Context returns the wrapper's WrappedContext, overwriting the nested grpc.ServerStream.Context()
91-
func (w *wrappedServerStream) Context() context.Context {
92-
return w.wrappedContext
93-
}
94-
95-
// wrapServerStream returns a ServerStream that has the ability to overwrite context.
96-
func wrapServerStream(stream grpc.ServerStream) *wrappedServerStream {
97-
return &wrappedServerStream{ServerStream: stream, wrappedContext: stream.Context()}
98-
}
99-
100-
func validationErrToStatus(err error) *status.Status {
101-
// Message is invalid.
10279
if valErr := new(protovalidate.ValidationError); errors.As(err, &valErr) {
80+
// Message is invalid.
10381
st := status.New(codes.InvalidArgument, err.Error())
10482
ds, detErr := st.WithDetails(valErr.ToProto())
10583
if detErr != nil {
106-
return st
84+
return st.Err()
10785
}
108-
return ds
86+
return ds.Err()
10987
}
11088
// CEL expression doesn't compile or type-check.
111-
return status.New(codes.Unknown, err.Error())
89+
return status.Error(codes.Unknown, err.Error())
11290
}

interceptors/protovalidate/protovalidate_test.go

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,15 @@ func TestUnaryServerInterceptor(t *testing.T) {
7070

7171
type server struct {
7272
testvalidatev1.UnimplementedTestValidateServiceServer
73+
74+
called *bool
7375
}
7476

7577
func (g *server) SendStream(
7678
_ *testvalidatev1.SendStreamRequest,
7779
stream testvalidatev1.TestValidateService_SendStreamServer,
7880
) error {
81+
*g.called = true
7982
if err := stream.Send(&testvalidatev1.SendStreamResponse{}); err != nil {
8083
return err
8184
}
@@ -85,7 +88,7 @@ func (g *server) SendStream(
8588

8689
const bufSize = 1024 * 1024
8790

88-
func startGrpcServer(t *testing.T, ignoreMessages ...protoreflect.MessageType) *grpc.ClientConn {
91+
func startGrpcServer(t *testing.T, called *bool, ignoreMessages ...protoreflect.MessageType) *grpc.ClientConn {
8992
lis := bufconn.Listen(bufSize)
9093

9194
validator, err := protovalidate.New()
@@ -98,7 +101,7 @@ func startGrpcServer(t *testing.T, ignoreMessages ...protoreflect.MessageType) *
98101
),
99102
),
100103
)
101-
testvalidatev1.RegisterTestValidateServiceServer(s, &server{})
104+
testvalidatev1.RegisterTestValidateServiceServer(s, &server{called: called})
102105
go func() {
103106
if err = s.Serve(lis); err != nil {
104107
log.Fatalf("Server exited with error: %v", err)
@@ -129,17 +132,24 @@ func startGrpcServer(t *testing.T, ignoreMessages ...protoreflect.MessageType) *
129132

130133
func TestStreamServerInterceptor(t *testing.T) {
131134
t.Run("valid_email", func(t *testing.T) {
135+
called := proto.Bool(false)
132136
client := testvalidatev1.NewTestValidateServiceClient(
133-
startGrpcServer(t),
137+
startGrpcServer(t, called),
134138
)
135139

136-
_, err := client.SendStream(context.Background(), testvalidate.GoodStreamRequest)
140+
out, err := client.SendStream(context.Background(), testvalidate.GoodStreamRequest)
141+
assert.Nil(t, err)
142+
143+
_, err = out.Recv()
144+
t.Log(err)
137145
assert.Nil(t, err)
146+
assert.True(t, *called)
138147
})
139148

140149
t.Run("invalid_email", func(t *testing.T) {
150+
called := proto.Bool(false)
141151
client := testvalidatev1.NewTestValidateServiceClient(
142-
startGrpcServer(t),
152+
startGrpcServer(t, called),
143153
)
144154

145155
out, err := client.SendStream(context.Background(), testvalidate.BadStreamRequest)
@@ -151,18 +161,21 @@ func TestStreamServerInterceptor(t *testing.T) {
151161
ConstraintId: "string.email",
152162
Message: "value must be a valid email address",
153163
}, err)
164+
assert.False(t, *called)
154165
})
155166

156167
t.Run("invalid_email_ignored", func(t *testing.T) {
168+
called := proto.Bool(false)
157169
client := testvalidatev1.NewTestValidateServiceClient(
158-
startGrpcServer(t, testvalidate.BadStreamRequest.ProtoReflect().Type()),
170+
startGrpcServer(t, called, testvalidate.BadStreamRequest.ProtoReflect().Type()),
159171
)
160172

161173
out, err := client.SendStream(context.Background(), testvalidate.BadStreamRequest)
162174
assert.Nil(t, err)
163175

164176
_, err = out.Recv()
165177
assert.Nil(t, err)
178+
assert.True(t, *called)
166179
})
167180
}
168181

0 commit comments

Comments
 (0)