Skip to content

Commit 8036513

Browse files
Merge pull request #714 from akshayjshah/ajs/details
Include error details in protovalidate responses
2 parents 7da22cf + aec8785 commit 8036513

File tree

2 files changed

+55
-18
lines changed

2 files changed

+55
-18
lines changed

interceptors/protovalidate/protovalidate.go

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func UnaryServerInterceptor(validator *protovalidate.Validator, opts ...Option)
2929
break
3030
}
3131
if err = validator.Validate(msg); err != nil {
32-
return nil, status.Error(codes.InvalidArgument, err.Error())
32+
return nil, validationErrToStatus(err).Err()
3333
}
3434
default:
3535
return nil, errors.New("unsupported message type")
@@ -63,12 +63,15 @@ func (w *wrappedServerStream) RecvMsg(m interface{}) error {
6363
return err
6464
}
6565

66-
msg := m.(proto.Message)
66+
msg, ok := m.(proto.Message)
67+
if !ok {
68+
return errors.New("unsupported message type")
69+
}
6770
if w.options.shouldIgnoreMessage(msg.ProtoReflect().Type()) {
6871
return nil
6972
}
7073
if err := w.validator.Validate(msg); err != nil {
71-
return status.Error(codes.InvalidArgument, err.Error())
74+
return validationErrToStatus(err).Err()
7275
}
7376

7477
return nil
@@ -93,3 +96,17 @@ func (w *wrappedServerStream) Context() context.Context {
9396
func wrapServerStream(stream grpc.ServerStream) *wrappedServerStream {
9497
return &wrappedServerStream{ServerStream: stream, wrappedContext: stream.Context()}
9598
}
99+
100+
func validationErrToStatus(err error) *status.Status {
101+
// Message is invalid.
102+
if valErr := new(protovalidate.ValidationError); errors.As(err, &valErr) {
103+
st := status.New(codes.InvalidArgument, err.Error())
104+
ds, detErr := st.WithDetails(valErr.ToProto())
105+
if detErr != nil {
106+
return st
107+
}
108+
return ds
109+
}
110+
// CEL expression doesn't compile or type-check.
111+
return status.New(codes.Unknown, err.Error())
112+
}

interceptors/protovalidate/protovalidate_test.go

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,19 @@ import (
99
"net"
1010
"testing"
1111

12+
"buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate"
1213
"github.com/bufbuild/protovalidate-go"
1314
protovalidate_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/protovalidate"
1415
"github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testvalidate"
1516
testvalidatev1 "github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testvalidate/v1"
1617
"github.com/stretchr/testify/assert"
18+
"github.com/stretchr/testify/require"
1719
"google.golang.org/grpc"
1820
"google.golang.org/grpc/codes"
1921
"google.golang.org/grpc/credentials/insecure"
2022
"google.golang.org/grpc/status"
2123
"google.golang.org/grpc/test/bufconn"
24+
"google.golang.org/protobuf/proto"
2225
"google.golang.org/protobuf/reflect/protoreflect"
2326
)
2427

@@ -31,36 +34,34 @@ func TestUnaryServerInterceptor(t *testing.T) {
3134
handler := func(ctx context.Context, req any) (any, error) {
3235
return "good", nil
3336
}
37+
info := &grpc.UnaryServerInfo{FullMethod: "FakeMethod"}
3438

3539
t.Run("valid_email", func(t *testing.T) {
36-
info := &grpc.UnaryServerInfo{
37-
FullMethod: "FakeMethod",
38-
}
39-
4040
resp, err := interceptor(context.TODO(), testvalidate.GoodUnaryRequest, info, handler)
4141
assert.Nil(t, err)
4242
assert.Equal(t, resp, "good")
4343
})
4444

4545
t.Run("invalid_email", func(t *testing.T) {
46-
info := &grpc.UnaryServerInfo{
47-
FullMethod: "FakeMethod",
48-
}
49-
5046
_, err = interceptor(context.TODO(), testvalidate.BadUnaryRequest, info, handler)
47+
assertEqualViolation(t, &validate.Violation{
48+
FieldPath: "message",
49+
ConstraintId: "string.email",
50+
Message: "value must be a valid email address",
51+
}, err)
52+
})
53+
54+
t.Run("not_protobuf", func(t *testing.T) {
55+
_, err = interceptor(context.Background(), "not protobuf", info, handler)
5156
assert.Error(t, err)
52-
assert.Equal(t, codes.InvalidArgument, status.Code(err))
57+
assert.Equal(t, codes.Unknown, status.Code(err))
5358
})
5459

5560
interceptor = protovalidate_middleware.UnaryServerInterceptor(validator,
5661
protovalidate_middleware.WithIgnoreMessages(testvalidate.BadUnaryRequest.ProtoReflect().Type()),
5762
)
5863

5964
t.Run("invalid_email_ignored", func(t *testing.T) {
60-
info := &grpc.UnaryServerInfo{
61-
FullMethod: "FakeMethod",
62-
}
63-
6465
resp, err := interceptor(context.TODO(), testvalidate.BadUnaryRequest, info, handler)
6566
assert.Nil(t, err)
6667
assert.Equal(t, resp, "good")
@@ -145,8 +146,11 @@ func TestStreamServerInterceptor(t *testing.T) {
145146
assert.Nil(t, err)
146147

147148
_, err = out.Recv()
148-
assert.Error(t, err)
149-
assert.Equal(t, codes.InvalidArgument, status.Code(err))
149+
assertEqualViolation(t, &validate.Violation{
150+
FieldPath: "message",
151+
ConstraintId: "string.email",
152+
Message: "value must be a valid email address",
153+
}, err)
150154
})
151155

152156
t.Run("invalid_email_ignored", func(t *testing.T) {
@@ -161,3 +165,19 @@ func TestStreamServerInterceptor(t *testing.T) {
161165
assert.Nil(t, err)
162166
})
163167
}
168+
169+
func assertEqualViolation(tb testing.TB, want *validate.Violation, got error) bool {
170+
require.Error(tb, got)
171+
st := status.Convert(got)
172+
assert.Equal(tb, codes.InvalidArgument, st.Code())
173+
details := st.Proto().GetDetails()
174+
require.Len(tb, details, 1)
175+
gotpb, unwrapErr := details[0].UnmarshalNew()
176+
require.Nil(tb, unwrapErr)
177+
violations := &validate.Violations{
178+
Violations: []*validate.Violation{want},
179+
}
180+
tb.Logf("got: %v", gotpb)
181+
tb.Logf("want: %v", violations)
182+
return assert.True(tb, proto.Equal(gotpb, violations))
183+
}

0 commit comments

Comments
 (0)