@@ -3,18 +3,86 @@ package forward
33import (
44 "context"
55 "errors"
6+ "fmt"
67 "io"
78 "sync"
89 "time"
910
1011 "github.com/Snawoot/secache"
1112 "github.com/Snawoot/secache/randmap"
1213
14+ clog "github.com/SenseUnit/dumbproxy/log"
1315 "github.com/SenseUnit/dumbproxy/rate"
1416)
1517
1618const copyChunkSize = 128 * 1024
1719
20+ type LimitKind int
21+
22+ const (
23+ LimitKindNone LimitKind = iota
24+ LimitKindStatic
25+ LimitKindJS
26+ )
27+
28+ type LimitSpec struct {
29+ Kind LimitKind
30+ Spec any
31+ }
32+
33+ type StaticLimitSpec struct {
34+ BPS uint64
35+ Burst int64
36+ Separate bool
37+ }
38+
39+ type JSLimitSpec struct {
40+ Filename string
41+ Instances int
42+ }
43+
44+ type LimitParameters struct {
45+ UploadBPS float64 `json:"uploadBPS"`
46+ UploadBurst int64 `json:"uploadBurst"`
47+ DownloadBPS float64 `json:"downloadBPS"`
48+ DownloadBurst int64 `json:"downloadBurst"`
49+ GroupKey * string `json:"groupKey"`
50+ Separate bool `json:"separate"`
51+ }
52+
53+ type LimitProvider = func (context.Context , string , string , string ) (* LimitParameters , error )
54+
55+ func ProviderFromSpec (spec LimitSpec , logger * clog.CondLogger ) (LimitProvider , error ) {
56+ switch spec .Kind {
57+ case LimitKindStatic :
58+ staticSpec , ok := spec .Spec .(StaticLimitSpec )
59+ if ! ok {
60+ return nil , fmt .Errorf ("incorrect payload type in BW limit spec: %T" , spec )
61+ }
62+ return func (_ context.Context , username , _ , _ string ) (* LimitParameters , error ) {
63+ return & LimitParameters {
64+ UploadBPS : float64 (staticSpec .BPS ),
65+ UploadBurst : staticSpec .Burst ,
66+ DownloadBPS : float64 (staticSpec .BPS ),
67+ DownloadBurst : staticSpec .Burst ,
68+ GroupKey : & username ,
69+ Separate : staticSpec .Separate ,
70+ }, nil
71+ }, nil
72+ case LimitKindJS :
73+ jsSpec , ok := spec .Spec .(JSLimitSpec )
74+ if ! ok {
75+ return nil , fmt .Errorf ("incorrect payload type in BW limit spec: %T" , spec )
76+ }
77+ j , err := NewJSLimitProvider (jsSpec .Filename , jsSpec .Instances , logger )
78+ if err != nil {
79+ return nil , err
80+ }
81+ return j .Parameters , nil
82+ }
83+ return nil , fmt .Errorf ("unsupported BW limit kind %d" , int (spec .Kind ))
84+ }
85+
1886type cacheItem struct {
1987 mux sync.RWMutex
2088 ul * rate.Limiter
@@ -38,17 +106,13 @@ func (i *cacheItem) unlock() {
38106}
39107
40108type BWLimit struct {
41- bps float64
42- burst int64
43- separate bool
44- cache secache.Cache [string , * cacheItem ]
109+ paramFn LimitProvider
110+ cache secache.Cache [string , * cacheItem ]
45111}
46112
47- func NewBWLimit (bytesPerSecond float64 , burst int64 , separate bool ) * BWLimit {
113+ func NewBWLimit (p LimitProvider ) * BWLimit {
48114 return & BWLimit {
49- bps : bytesPerSecond ,
50- burst : burst ,
51- separate : separate ,
115+ paramFn : p ,
52116 cache : * (secache .New [string , * cacheItem ](3 , func (_ string , item * cacheItem ) bool {
53117 if item .tryLock () {
54118 if item .ul .Tokens () >= float64 (item .ul .Burst ()) && item .dl .Tokens () >= float64 (item .dl .Burst ()) {
@@ -120,35 +184,46 @@ func (l *BWLimit) futureCopyAndCloseWrite(ctx context.Context, c chan<- error, r
120184 close (c )
121185}
122186
123- func (l * BWLimit ) getRatelimiters (username string ) (res * cacheItem ) {
187+ func (l * BWLimit ) getRatelimiters (ctx context.Context , username , network , address string ) (* cacheItem , error ) {
188+ params , err := l .paramFn (ctx , username , network , address )
189+ if err != nil {
190+ return nil , err
191+ }
192+ groupKey := username
193+ if params .GroupKey != nil {
194+ groupKey = * params .GroupKey
195+ }
196+ var res * cacheItem
124197 l .cache .Do (func (m * randmap.RandMap [string , * cacheItem ]) {
125198 var ok bool
126- res , ok = m .Get (username )
199+ res , ok = m .Get (groupKey )
127200 if ok {
128201 res .rLock ()
129202 } else {
130- ul := rate .NewLimiter (rate .Limit (l . bps ), max (copyChunkSize , l . burst ))
203+ ul := rate .NewLimiter (rate .Limit (params . UploadBPS ), max (copyChunkSize , params . UploadBurst ))
131204 dl := ul
132- if l . separate {
133- dl = rate .NewLimiter (rate .Limit (l . bps ), max (copyChunkSize , l . burst ))
205+ if params . Separate {
206+ dl = rate .NewLimiter (rate .Limit (params . DownloadBPS ), max (copyChunkSize , params . DownloadBurst ))
134207 }
135208 res = & cacheItem {
136209 ul : ul ,
137210 dl : dl ,
138211 }
139212 res .rLock ()
140- l .cache .SetLocked (m , username , res )
213+ l .cache .SetLocked (m , groupKey , res )
141214 }
142215 return
143216 })
144- return
217+ return res , nil
145218}
146219
147- func (l * BWLimit ) PairConnections (ctx context.Context , username string , incoming , outgoing io.ReadWriteCloser ) error {
148- ci := l .getRatelimiters (username )
220+ func (l * BWLimit ) PairConnections (ctx context.Context , username string , incoming , outgoing io.ReadWriteCloser , network , address string ) error {
221+ ci , err := l .getRatelimiters (ctx , username , network , address )
222+ if err != nil {
223+ return fmt .Errorf ("ratelimit parameter computarion failed for user %q: %w" , username , err )
224+ }
149225 defer ci .rUnlock ()
150226
151- var err error
152227 i2oErr := make (chan error , 1 )
153228 o2iErr := make (chan error , 1 )
154229 ctxErr := ctx .Done ()
0 commit comments