Skip to content

Commit 5bdac67

Browse files
committed
BW limit parameters by JS script
1 parent d2827e3 commit 5bdac67

File tree

10 files changed

+439
-178
lines changed

10 files changed

+439
-178
lines changed

forward/bwlimit.go

Lines changed: 93 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,86 @@ package forward
33
import (
44
"context"
55
"errors"
6+
"fmt"
67
"io"
78
"sync"
89
"time"
910

1011
"github.com/Snawoot/secache"
1112
"github.com/Snawoot/secache/randmap"
1213

14+
clog "github.com/SenseUnit/dumbproxy/log"
1315
"github.com/SenseUnit/dumbproxy/rate"
1416
)
1517

1618
const copyChunkSize = 128 * 1024
1719

20+
type LimitKind int
21+
22+
const (
23+
LimitKindNone LimitKind = iota
24+
LimitKindStatic
25+
LimitKindJS
26+
)
27+
28+
type LimitSpec struct {
29+
Kind LimitKind
30+
Spec any
31+
}
32+
33+
type StaticLimitSpec struct {
34+
BPS uint64
35+
Burst int64
36+
Separate bool
37+
}
38+
39+
type JSLimitSpec struct {
40+
Filename string
41+
Instances int
42+
}
43+
44+
type LimitParameters struct {
45+
UploadBPS float64 `json:"uploadBPS"`
46+
UploadBurst int64 `json:"uploadBurst"`
47+
DownloadBPS float64 `json:"downloadBPS"`
48+
DownloadBurst int64 `json:"downloadBurst"`
49+
GroupKey *string `json:"groupKey"`
50+
Separate bool `json:"separate"`
51+
}
52+
53+
type LimitProvider = func(context.Context, string, string, string) (*LimitParameters, error)
54+
55+
func ProviderFromSpec(spec LimitSpec, logger *clog.CondLogger) (LimitProvider, error) {
56+
switch spec.Kind {
57+
case LimitKindStatic:
58+
staticSpec, ok := spec.Spec.(StaticLimitSpec)
59+
if !ok {
60+
return nil, fmt.Errorf("incorrect payload type in BW limit spec: %T", spec)
61+
}
62+
return func(_ context.Context, username, _, _ string) (*LimitParameters, error) {
63+
return &LimitParameters{
64+
UploadBPS: float64(staticSpec.BPS),
65+
UploadBurst: staticSpec.Burst,
66+
DownloadBPS: float64(staticSpec.BPS),
67+
DownloadBurst: staticSpec.Burst,
68+
GroupKey: &username,
69+
Separate: staticSpec.Separate,
70+
}, nil
71+
}, nil
72+
case LimitKindJS:
73+
jsSpec, ok := spec.Spec.(JSLimitSpec)
74+
if !ok {
75+
return nil, fmt.Errorf("incorrect payload type in BW limit spec: %T", spec)
76+
}
77+
j, err := NewJSLimitProvider(jsSpec.Filename, jsSpec.Instances, logger)
78+
if err != nil {
79+
return nil, err
80+
}
81+
return j.Parameters, nil
82+
}
83+
return nil, fmt.Errorf("unsupported BW limit kind %d", int(spec.Kind))
84+
}
85+
1886
type cacheItem struct {
1987
mux sync.RWMutex
2088
ul *rate.Limiter
@@ -38,17 +106,13 @@ func (i *cacheItem) unlock() {
38106
}
39107

40108
type BWLimit struct {
41-
bps float64
42-
burst int64
43-
separate bool
44-
cache secache.Cache[string, *cacheItem]
109+
paramFn LimitProvider
110+
cache secache.Cache[string, *cacheItem]
45111
}
46112

47-
func NewBWLimit(bytesPerSecond float64, burst int64, separate bool) *BWLimit {
113+
func NewBWLimit(p LimitProvider) *BWLimit {
48114
return &BWLimit{
49-
bps: bytesPerSecond,
50-
burst: burst,
51-
separate: separate,
115+
paramFn: p,
52116
cache: *(secache.New[string, *cacheItem](3, func(_ string, item *cacheItem) bool {
53117
if item.tryLock() {
54118
if item.ul.Tokens() >= float64(item.ul.Burst()) && item.dl.Tokens() >= float64(item.dl.Burst()) {
@@ -120,35 +184,46 @@ func (l *BWLimit) futureCopyAndCloseWrite(ctx context.Context, c chan<- error, r
120184
close(c)
121185
}
122186

123-
func (l *BWLimit) getRatelimiters(username string) (res *cacheItem) {
187+
func (l *BWLimit) getRatelimiters(ctx context.Context, username, network, address string) (*cacheItem, error) {
188+
params, err := l.paramFn(ctx, username, network, address)
189+
if err != nil {
190+
return nil, err
191+
}
192+
groupKey := username
193+
if params.GroupKey != nil {
194+
groupKey = *params.GroupKey
195+
}
196+
var res *cacheItem
124197
l.cache.Do(func(m *randmap.RandMap[string, *cacheItem]) {
125198
var ok bool
126-
res, ok = m.Get(username)
199+
res, ok = m.Get(groupKey)
127200
if ok {
128201
res.rLock()
129202
} else {
130-
ul := rate.NewLimiter(rate.Limit(l.bps), max(copyChunkSize, l.burst))
203+
ul := rate.NewLimiter(rate.Limit(params.UploadBPS), max(copyChunkSize, params.UploadBurst))
131204
dl := ul
132-
if l.separate {
133-
dl = rate.NewLimiter(rate.Limit(l.bps), max(copyChunkSize, l.burst))
205+
if params.Separate {
206+
dl = rate.NewLimiter(rate.Limit(params.DownloadBPS), max(copyChunkSize, params.DownloadBurst))
134207
}
135208
res = &cacheItem{
136209
ul: ul,
137210
dl: dl,
138211
}
139212
res.rLock()
140-
l.cache.SetLocked(m, username, res)
213+
l.cache.SetLocked(m, groupKey, res)
141214
}
142215
return
143216
})
144-
return
217+
return res, nil
145218
}
146219

147-
func (l *BWLimit) PairConnections(ctx context.Context, username string, incoming, outgoing io.ReadWriteCloser) error {
148-
ci := l.getRatelimiters(username)
220+
func (l *BWLimit) PairConnections(ctx context.Context, username string, incoming, outgoing io.ReadWriteCloser, network, address string) error {
221+
ci, err := l.getRatelimiters(ctx, username, network, address)
222+
if err != nil {
223+
return fmt.Errorf("ratelimit parameter computarion failed for user %q: %w", username, err)
224+
}
149225
defer ci.rUnlock()
150226

151-
var err error
152227
i2oErr := make(chan error, 1)
153228
o2iErr := make(chan error, 1)
154229
ctxErr := ctx.Done()

forward/direct.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func futureCopyAndCloseWrite(c chan<- error, dst io.WriteCloser, src io.ReadClos
2222
close(c)
2323
}
2424

25-
func PairConnections(ctx context.Context, username string, incoming, outgoing io.ReadWriteCloser) error {
25+
func PairConnections(ctx context.Context, username string, incoming, outgoing io.ReadWriteCloser, _, _ string) error {
2626
var err error
2727
i2oErr := make(chan error, 1)
2828
o2iErr := make(chan error, 1)

forward/jslimit.go

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package forward
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"os"
8+
9+
"github.com/dop251/goja"
10+
"golang.org/x/sync/errgroup"
11+
12+
"github.com/SenseUnit/dumbproxy/dialer/dto"
13+
"github.com/SenseUnit/dumbproxy/jsext"
14+
clog "github.com/SenseUnit/dumbproxy/log"
15+
)
16+
17+
type JSLimitFunc = func(req *jsext.JSRequestInfo, dst *jsext.JSDstInfo, username string) (*LimitParameters, error)
18+
19+
type JSLimitProvider struct {
20+
funcPool chan JSLimitFunc
21+
logger *clog.CondLogger
22+
}
23+
24+
func NewJSLimitProvider(filename string, instances int, logger *clog.CondLogger) (*JSLimitProvider, error) {
25+
script, err := os.ReadFile(filename)
26+
if err != nil {
27+
return nil, fmt.Errorf("unable to load JS script file %q: %w", filename, err)
28+
}
29+
30+
instances = max(1, instances)
31+
pool := make(chan JSLimitFunc, instances)
32+
initGroup, _ := errgroup.WithContext(context.Background())
33+
34+
for i := 0; i < instances; i++ {
35+
initGroup.Go(func() error {
36+
vm := goja.New()
37+
err := jsext.AddPrinter(vm, logger)
38+
if err != nil {
39+
return fmt.Errorf("can't add print function to runtime: %w", err)
40+
}
41+
err = jsext.ConfigureRuntime(vm)
42+
if err != nil {
43+
return fmt.Errorf("can't configure runtime: %w", err)
44+
}
45+
_, err = vm.RunString(string(script))
46+
if err != nil {
47+
return fmt.Errorf("script run failed: %w", err)
48+
}
49+
50+
var f JSLimitFunc
51+
var limitFnJSVal goja.Value
52+
if ex := vm.Try(func() {
53+
limitFnJSVal = vm.Get("bwLimit")
54+
}); ex != nil {
55+
return fmt.Errorf("\"bwLimit\" function cannot be located in VM context: %w", err)
56+
}
57+
if limitFnJSVal == nil {
58+
return errors.New("\"bwLimit\" function is not defined")
59+
}
60+
err = vm.ExportTo(limitFnJSVal, &f)
61+
if err != nil {
62+
return fmt.Errorf("can't export \"bwLimit\" function from JS VM: %w", err)
63+
}
64+
65+
pool <- f
66+
return nil
67+
})
68+
}
69+
70+
err = initGroup.Wait()
71+
if err != nil {
72+
return nil, err
73+
}
74+
75+
return &JSLimitProvider{
76+
funcPool: pool,
77+
logger: logger,
78+
}, nil
79+
}
80+
81+
func (j *JSLimitProvider) Parameters(ctx context.Context, username, network, address string) (res *LimitParameters, err error) {
82+
defer func() {
83+
if err != nil {
84+
j.logger.Error("%v", err)
85+
}
86+
}()
87+
req, _ := dto.FilterParamsFromContext(ctx)
88+
ri := jsext.JSRequestInfoFromRequest(req)
89+
di, err := jsext.JSDstInfoFromContext(ctx, network, address)
90+
if err != nil {
91+
return nil, fmt.Errorf("unable to construct dst info: %w", err)
92+
}
93+
func() {
94+
f := <-j.funcPool
95+
defer func(pool chan JSLimitFunc, f JSLimitFunc) {
96+
pool <- f
97+
}(j.funcPool, f)
98+
res, err = f(ri, di, username)
99+
}()
100+
if err != nil {
101+
return nil, fmt.Errorf("JS limit script exception: %w", err)
102+
}
103+
if res == nil {
104+
return nil, fmt.Errorf("JS limit script returned null object")
105+
}
106+
return res, nil
107+
}

handler/config.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
package handler
22

33
import (
4-
"context"
5-
"io"
6-
74
"github.com/SenseUnit/dumbproxy/auth"
85
clog "github.com/SenseUnit/dumbproxy/log"
96
)
@@ -19,7 +16,7 @@ type Config struct {
1916
Logger *clog.CondLogger
2017
// Forward optionally specifies custom connection pairing function
2118
// which does actual data forwarding.
22-
Forward func(ctx context.Context, username string, incoming, outgoing io.ReadWriteCloser) error
19+
Forward ForwardFunc
2320
// UserIPHints specifies whether allow IP hints set by user or not
2421
UserIPHints bool
2522
}

handler/handler.go

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ import (
1010
"math/rand/v2"
1111
"net"
1212
"net/http"
13+
"net/url"
1314
"strconv"
15+
"strings"
1416
"sync"
1517

1618
"github.com/SenseUnit/dumbproxy/auth"
@@ -27,7 +29,7 @@ type HandlerDialer interface {
2729
DialContext(ctx context.Context, net, address string) (net.Conn, error)
2830
}
2931

30-
type ForwardFunc = func(ctx context.Context, username string, incoming, outgoing io.ReadWriteCloser) error
32+
type ForwardFunc = func(ctx context.Context, username string, incoming, outgoing io.ReadWriteCloser, network, address string) error
3133

3234
type ProxyHandler struct {
3335
auth auth.Auth
@@ -114,6 +116,8 @@ func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request, u
114116
username,
115117
wrapH1ReqBody(io.NopCloser(io.LimitReader(rw.Reader, int64(buffered)))),
116118
wrapH1RespWriter(conn),
119+
"tcp",
120+
req.RequestURI,
117121
)
118122
s.forward(
119123
req.Context(),
@@ -123,31 +127,48 @@ func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request, u
123127
localconn,
124128
),
125129
conn,
130+
"tcp",
131+
req.RequestURI,
126132
)
127133
} else {
128134
s.logger.Debug("not rescuing remaining data in bufio.ReadWriter")
129135
fmt.Fprintf(localconn, "HTTP/%d.%d 200 OK\r\n\r\n", req.ProtoMajor, req.ProtoMinor)
130-
s.forward(req.Context(), username, localconn, conn)
136+
s.forward(req.Context(), username, localconn, conn, "tcp", req.RequestURI)
131137
}
132138
} else if req.ProtoMajor == 2 {
133139
wr.Header()["Date"] = nil
134140
wr.WriteHeader(http.StatusOK)
135141
flush(wr)
136-
s.forward(req.Context(), username, wrapH2(req.Body, wr), conn)
142+
s.forward(req.Context(), username, wrapH2(req.Body, wr), conn, "tcp", req.RequestURI)
137143
} else {
138144
s.logger.Error("Unsupported protocol version: %s", req.Proto)
139145
http.Error(wr, "Unsupported protocol version.", http.StatusBadRequest)
140146
return
141147
}
142148
}
143149

150+
func addressFromURL(u *url.URL) string {
151+
host := u.Hostname()
152+
port := u.Port()
153+
if port == "" {
154+
switch strings.ToLower(u.Scheme) {
155+
case "http":
156+
port = "80"
157+
case "https":
158+
port = "443"
159+
}
160+
}
161+
return net.JoinHostPort(host, port)
162+
}
163+
144164
func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request, username string) {
145165
req.RequestURI = ""
146166
forwardReqBody := newH1ReqBodyPipe()
147167
origBody := req.Body
148168
req.Body = forwardReqBody.Body()
169+
address := addressFromURL(req.URL)
149170
go func() {
150-
s.forward(req.Context(), username, wrapH1ReqBody(origBody), forwardReqBody)
171+
s.forward(req.Context(), username, wrapH1ReqBody(origBody), forwardReqBody, "tcp", address)
151172
}()
152173
if req.ProtoMajor == 2 {
153174
req.URL.Scheme = "http" // We can't access :scheme pseudo-header, so assume http
@@ -171,7 +192,7 @@ func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request,
171192
copyHeader(wr.Header(), resp.Header)
172193
wr.WriteHeader(resp.StatusCode)
173194
flush(wr)
174-
s.forward(req.Context(), username, wrapH1RespWriter(wr), wrapH1ReqBody(resp.Body))
195+
s.forward(req.Context(), username, wrapH1RespWriter(wr), wrapH1ReqBody(resp.Body), "tcp", address)
175196
}
176197

177198
func (s *ProxyHandler) HandleGetRandom(wr http.ResponseWriter, req *http.Request, username string) {

handler/socks.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func SOCKSHandler(dialer HandlerDialer, logger *clog.CondLogger, forward Forward
7070
return fmt.Errorf("failed to send reply, %v", err)
7171
}
7272

73-
return forward(ctx, username, wrapSOCKS(request.Reader, writer), target)
73+
return forward(ctx, username, wrapSOCKS(request.Reader, writer), target, "tcp", request.DestAddr.String())
7474
}
7575
}
7676

0 commit comments

Comments
 (0)