Skip to content

Commit aec8785

Browse files
committed
protovalidate: send violations as error details
Amend the unary and streaming interceptors to send validation errors to the client as an error detail. This allows client code to easily parse and work with the structured validation information: for example, a UI might want to display validation errors next to the relevant fields in a form. Signed-off-by: Akshay Shah <[email protected]>
1 parent 6e75075 commit aec8785

File tree

2 files changed

+45
-6
lines changed

2 files changed

+45
-6
lines changed

interceptors/protovalidate/protovalidate.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func UnaryServerInterceptor(validator *protovalidate.Validator, opts ...Option)
2929
break
3030
}
3131
if err = validator.Validate(msg); err != nil {
32-
return nil, status.Error(codes.InvalidArgument, err.Error())
32+
return nil, validationErrToStatus(err).Err()
3333
}
3434
default:
3535
return nil, errors.New("unsupported message type")
@@ -71,7 +71,7 @@ func (w *wrappedServerStream) RecvMsg(m interface{}) error {
7171
return nil
7272
}
7373
if err := w.validator.Validate(msg); err != nil {
74-
return status.Error(codes.InvalidArgument, err.Error())
74+
return validationErrToStatus(err).Err()
7575
}
7676

7777
return nil
@@ -96,3 +96,17 @@ func (w *wrappedServerStream) Context() context.Context {
9696
func wrapServerStream(stream grpc.ServerStream) *wrappedServerStream {
9797
return &wrappedServerStream{ServerStream: stream, wrappedContext: stream.Context()}
9898
}
99+
100+
func validationErrToStatus(err error) *status.Status {
101+
// Message is invalid.
102+
if valErr := new(protovalidate.ValidationError); errors.As(err, &valErr) {
103+
st := status.New(codes.InvalidArgument, err.Error())
104+
ds, detErr := st.WithDetails(valErr.ToProto())
105+
if detErr != nil {
106+
return st
107+
}
108+
return ds
109+
}
110+
// CEL expression doesn't compile or type-check.
111+
return status.New(codes.Unknown, err.Error())
112+
}

interceptors/protovalidate/protovalidate_test.go

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,19 @@ import (
99
"net"
1010
"testing"
1111

12+
"buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate"
1213
"github.com/bufbuild/protovalidate-go"
1314
protovalidate_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/protovalidate"
1415
"github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testvalidate"
1516
testvalidatev1 "github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testvalidate/v1"
1617
"github.com/stretchr/testify/assert"
18+
"github.com/stretchr/testify/require"
1719
"google.golang.org/grpc"
1820
"google.golang.org/grpc/codes"
1921
"google.golang.org/grpc/credentials/insecure"
2022
"google.golang.org/grpc/status"
2123
"google.golang.org/grpc/test/bufconn"
24+
"google.golang.org/protobuf/proto"
2225
"google.golang.org/protobuf/reflect/protoreflect"
2326
)
2427

@@ -41,8 +44,11 @@ func TestUnaryServerInterceptor(t *testing.T) {
4144

4245
t.Run("invalid_email", func(t *testing.T) {
4346
_, err = interceptor(context.TODO(), testvalidate.BadUnaryRequest, info, handler)
44-
assert.Error(t, err)
45-
assert.Equal(t, codes.InvalidArgument, status.Code(err))
47+
assertEqualViolation(t, &validate.Violation{
48+
FieldPath: "message",
49+
ConstraintId: "string.email",
50+
Message: "value must be a valid email address",
51+
}, err)
4652
})
4753

4854
t.Run("not_protobuf", func(t *testing.T) {
@@ -140,8 +146,11 @@ func TestStreamServerInterceptor(t *testing.T) {
140146
assert.Nil(t, err)
141147

142148
_, err = out.Recv()
143-
assert.Error(t, err)
144-
assert.Equal(t, codes.InvalidArgument, status.Code(err))
149+
assertEqualViolation(t, &validate.Violation{
150+
FieldPath: "message",
151+
ConstraintId: "string.email",
152+
Message: "value must be a valid email address",
153+
}, err)
145154
})
146155

147156
t.Run("invalid_email_ignored", func(t *testing.T) {
@@ -156,3 +165,19 @@ func TestStreamServerInterceptor(t *testing.T) {
156165
assert.Nil(t, err)
157166
})
158167
}
168+
169+
func assertEqualViolation(tb testing.TB, want *validate.Violation, got error) bool {
170+
require.Error(tb, got)
171+
st := status.Convert(got)
172+
assert.Equal(tb, codes.InvalidArgument, st.Code())
173+
details := st.Proto().GetDetails()
174+
require.Len(tb, details, 1)
175+
gotpb, unwrapErr := details[0].UnmarshalNew()
176+
require.Nil(tb, unwrapErr)
177+
violations := &validate.Violations{
178+
Violations: []*validate.Violation{want},
179+
}
180+
tb.Logf("got: %v", gotpb)
181+
tb.Logf("want: %v", violations)
182+
return assert.True(tb, proto.Equal(gotpb, violations))
183+
}

0 commit comments

Comments
 (0)