Skip to content

Commit 1fad59e

Browse files
committed
chore: revert changes to handwrittenvalidation
1 parent 041ecce commit 1fad59e

File tree

7 files changed

+68
-130
lines changed

7 files changed

+68
-130
lines changed

internal/middleware/handwrittenvalidation/handwrittenvalidation.go

Lines changed: 15 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -2,128 +2,53 @@ package handwrittenvalidation
22

33
import (
44
"context"
5-
"errors"
6-
"fmt"
75

8-
"buf.build/go/protovalidate"
9-
"go.opentelemetry.io/otel"
10-
otelcodes "go.opentelemetry.io/otel/codes"
11-
"go.opentelemetry.io/otel/trace"
126
"google.golang.org/grpc"
137
"google.golang.org/grpc/codes"
148
"google.golang.org/grpc/status"
15-
"google.golang.org/protobuf/proto"
169
)
1710

18-
var tracer = otel.Tracer("spicedb/internal/middleware")
19-
20-
// mustNewProtoValidator wraps protovalidate.New() to panic
21-
// if the validator can't be constructed.
22-
func mustNewProtoValidator(opts ...protovalidate.ValidatorOption) protovalidate.Validator {
23-
validator, err := protovalidate.New(opts...)
24-
if err != nil {
25-
wrappedErr := fmt.Errorf("could not construct validator: %w", err)
26-
panic(wrappedErr)
27-
}
28-
return validator
29-
}
30-
3111
type handwrittenValidator interface {
3212
HandwrittenValidate() error
3313
}
3414

35-
// UnaryServerInterceptor returns a function that performs standard proto validation and handwritten validation (if any) on the incoming request.
36-
func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
37-
protovalidator := mustNewProtoValidator()
38-
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
39-
_, span := tracer.Start(ctx, "protovalidate")
40-
err := standardValidate(req, protovalidator, span)
41-
if err != nil {
42-
span.End()
43-
return nil, err
44-
}
45-
46-
err = handwrittenValidate(req, span)
15+
// UnaryServerInterceptor returns a new unary server interceptor that runs the handwritten validation
16+
// on the incoming request, if any.
17+
func UnaryServerInterceptor(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
18+
validator, ok := req.(handwrittenValidator)
19+
if ok {
20+
err := validator.HandwrittenValidate()
4721
if err != nil {
48-
span.End()
49-
return nil, err
22+
return nil, status.Errorf(codes.InvalidArgument, "%s", err)
5023
}
51-
52-
span.End()
53-
54-
return handler(ctx, req)
5524
}
25+
26+
return handler(ctx, req)
5627
}
5728

58-
// StreamServerInterceptor returns a function that performs standard proto validation and handwritten validation (if any) on the incoming request.
29+
// StreamServerInterceptor returns a new stream server interceptor that runs the handwritten validation
30+
// on the incoming request messages, if any.
5931
func StreamServerInterceptor(srv any, stream grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
60-
wrapper := &recvWrapper{ServerStream: stream, protovalidator: mustNewProtoValidator()}
32+
wrapper := &recvWrapper{stream}
6133
return handler(srv, wrapper)
6234
}
6335

6436
type recvWrapper struct {
6537
grpc.ServerStream
66-
protovalidator protovalidate.Validator
6738
}
6839

