Skip to content

Commit 001ba37

Browse files
authored
✨ new feat: selector middleware (#511)
Signed-off-by: aimuz <[email protected]> Signed-off-by: aimuz <[email protected]>
1 parent f3ab992 commit 001ba37

File tree

4 files changed

+290
-0
lines changed

4 files changed

+290
-0
lines changed

interceptors/selector/doc.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// Copyright (c) The go-grpc-middleware Authors.
2+
// Licensed under the Apache License 2.0.
3+
4+
/*
5+
Package selector
6+
7+
`selector` a generic server-side selector middleware for gRPC.
8+
9+
# Server Side Selector Middleware
10+
It allows to set check rules to allowlist or blocklist middleware such as Auth
11+
interceptors to toggle behavior on or off based on the request path.
12+
13+
Please see examples for simple examples of use.
14+
*/
15+
package selector

interceptors/selector/selector.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright (c) The go-grpc-middleware Authors.
2+
// Licensed under the Apache License 2.0.
3+
4+
package selector
5+
6+
import (
7+
"context"
8+
9+
"google.golang.org/grpc"
10+
)
11+
12+
type MatchFunc func(ctx context.Context, fullMethod string) bool
13+
14+
// UnaryServerInterceptor returns a new unary server interceptor that will decide whether to call
15+
// the interceptor on the behavior of the MatchFunc.
16+
func UnaryServerInterceptor(interceptors grpc.UnaryServerInterceptor, match MatchFunc) grpc.UnaryServerInterceptor {
17+
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
18+
if match(ctx, info.FullMethod) {
19+
return interceptors(ctx, req, info, handler)
20+
}
21+
return handler(ctx, req)
22+
}
23+
}
24+
25+
// StreamServerInterceptor returns a new stream server interceptor that will decide whether to call
26+
// the interceptor on the behavior of the MatchFunc.
27+
func StreamServerInterceptor(interceptors grpc.StreamServerInterceptor, match MatchFunc) grpc.StreamServerInterceptor {
28+
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
29+
if match(ss.Context(), info.FullMethod) {
30+
return interceptors(srv, ss, info, handler)
31+
}
32+
return handler(srv, ss)
33+
}
34+
}
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// Copyright (c) The go-grpc-middleware Authors.
2+
// Licensed under the Apache License 2.0.
3+
4+
package selector_test
5+
6+
import (
7+
"context"
8+
9+
"google.golang.org/grpc"
10+
"google.golang.org/grpc/codes"
11+
"google.golang.org/grpc/status"
12+
13+
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth"
14+
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
15+
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/ratelimit"
16+
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/selector"
17+
)
18+
19+
// alwaysPassLimiter is an example limiter which implements Limiter interface.
20+
// It does not limit any request because Limit function always returns false.
21+
type alwaysPassLimiter struct{}
22+
23+
func (*alwaysPassLimiter) Limit(_ context.Context) error {
24+
return nil
25+
}
26+
27+
func healthSkip(ctx context.Context, fullMethod string) bool {
28+
return fullMethod != "/ping.v1.PingService/Health"
29+
}
30+
31+
func Example_ratelimit() {
32+
limiter := &alwaysPassLimiter{}
33+
_ = grpc.NewServer(
34+
grpc.ChainUnaryInterceptor(
35+
selector.UnaryServerInterceptor(ratelimit.UnaryServerInterceptor(limiter), healthSkip),
36+
),
37+
grpc.ChainStreamInterceptor(
38+
selector.StreamServerInterceptor(ratelimit.StreamServerInterceptor(limiter), healthSkip),
39+
),
40+
)
41+
}
42+
43+
var tokenInfoKey struct{}
44+
45+
func parseToken(token string) (struct{}, error) {
46+
return struct{}{}, nil
47+
}
48+
49+
func userClaimFromToken(struct{}) string {
50+
return "foobar"
51+
}
52+
53+
// exampleAuthFunc is used by a middleware to authenticate requests
54+
func exampleAuthFunc(ctx context.Context) (context.Context, error) {
55+
token, err := auth.AuthFromMD(ctx, "bearer")
56+
if err != nil {
57+
return nil, err
58+
}
59+
60+
tokenInfo, err := parseToken(token)
61+
if err != nil {
62+
return nil, status.Errorf(codes.Unauthenticated, "invalid auth token: %v", err)
63+
}
64+
65+
ctx = logging.InjectFields(ctx, logging.Fields{"auth.sub", userClaimFromToken(tokenInfo)})
66+
67+
// WARNING: In production define your own type to avoid context collisions.
68+
return context.WithValue(ctx, tokenInfoKey, tokenInfo), nil
69+
}
70+
71+
func loginSkip(ctx context.Context, fullMethod string) bool {
72+
return fullMethod != "/auth.v1.AuthService/Login"
73+
}
74+
75+
func Example_login() {
76+
_ = grpc.NewServer(
77+
grpc.ChainUnaryInterceptor(
78+
selector.UnaryServerInterceptor(auth.UnaryServerInterceptor(exampleAuthFunc), loginSkip),
79+
),
80+
grpc.ChainStreamInterceptor(
81+
selector.StreamServerInterceptor(auth.StreamServerInterceptor(exampleAuthFunc), loginSkip),
82+
),
83+
)
84+
}
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
// Copyright (c) The go-grpc-middleware Authors.
2+
// Licensed under the Apache License 2.0.
3+
4+
package selector
5+
6+
import (
7+
"context"
8+
"testing"
9+
10+
"github.com/pkg/errors"
11+
"github.com/stretchr/testify/assert"
12+
"google.golang.org/grpc"
13+
14+
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth"
15+
)
16+
17+
var blockList = []string{"/auth.v1beta1.AuthService/Login"}
18+
19+
const errMsgFake = "fake error"
20+
21+
var ctxKey = struct{}{}
22+
23+
// allow After the method is matched, the interceptor is run
24+
func allow(methods []string) MatchFunc {
25+
return func(ctx context.Context, fullMethod string) bool {
26+
for _, s := range methods {
27+
if s == fullMethod {
28+
return true
29+
}
30+
}
31+
return false
32+
}
33+
}
34+
35+
// Block the interceptor will not run after the method matches
36+
func block(methods []string) MatchFunc {
37+
allow := allow(methods)
38+
return func(ctx context.Context, fullMethod string) bool {
39+
return !allow(ctx, fullMethod)
40+
}
41+
}
42+
43+
type mockGRPCServerStream struct {
44+
grpc.ServerStream
45+
46+
ctx context.Context
47+
}
48+
49+
func (m *mockGRPCServerStream) Context() context.Context {
50+
return m.ctx
51+
}
52+
53+
func TestUnaryServerInterceptor(t *testing.T) {
54+
ctx := context.Background()
55+
interceptor := UnaryServerInterceptor(auth.UnaryServerInterceptor(
56+
func(ctx context.Context) (context.Context, error) {
57+
newCtx := context.WithValue(ctx, ctxKey, true)
58+
return newCtx, nil
59+
},
60+
), block(blockList))
61+
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
62+
val := ctx.Value(ctxKey)
63+
if b, ok := val.(bool); ok && b {
64+
return "good", nil
65+
}
66+
return nil, errors.New(errMsgFake)
67+
}
68+
69+
t.Run("nextStep", func(t *testing.T) {
70+
info := &grpc.UnaryServerInfo{
71+
FullMethod: "FakeMethod",
72+
}
73+
resp, err := interceptor(ctx, nil, info, handler)
74+
assert.Nil(t, err)
75+
assert.Equal(t, resp, "good")
76+
})
77+
78+
t.Run("skipped", func(t *testing.T) {
79+
info := &grpc.UnaryServerInfo{
80+
FullMethod: "/auth.v1beta1.AuthService/Login",
81+
}
82+
resp, err := interceptor(ctx, nil, info, handler)
83+
assert.Nil(t, resp)
84+
assert.EqualError(t, err, errMsgFake)
85+
})
86+
}
87+
88+
func TestStreamServerInterceptor(t *testing.T) {
89+
ctx := context.Background()
90+
interceptor := StreamServerInterceptor(auth.StreamServerInterceptor(
91+
func(ctx context.Context) (context.Context, error) {
92+
newCtx := context.WithValue(ctx, ctxKey, true)
93+
return newCtx, nil
94+
},
95+
), block(blockList))
96+
97+
handler := func(srv interface{}, stream grpc.ServerStream) error {
98+
ctx := stream.Context()
99+
val := ctx.Value(ctxKey)
100+
if b, ok := val.(bool); ok && b {
101+
return nil
102+
}
103+
return errors.New(errMsgFake)
104+
}
105+
106+
t.Run("nextStep", func(t *testing.T) {
107+
info := &grpc.StreamServerInfo{
108+
FullMethod: "FakeMethod",
109+
}
110+
err := interceptor(nil, &mockGRPCServerStream{ctx: ctx}, info, handler)
111+
assert.Nil(t, err)
112+
})
113+
114+
t.Run("skipped", func(t *testing.T) {
115+
info := &grpc.StreamServerInfo{
116+
FullMethod: "/auth.v1beta1.AuthService/Login",
117+
}
118+
err := interceptor(nil, &mockGRPCServerStream{ctx: ctx}, info, handler)
119+
assert.EqualError(t, err, errMsgFake)
120+
})
121+
}
122+
123+
func TestAllow(t *testing.T) {
124+
type args struct {
125+
methods []string
126+
}
127+
tests := []struct {
128+
name string
129+
args args
130+
method string
131+
want bool
132+
}{
133+
{
134+
name: "false",
135+
args: args{
136+
methods: []string{"/auth.v1beta1.AuthService/Login"},
137+
},
138+
method: "/testing.testpb.v1.TestService/PingList",
139+
want: false,
140+
},
141+
{
142+
name: "true",
143+
args: args{
144+
methods: []string{"/auth.v1beta1.AuthService/Login"},
145+
},
146+
method: "/auth.v1beta1.AuthService/Login",
147+
want: true,
148+
},
149+
}
150+
for _, tt := range tests {
151+
t.Run(tt.name, func(t *testing.T) {
152+
allow := allow(tt.args.methods)
153+
want := allow(context.Background(), tt.method)
154+
assert.Equalf(t, tt.want, want, "Allow(%v)(ctx, %v)", tt.args.methods, tt.method)
155+
})
156+
}
157+
}

0 commit comments

Comments
 (0)