Skip to content

Commit 4f2f56e

Browse files
authored
Adding ProjectInfo to AccessQuota (#137)
1 parent 58e5e11 commit 4f2f56e

File tree

12 files changed

+178
-159
lines changed

12 files changed

+178
-159
lines changed

cache.go

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,13 @@ var (
5454
_ UsageCache = (*RedisCache)(nil)
5555
)
5656

57-
type Service interface {
58-
GetService() proto.Service
59-
}
60-
61-
func NewLimitCounter(svc Service, cfg RedisConfig, logger *slog.Logger) httprate.LimitCounter {
57+
func NewLimitCounter(svc proto.Service, cfg RedisConfig, logger *slog.Logger) httprate.LimitCounter {
6258
if !cfg.Enabled {
6359
return nil
6460
}
6561

6662
prefix := redisRLPrefix
67-
if s := svc.GetService().String(); s != "" {
63+
if s := svc.String(); s != "" {
6864
prefix = fmt.Sprintf("%s%s:", redisRLPrefix, s)
6965
}
7066

@@ -91,7 +87,7 @@ func NewLimitCounter(svc Service, cfg RedisConfig, logger *slog.Logger) httprate
9187
const (
9288
defaultExpRedis = time.Hour
9389
defaultExpLRU = time.Minute
94-
cacheVersion = "v1"
90+
cacheVersion = "v2"
9591
)
9692

9793
// usageKey returns the redis key for storing usage amount.

client.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ func (c *Client) FetchProjectQuota(ctx context.Context, projectID uint64, chainI
133133
logger.Warn("failed to cache project quota", slog.Any("error", err))
134134
}
135135
}
136-
if err := quota.AccessKey.ValidateChains(chainIDs); err != nil {
136+
if err := quota.Info.ValidateChains(chainIDs); err != nil {
137137
return quota, proto.ErrInvalidChain.WithCause(err)
138138
}
139139
return quota, nil
@@ -163,8 +163,11 @@ func (c *Client) FetchKeyQuota(ctx context.Context, accessKey, origin string, ch
163163
logger.Warn("failed to cache access quota", slog.Any("error", err))
164164
}
165165
}
166+
if err := quota.Info.ValidateChains(chainIDs); err != nil {
167+
return quota, proto.ErrInvalidChain.WithCause(err)
168+
}
166169
// validate access key
167-
if err := c.validateAccessKey(quota.AccessKey, origin, chainIDs); err != nil {
170+
if err := c.validateAccessKey(quota.AccessKey, origin); err != nil {
168171
return quota, err
169172
}
170173
return quota, nil
@@ -325,7 +328,7 @@ func (c *Client) ClearQuotaCacheByAccessKey(ctx context.Context, accessKey strin
325328
return c.cache.QuotaCache.DeleteAccessQuota(ctx, accessKey)
326329
}
327330

328-
func (c *Client) validateAccessKey(access *proto.AccessKey, origin string, chainIDs []uint64) (err error) {
331+
func (c *Client) validateAccessKey(access *proto.AccessKey, origin string) (err error) {
329332
if !access.Active {
330333
return proto.ErrAccessKeyNotFound
331334
}
@@ -335,9 +338,6 @@ func (c *Client) validateAccessKey(access *proto.AccessKey, origin string, chain
335338
if !access.ValidateService(c.service) {
336339
return proto.ErrInvalidService
337340
}
338-
if err := access.ValidateChains(chainIDs); err != nil {
339-
return proto.ErrInvalidChain.WithCause(err)
340-
}
341341
return nil
342342
}
343343

middleware/common.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ func VerifyChains(ctx context.Context, chainIDs ...uint64) error {
8585
if !ok {
8686
return nil
8787
}
88-
if err := quota.AccessKey.ValidateChains(chainIDs); err != nil {
88+
if err := quota.Info.ValidateChains(chainIDs); err != nil {
8989
return proto.ErrInvalidChain.WithCause(err)
9090
}
9191
return nil

tests/mock/mem.go renamed to mock/mem.go

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,15 @@ import (
66
"time"
77

88
"github.com/0xsequence/authcontrol"
9-
"github.com/0xsequence/quotacontrol/internal/store"
109
"github.com/0xsequence/quotacontrol/internal/usage"
1110
"github.com/0xsequence/quotacontrol/proto"
1211
)
1312

1413
// NewMemoryStore returns a new in-memory store.
1514
func NewMemoryStore() *MemoryStore {
1615
ms := MemoryStore{
16+
infos: map[uint64]proto.ProjectInfo{},
1717
limits: map[uint64]proto.Limit{},
18-
cycles: map[uint64]proto.Cycle{},
1918
accessKeys: map[string]proto.AccessKey{},
2019
usage: map[proto.Service]usage.Record{},
2120
users: map[string]bool{},
@@ -38,16 +37,30 @@ type userPermission struct {
3837
type MemoryStore struct {
3938
sync.Mutex
4039
limits map[uint64]proto.Limit
41-
cycles map[uint64]proto.Cycle
40+
infos map[uint64]proto.ProjectInfo
4241
accessKeys map[string]proto.AccessKey
4342
usage map[proto.Service]usage.Record
4443
users map[string]bool
4544
projects map[uint64]*authcontrol.Auth
4645
permissions map[uint64]map[string]userPermission
4746
}
4847

48+
func (m *MemoryStore) SetProjectInfo(ctx context.Context, projectID uint64, info *proto.ProjectInfo) error {
49+
m.Lock()
50+
if info != nil {
51+
m.infos[projectID] = *info
52+
} else {
53+
delete(m.infos, projectID)
54+
}
55+
m.Unlock()
56+
return nil
57+
}
58+
4959
func (m *MemoryStore) SetAccessLimit(ctx context.Context, projectID uint64, config *proto.Limit) error {
5060
m.Lock()
61+
if _, ok := m.infos[projectID]; !ok {
62+
m.infos[projectID] = proto.ProjectInfo{ID: projectID}
63+
}
5164
m.limits[projectID] = *config
5265
m.Unlock()
5366
return nil
@@ -63,25 +76,11 @@ func (m *MemoryStore) GetAccessLimit(ctx context.Context, projectID uint64, cycl
6376
return &limit, nil
6477
}
6578

66-
func (m *MemoryStore) SetAccessCycle(ctx context.Context, projectID uint64, cycle *proto.Cycle) error {
79+
func (m *MemoryStore) GetProjectInfo(ctx context.Context, projectID uint64, now time.Time) (*proto.ProjectInfo, error) {
6780
m.Lock()
68-
if cycle != nil {
69-
m.cycles[projectID] = *cycle
70-
} else {
71-
delete(m.cycles, projectID)
72-
}
73-
m.Unlock()
74-
return nil
75-
}
76-
77-
func (m *MemoryStore) GetAccessCycle(ctx context.Context, projectID uint64, now time.Time) (*proto.Cycle, error) {
78-
m.Lock()
79-
cycle := m.cycles[projectID]
81+
info := m.infos[projectID]
8082
m.Unlock()
81-
if cycle.Start.IsZero() && cycle.End.IsZero() {
82-
return store.Cycle{}.GetAccessCycle(ctx, projectID, now)
83-
}
84-
return &cycle, nil
83+
return &info, nil
8584
}
8685

8786
func (m *MemoryStore) InsertAccessKey(ctx context.Context, access *proto.AccessKey) error {
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ func NewServer(cfg *quotacontrol.Config) (server *Server, cleanup func()) {
5050
PermissionCache: quotacontrol.NewRedisCache(client, time.Minute),
5151
}
5252
qcStore := quotacontrol.Store{
53-
LimitStore: store,
54-
AccessKeyStore: store,
55-
UsageStore: store,
56-
CycleStore: store,
57-
PermissionStore: store,
53+
ProjectInfoStore: store,
54+
LimitStore: store,
55+
AccessKeyStore: store,
56+
UsageStore: store,
57+
PermissionStore: store,
5858
}
5959

6060
logger := qc.logger.With(slog.String("server", "server"))

proto/proto.go

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,42 +3,47 @@
33
package proto
44

55
import (
6+
"cmp"
67
"fmt"
78
"slices"
89
"time"
910
)
1011

12+
// Ptr is an utility function to return a pointer to the value
1113
func Ptr[T any](v T) *T {
1214
return &v
1315
}
1416

15-
func (t *AccessKey) ValidateOrigin(rawOrigin string) bool {
17+
// ValidateOrigin checks if the given origin is allowed by the access key.
18+
func (a *AccessKey) ValidateOrigin(rawOrigin string) bool {
1619
if rawOrigin == "" {
17-
return !t.RequireOrigin
20+
return !a.RequireOrigin
1821
}
19-
return t.AllowedOrigins.MatchAny(rawOrigin)
22+
return a.AllowedOrigins.MatchAny(rawOrigin)
2023
}
2124

22-
func (t *AccessKey) ValidateService(service Service) bool {
23-
if len(t.AllowedServices) == 0 {
25+
// ValidateService checks if the given service is allowed by the access key.
26+
func (a *AccessKey) ValidateService(service Service) bool {
27+
if len(a.AllowedServices) == 0 {
2428
return true
2529
}
26-
for _, s := range t.AllowedServices {
30+
for _, s := range a.AllowedServices {
2731
if service == s {
2832
return true
2933
}
3034
}
3135
return false
3236
}
3337

34-
func (t *AccessKey) ValidateChains(chainIDs []uint64) error {
35-
if len(t.ChainIDs) == 0 {
38+
// ValidateChains checks if the given chain IDs are allowed by the project.
39+
func (i *ProjectInfo) ValidateChains(chainIDs []uint64) error {
40+
if len(i.ChainIDs) == 0 {
3641
return nil
3742
}
3843

3944
invalid := make([]uint64, 0, len(chainIDs))
4045
for _, id := range chainIDs {
41-
if !slices.Contains(t.ChainIDs, id) {
46+
if !slices.Contains(i.ChainIDs, id) {
4247
invalid = append(invalid, id)
4348
}
4449
}
@@ -49,6 +54,7 @@ func (t *AccessKey) ValidateChains(chainIDs []uint64) error {
4954
return nil
5055
}
5156

57+
// Validate checks if the limit configuration is valid.
5258
func (l Limit) Validate() error {
5359
for name, cfg := range l.ServiceLimit {
5460
svc, ok := ParseService(name)
@@ -62,18 +68,21 @@ func (l Limit) Validate() error {
6268
return nil
6369
}
6470

71+
// GetSettings returns the service limit settings for the given service.
6572
func (l Limit) GetSettings(svc Service) (ServiceLimit, bool) {
6673
settings, ok := l.ServiceLimit[svc.String()]
6774
return settings, ok
6875
}
6976

77+
// SetSetting sets the service limit settings for the given service.
7078
func (l *Limit) SetSetting(svc Service, limits ServiceLimit) {
7179
if l.ServiceLimit == nil {
7280
l.ServiceLimit = make(map[string]ServiceLimit)
7381
}
7482
l.ServiceLimit[svc.String()] = limits
7583
}
7684

85+
// Validate checks if the service limit configuration is valid.
7786
func (l ServiceLimit) Validate() error {
7887
if l.RateLimit < 1 {
7988
return fmt.Errorf("rateLimit must be > 0")
@@ -104,6 +113,7 @@ func getOverThreshold(v, total, threshold int64) (int64, bool) {
104113
return max(0, total-threshold), true
105114
}
106115

116+
// GetSpendResult calculates the spend result and event type based on the service limit and usage
107117
func (l *ServiceLimit) GetSpendResult(v, total int64) (int64, *EventType) {
108118
// valid usage
109119
if total < l.FreeMax {
@@ -176,14 +186,23 @@ func (c *Cycle) GetEnd(now time.Time) time.Time {
176186
return c.GetStart(now).AddDate(0, 1, -1)
177187
}
178188

179-
func (c *Cycle) GetDuration(now time.Time) time.Duration {
180-
return c.GetEnd(now).Sub(c.GetStart(now))
181-
}
189+
func (c *Cycle) SetInterval(from, to *time.Time, now time.Time) {
190+
from = cmp.Or(from, &time.Time{})
191+
to = cmp.Or(to, &time.Time{})
192+
193+
if !from.IsZero() && !to.IsZero() {
194+
return
195+
}
182196

183-
func (c *Cycle) Advance(now time.Time) {
184-
for c.End.Before(now) {
185-
c.Start = c.Start.AddDate(0, 1, 0)
186-
c.End = c.End.AddDate(0, 1, 0)
197+
duration := c.GetEnd(now).Sub(c.GetStart(now))
198+
switch {
199+
case !to.IsZero():
200+
*from = to.Add(-duration)
201+
case !from.IsZero():
202+
*to = from.Add(duration)
203+
default:
204+
*from = c.Start
205+
*to = c.End
187206
}
188207
}
189208

@@ -194,14 +213,6 @@ func (u *UserPermission) CanAccess(perm UserPermission) bool {
194213
return *u >= perm
195214
}
196215

197-
func (e WebRPCError) WithMessage(message string) WebRPCError {
198-
err := e
199-
if message != "" {
200-
err.Message = message
201-
}
202-
return err
203-
}
204-
205216
func ParseService(v string) (Service, bool) {
206217
raw, ok := Service_value[v]
207218
if !ok {
@@ -233,7 +244,3 @@ func (x Service) GetName() string {
233244
}
234245
return ""
235246
}
236-
237-
func (x Service) GetService() Service {
238-
return x
239-
}

proto/quotacontrol.gen.go

Lines changed: 15 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)