6940
func (s *recvWrapper) RecvMsg(m any) error {
7041
if err := s.ServerStream.RecvMsg(m); err != nil {
7142
return err
7243
}
7344

74-
_, span := tracer.Start(s.Context(), "protovalidate")
75-
err := standardValidate(m, s.protovalidator, span)
76-
if err != nil {
77-
span.End()
78-
return err
79-
}
80-
81-
err = handwrittenValidate(m, span)
82-
if err != nil {
83-
span.End()
84-
return err
85-
}
86-
87-
span.End()
88-
89-
return nil
90-
}
91-
92-
// standardValidate was copied from https://github.com/grpc-ecosystem/go-grpc-middleware/blob/ab2131d954af9580c1b49a3d9475f6adbe5de9d3/interceptors/protovalidate/protovalidate.go#L68
93-
// it validates the proto and if an error occurs, marks the span as errored.
94-
func standardValidate(m any, validator protovalidate.Validator, span trace.Span) error {
95-
msg, ok := m.(proto.Message)
96-
if !ok {
97-
return status.Errorf(codes.Internal, "unsupported message type: %T", m)
98-
}
99-
err := validator.Validate(msg)
100-
if err == nil {
101-
return nil
102-
}
103-
var valErr *protovalidate.ValidationError
104-
if errors.As(err, &valErr) {
105-
span.SetStatus(otelcodes.Error, err.Error())
106-
span.RecordError(err)
107-
st := status.New(codes.InvalidArgument, err.Error())
108-
ds, detErr := st.WithDetails(valErr.ToProto())
109-
if detErr != nil {
110-
return st.Err()
111-
}
112-
return ds.Err()
113-
}
114-
115-
return status.Error(codes.Internal, err.Error())
116-
}
117-
118-
func handwrittenValidate(req any, span trace.Span) error {
119-
validator, ok := req.(handwrittenValidator)
45+
validator, ok := m.(handwrittenValidator)
12046
if ok {
12147
err := validator.HandwrittenValidate()
12248
if err != nil {
123-
span.SetStatus(otelcodes.Error, err.Error())
124-
span.RecordError(err)
125-
return status.Errorf(codes.InvalidArgument, "%s", err)
49+
return err
12650
}
12751
}
52+
12853
return nil
12954
}

internal/middleware/interceptorwrapper/interceptorwrapper.go

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,19 @@ import (
77
"google.golang.org/grpc"
88
)
99

10+
var tracer = otel.Tracer("spicedb/internal/middleware")
11+
1012
// WrapUnaryServerInterceptorWithSpans returns a new interceptor that wraps the given interceptor
1113
// with a span, measuring the duration of the interceptor's pre-handler logic.
1214
func WrapUnaryServerInterceptorWithSpans(
1315
inner grpc.UnaryServerInterceptor,
14-
tracerName, spanName string,
16+
spanName string,
1517
) grpc.UnaryServerInterceptor {
16-
t := otel.Tracer(tracerName)
1718
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
18-
ctx, span := t.Start(ctx, spanName)
19+
ctx, span := tracer.Start(ctx, spanName)
20+
// NOTE: this shim is what lets us measure how long the interceptor is doing work.
21+
// It's the handler that we pass to the wrapped interceptor, so `span.End()` will be
22+
// called when the handler itself is called.
1923
shimHandler := func(ctx context.Context, req any) (any, error) {
2024
span.End()
2125
return handler(ctx, req)
@@ -27,32 +31,3 @@ func WrapUnaryServerInterceptorWithSpans(
2731
return resp, err
2832
}
2933
}
30-
31-
// WrapStreamServerInterceptorWithSpans returns a new interceptor that wraps the given interceptor
32-
// with a span, measuring the duration of the interceptor's pre-handler logic.
33-
func WrapStreamServerInterceptorWithSpans(
34-
inner grpc.StreamServerInterceptor,
35-
tracerName, spanName string,
36-
) grpc.StreamServerInterceptor {
37-
t := otel.Tracer(tracerName)
38-
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
39-
ctx, span := t.Start(ss.Context(), spanName)
40-
wrappedStream := &contextStream{ServerStream: ss, ctx: ctx}
41-
shimHandler := func(srv any, stream grpc.ServerStream) error {
42-
span.End()
43-
return handler(srv, stream)
44-
}
45-
err := inner(srv, wrappedStream, info, shimHandler)
46-
if span.IsRecording() {
47-
span.End()
48-
}
49-
return err
50-
}
51-
}
52-
53-
type contextStream struct {
54-
grpc.ServerStream
55-
ctx context.Context
56-
}
57-
58-
func (s *contextStream) Context() context.Context { return s.ctx }

internal/services/v1/experimental.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"time"
1414

