Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Commit b4f04b6

Browse files
Add Interceptor for RateLimit
1 parent 2195181 commit b4f04b6

File tree

3 files changed

+122
-39
lines changed

3 files changed

+122
-39
lines changed

pkg/server/service.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c
9090
auth.AuthenticationLoggingInterceptor,
9191
middlewareInterceptors,
9292
)
93+
if cfg.Security.RateLimit.Enabled {
94+
rateLimiter := plugins.NewRateLimiter(cfg.Security.RateLimit.RequestsPerSecond, cfg.Security.RateLimit.BurstSize, cfg.Security.RateLimit.CleanupInterval.Duration)
95+
rateLimitInterceptors := plugins.RateLimiteInterceptor(*rateLimiter)
96+
chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(chainedUnaryInterceptors, rateLimitInterceptors)
97+
}
9398
} else {
9499
logger.Infof(ctx, "Creating gRPC server without authentication")
95100
chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(grpcprometheus.UnaryServerInterceptor)
@@ -257,6 +262,7 @@ func serveGatewayInsecure(ctx context.Context, pluginRegistry *plugins.Registry,
257262
}
258263

259264
oauth2ResourceServer = oauth2Provider
265+
260266
} else {
261267
oauth2ResourceServer, err = authzserver.NewOAuth2ResourceServer(ctx, authCfg.AppAuth.ExternalAuthServer, authCfg.UserAuth.OpenID.BaseURL)
262268
if err != nil {

plugins/rate_limit.go

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,28 @@
11
package plugins
22

33
import (
4+
"context"
5+
"errors"
46
"fmt"
57
"sync"
68
"time"
79

10+
auth "github.com/flyteorg/flyteadmin/auth"
811
"golang.org/x/time/rate"
12+
"google.golang.org/grpc"
13+
"google.golang.org/grpc/codes"
14+
"google.golang.org/grpc/status"
915
)
1016

11-
type RateLimitError error
17+
type RateLimitExceeded error
1218

1319
// define a struct that contains a map of rate limiters, and a time stamp of last access and a mutex to protect the map
1420
type accessRecords struct {
1521
limiter *rate.Limiter
1622
lastAccess time.Time
1723
}
1824

19-
type Limiter struct {
25+
type LimiterStore struct {
2026
accessPerUser map[string]*accessRecords
2127
mutex *sync.Mutex
2228
requestPerSec int
@@ -27,7 +33,7 @@ type Limiter struct {
2733
// define a function named Allow that takes userID and returns RateLimitError
2834
// the function check if the user is in the map, if not, create a new accessRecords for the user
2935
// then it check if the user can access the resource, if not, return RateLimitError
30-
func (l *Limiter) Allow(userID string) error {
36+
func (l *LimiterStore) Allow(userID string) error {
3137
l.mutex.Lock()
3238
defer l.mutex.Unlock()
3339
if _, ok := l.accessPerUser[userID]; !ok {
@@ -38,13 +44,13 @@ func (l *Limiter) Allow(userID string) error {
3844
}
3945

4046
if !l.accessPerUser[userID].limiter.Allow() {
41-
return RateLimitError(fmt.Errorf("rate limit exceeded"))
47+
return RateLimitExceeded(fmt.Errorf("rate limit exceeded"))
4248
}
4349

4450
return nil
4551
}
4652

47-
func (l *Limiter) clean() {
53+
func (l *LimiterStore) clean() {
4854
l.mutex.Lock()
4955
defer l.mutex.Unlock()
5056
for userID, accessRecord := range l.accessPerUser {
@@ -54,8 +60,8 @@ func (l *Limiter) clean() {
5460
}
5561
}
5662

57-
func NewRateLimiter(requestPerSec int, burstSize int, cleanupInterval time.Duration) *Limiter {
58-
l := &Limiter{
63+
func newRateLimitStore(requestPerSec int, burstSize int, cleanupInterval time.Duration) *LimiterStore {
64+
l := &LimiterStore{
5965
accessPerUser: make(map[string]*accessRecords),
6066
mutex: &sync.Mutex{},
6167
requestPerSec: requestPerSec,
@@ -72,3 +78,35 @@ func NewRateLimiter(requestPerSec int, burstSize int, cleanupInterval time.Durat
7278

7379
return l
7480
}
81+
82+
type RateLimiter struct {
83+
limiter *LimiterStore
84+
}
85+
86+
func (r *RateLimiter) Limit(ctx context.Context) error {
87+
IdenCtx := auth.IdentityContextFromContext(ctx)
88+
if IdenCtx.IsEmpty() {
89+
return errors.New("no identity context found")
90+
}
91+
userID := IdenCtx.UserID()
92+
if err := r.limiter.Allow(userID); err != nil {
93+
return err
94+
}
95+
return nil
96+
}
97+
98+
func NewRateLimiter(requestPerSec int, burstSize int, cleanupInterval time.Duration) *RateLimiter {
99+
limiter := newRateLimitStore(requestPerSec, burstSize, cleanupInterval)
100+
return &RateLimiter{limiter: limiter}
101+
}
102+
103+
func RateLimiteInterceptor(limiter RateLimiter) grpc.UnaryServerInterceptor {
104+
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
105+
resp interface{}, err error) {
106+
if err := limiter.Limit(ctx); err != nil {
107+
return nil, status.Errorf(codes.ResourceExhausted, "rate limit exceeded")
108+
}
109+
110+
return handler(ctx, req)
111+
}
112+
}

plugins/rate_limit_test.go

Lines changed: 71 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,95 @@
11
package plugins
22

33
import (
4+
"context"
45
"testing"
56
"time"
67

8+
auth "github.com/flyteorg/flyteadmin/auth"
79
"github.com/stretchr/testify/assert"
810
)
911

1012
func TestNewRateLimiter(t *testing.T) {
11-
rl := NewRateLimiter(1, 1, time.Second)
12-
assert.NotNil(t, rl)
13+
rlStore := newRateLimitStore(1, 1, time.Second)
14+
assert.NotNil(t, rlStore)
1315
}
1416

15-
func TestLimiter_Allow(t *testing.T) {
16-
rl := NewRateLimiter(1, 1, time.Second)
17-
assert.NoError(t, rl.Allow("hello"))
18-
// assert error type is RateLimitError
19-
assert.Error(t, rl.Allow("hello"))
17+
func TestLimiterAllow(t *testing.T) {
18+
rlStore := newRateLimitStore(1, 1, time.Second)
19+
assert.NoError(t, rlStore.Allow("hello"))
20+
assert.Error(t, rlStore.Allow("hello"))
2021
time.Sleep(time.Second)
21-
assert.NoError(t, rl.Allow("hello"))
22+
assert.NoError(t, rlStore.Allow("hello"))
2223
}
2324

24-
func TestLimiter_AllowBurst(t *testing.T) {
25-
rl := NewRateLimiter(1, 2, time.Second)
26-
assert.NoError(t, rl.Allow("hello"))
27-
assert.NoError(t, rl.Allow("hello"))
28-
assert.Error(t, rl.Allow("hello"))
29-
assert.NoError(t, rl.Allow("world"))
25+
func TestLimiterAllowBurst(t *testing.T) {
26+
rlStore := newRateLimitStore(1, 2, time.Second)
27+
assert.NoError(t, rlStore.Allow("hello"))
28+
assert.NoError(t, rlStore.Allow("hello"))
29+
assert.Error(t, rlStore.Allow("hello"))
30+
assert.NoError(t, rlStore.Allow("world"))
3031
}
3132

32-
func TestLimiter_Clean(t *testing.T) {
33-
rl := NewRateLimiter(1, 1, time.Second)
34-
assert.NoError(t, rl.Allow("hello"))
35-
assert.Error(t, rl.Allow("hello"))
33+
func TestLimiterClean(t *testing.T) {
34+
rlStore := newRateLimitStore(1, 1, time.Second)
35+
assert.NoError(t, rlStore.Allow("hello"))
36+
assert.Error(t, rlStore.Allow("hello"))
3637
time.Sleep(time.Second)
37-
rl.clean()
38-
assert.NoError(t, rl.Allow("hello"))
38+
rlStore.clean()
39+
assert.NoError(t, rlStore.Allow("hello"))
3940
}
4041

41-
func TestLimiter_AllowOnMultipleRequests(t *testing.T) {
42-
rl := NewRateLimiter(1, 1, time.Second)
43-
assert.NoError(t, rl.Allow("a"))
44-
assert.NoError(t, rl.Allow("b"))
45-
assert.NoError(t, rl.Allow("c"))
46-
assert.Error(t, rl.Allow("a"))
47-
assert.Error(t, rl.Allow("b"))
42+
func TestLimiterAllowOnMultipleRequests(t *testing.T) {
43+
rlStore := newRateLimitStore(1, 1, time.Second)
44+
assert.NoError(t, rlStore.Allow("a"))
45+
assert.NoError(t, rlStore.Allow("b"))
46+
assert.NoError(t, rlStore.Allow("c"))
47+
assert.Error(t, rlStore.Allow("a"))
48+
assert.Error(t, rlStore.Allow("b"))
4849

4950
time.Sleep(time.Second)
5051

51-
assert.NoError(t, rl.Allow("a"))
52-
assert.Error(t, rl.Allow("a"))
53-
assert.NoError(t, rl.Allow("b"))
54-
assert.Error(t, rl.Allow("b"))
55-
assert.NoError(t, rl.Allow("c"))
52+
assert.NoError(t, rlStore.Allow("a"))
53+
assert.Error(t, rlStore.Allow("a"))
54+
assert.NoError(t, rlStore.Allow("b"))
55+
assert.Error(t, rlStore.Allow("b"))
56+
assert.NoError(t, rlStore.Allow("c"))
57+
}
58+
59+
func TestRateLimiterLimitPass(t *testing.T) {
60+
rateLimit := NewRateLimiter(1, 1, time.Second)
61+
assert.NotNil(t, rateLimit)
62+
63+
identityCtx, err := auth.NewIdentityContext("audience", "user1", "app1", time.Now(), nil, nil, nil)
64+
assert.NoError(t, err)
65+
66+
ctx := context.WithValue(context.TODO(), auth.ContextKeyIdentityContext, identityCtx)
67+
err = rateLimit.Limit(ctx)
68+
assert.NoError(t, err)
69+
70+
}
71+
72+
func TestRateLimiterLimitStop(t *testing.T) {
73+
rateLimit := NewRateLimiter(1, 1, time.Second)
74+
assert.NotNil(t, rateLimit)
75+
76+
identityCtx, err := auth.NewIdentityContext("audience", "user1", "app1", time.Now(), nil, nil, nil)
77+
assert.NoError(t, err)
78+
ctx := context.WithValue(context.TODO(), auth.ContextKeyIdentityContext, identityCtx)
79+
err = rateLimit.Limit(ctx)
80+
assert.NoError(t, err)
81+
82+
err = rateLimit.Limit(ctx)
83+
assert.Error(t, err)
84+
85+
}
86+
87+
func TestRateLimiterLimitWithoutUserIdentity(t *testing.T) {
88+
rateLimit := NewRateLimiter(1, 1, time.Second)
89+
assert.NotNil(t, rateLimit)
90+
91+
ctx := context.TODO()
92+
93+
err := rateLimit.Limit(ctx)
94+
assert.Error(t, err)
5695
}

0 commit comments

Comments
 (0)