1515
"github.com/ccoveille/go-safecast/v2"
16+
grpcvalidate "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/protovalidate"
1617
"google.golang.org/grpc"
1718
"google.golang.org/grpc/codes"
1819
"google.golang.org/protobuf/types/known/timestamppb"
@@ -23,6 +24,7 @@ import (
2324
log "github.com/authzed/spicedb/internal/logging"
2425
"github.com/authzed/spicedb/internal/middleware"
2526
"github.com/authzed/spicedb/internal/middleware/handwrittenvalidation"
27+
"github.com/authzed/spicedb/internal/middleware/interceptorwrapper"
2628
"github.com/authzed/spicedb/internal/middleware/perfinsights"
2729
"github.com/authzed/spicedb/internal/middleware/streamtimeout"
2830
"github.com/authzed/spicedb/internal/middleware/usagemetrics"
@@ -35,6 +37,7 @@ import (
3537
"github.com/authzed/spicedb/pkg/datastore"
3638
dsoptions "github.com/authzed/spicedb/pkg/datastore/options"
3739
"github.com/authzed/spicedb/pkg/datastore/queryshape"
40+
"github.com/authzed/spicedb/pkg/genutil"
3841
"github.com/authzed/spicedb/pkg/middleware/consistency"
3942
core "github.com/authzed/spicedb/pkg/proto/core/v1"
4043
dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
@@ -92,14 +95,18 @@ func NewExperimentalServer(dispatch dispatch.Dispatcher, permServerConfig Permis
9295
chunkSize = 100
9396
}
9497

98+
validator := genutil.MustNewProtoValidator()
99+
95100
return &experimentalServer{
96101
WithServiceSpecificInterceptors: shared.WithServiceSpecificInterceptors{
97102
Unary: middleware.ChainUnaryServer(
98-
handwrittenvalidation.UnaryServerInterceptor(),
103+
interceptorwrapper.WrapUnaryServerInterceptorWithSpans(grpcvalidate.UnaryServerInterceptor(validator), "protovalidate"),
104+
handwrittenvalidation.UnaryServerInterceptor,
99105
usagemetrics.UnaryServerInterceptor(),
100106
perfinsights.UnaryServerInterceptor(permServerConfig.PerformanceInsightMetricsEnabled),
101107
),
102108
Stream: middleware.ChainStreamServer(
109+
grpcvalidate.StreamServerInterceptor(validator),
103110
handwrittenvalidation.StreamServerInterceptor,
104111
usagemetrics.StreamServerInterceptor(),
105112
streamtimeout.MustStreamServerInterceptor(config.StreamReadTimeout),

internal/services/v1/relationships.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"time"
99

10+
grpcvalidate "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/protovalidate"
1011
"github.com/prometheus/client_golang/prometheus"
1112
"github.com/prometheus/client_golang/prometheus/promauto"
1213
"go.opentelemetry.io/otel/trace"
@@ -18,6 +19,7 @@ import (
1819
"github.com/authzed/spicedb/internal/dispatch"
1920
"github.com/authzed/spicedb/internal/middleware"
2021
"github.com/authzed/spicedb/internal/middleware/handwrittenvalidation"
22+
"github.com/authzed/spicedb/internal/middleware/interceptorwrapper"
2123
"github.com/authzed/spicedb/internal/middleware/perfinsights"
2224
"github.com/authzed/spicedb/internal/middleware/streamtimeout"
2325
"github.com/authzed/spicedb/internal/middleware/usagemetrics"
@@ -145,16 +147,20 @@ func NewPermissionsServer(
145147
ExperimentalQueryPlan: config.ExperimentalQueryPlan,
146148
}
147149

150+
validator := genutil.MustNewProtoValidator()
151+
148152
return &permissionServer{
149153
dispatch: dispatch,
150154
config: configWithDefaults,
151155
WithServiceSpecificInterceptors: shared.WithServiceSpecificInterceptors{
152156
Unary: middleware.ChainUnaryServer(
153-
handwrittenvalidation.UnaryServerInterceptor(),
157+
interceptorwrapper.WrapUnaryServerInterceptorWithSpans(grpcvalidate.UnaryServerInterceptor(validator), "protovalidate"),
158+
handwrittenvalidation.UnaryServerInterceptor,
154159
usagemetrics.UnaryServerInterceptor(),
155160
perfinsights.UnaryServerInterceptor(configWithDefaults.PerformanceInsightMetricsEnabled),
156161
),
157162
Stream: middleware.ChainStreamServer(
163+
grpcvalidate.StreamServerInterceptor(validator),
158164
handwrittenvalidation.StreamServerInterceptor,
159165
usagemetrics.StreamServerInterceptor(),
160166
streamtimeout.MustStreamServerInterceptor(configWithDefaults.StreamingAPITimeout),

internal/services/v1/schema.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@ import (
55
"sort"
66
"strings"
77

8+
grpcvalidate "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/protovalidate"
9+
810
v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
911

1012
log "github.com/authzed/spicedb/internal/logging"
1113
"github.com/authzed/spicedb/internal/middleware"
12-
"github.com/authzed/spicedb/internal/middleware/handwrittenvalidation"
14+
"github.com/authzed/spicedb/internal/middleware/interceptorwrapper"
1315
"github.com/authzed/spicedb/internal/middleware/perfinsights"
1416
"github.com/authzed/spicedb/internal/middleware/usagemetrics"
1517
"github.com/authzed/spicedb/internal/services/shared"
@@ -44,15 +46,17 @@ type SchemaServerConfig struct {
4446
func NewSchemaServer(config SchemaServerConfig) v1.SchemaServiceServer {
4547
cts := caveattypes.TypeSetOrDefault(config.CaveatTypeSet)
4648

49+
validator := genutil.MustNewProtoValidator()
50+
4751
return &schemaServer{
4852
WithServiceSpecificInterceptors: shared.WithServiceSpecificInterceptors{
4953
Unary: middleware.ChainUnaryServer(
50-
handwrittenvalidation.UnaryServerInterceptor(),
54+
interceptorwrapper.WrapUnaryServerInterceptorWithSpans(grpcvalidate.UnaryServerInterceptor(validator), "protovalidate"),
5155
usagemetrics.UnaryServerInterceptor(),
5256
perfinsights.UnaryServerInterceptor(config.PerformanceInsightMetricsEnabled),
5357
),
5458
Stream: middleware.ChainStreamServer(
55-
handwrittenvalidation.StreamServerInterceptor,
59+
grpcvalidate.StreamServerInterceptor(validator),
5660
usagemetrics.StreamServerInterceptor(),
5761
perfinsights.StreamServerInterceptor(config.PerformanceInsightMetricsEnabled),
5862
),

internal/services/v1/watch.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,18 @@ import (
55
"slices"
66
"time"
77

8+
grpcvalidate "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/protovalidate"
89
"google.golang.org/grpc/codes"
910
"google.golang.org/grpc/status"
1011
"google.golang.org/protobuf/types/known/structpb"
1112

1213
v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
1314

14-
"github.com/authzed/spicedb/internal/middleware/handwrittenvalidation"
1515
"github.com/authzed/spicedb/internal/middleware/usagemetrics"
1616
"github.com/authzed/spicedb/internal/services/shared"
1717
"github.com/authzed/spicedb/pkg/datalayer"
1818
"github.com/authzed/spicedb/pkg/datastore"
19+
"github.com/authzed/spicedb/pkg/genutil"
1920
"github.com/authzed/spicedb/pkg/genutil/mapz"
2021
dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
2122
"github.com/authzed/spicedb/pkg/tuple"
@@ -31,9 +32,11 @@ type watchServer struct {
3132

3233
// NewWatchServer creates an instance of the watch server.
3334
func NewWatchServer(heartbeatDuration time.Duration) v1.WatchServiceServer {
35+
validator := genutil.MustNewProtoValidator()
36+
3437
s := &watchServer{
3538
WithStreamServiceSpecificInterceptor: shared.WithStreamServiceSpecificInterceptor{
36-
Stream: handwrittenvalidation.StreamServerInterceptor,
39+
Stream: grpcvalidate.StreamServerInterceptor(validator),
3740
},
3841
heartbeatDuration: heartbeatDuration,
3942
}
@@ -49,7 +52,7 @@ func (ws *watchServer) Watch(req *v1.WatchRequest, stream v1.WatchService_WatchS
4952
}
5053
}
5154

52-
objectTypes := mapz.NewSet[string](req.GetOptionalObjectTypes()...)
55+
objectTypes := mapz.NewSet(req.GetOptionalObjectTypes()...)
5356

5457
ctx := stream.Context()
5558
dl := datalayer.MustFromContext(ctx)

pkg/genutil/protovalidate.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package genutil
2+
3+
import (
4+
"fmt"
5+
6+
"buf.build/go/protovalidate"
7+
)
8+
9+
// MustNewProtoValidator wraps protovalidate.New() to panic
10+
// if the validator can't be constructed.
11+
func MustNewProtoValidator(opts ...protovalidate.ValidatorOption) protovalidate.Validator {
12+
validator, err := protovalidate.New(opts...)
13+
if err != nil {
14+
wrappedErr := fmt.Errorf("could not construct validator: %w", err)
15+
panic(wrappedErr)
16+
}
17+
return validator
18+
}

0 commit comments

Comments
 (0)