Skip to content

Commit 7010ee6

Browse files
committed
all: ratelimit package
1 parent 0647504 commit 7010ee6

File tree

7 files changed

+320
-183
lines changed

7 files changed

+320
-183
lines changed

internal/ratelimit/ratelimit.go

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
// Package ratelimit provides a rate limiting functionality.
2+
package ratelimit
3+
4+
import (
5+
"fmt"
6+
"log/slog"
7+
"net/netip"
8+
"slices"
9+
"sync"
10+
"time"
11+
12+
"github.com/AdguardTeam/dnsproxy/proxy"
13+
"github.com/AdguardTeam/golibs/logutil/slogutil"
14+
rate "github.com/beefsack/go-rate"
15+
gocache "github.com/patrickmn/go-cache"
16+
)
17+
18+
// Config is the configuration for the ratelimit handler.
19+
type Config struct {
20+
// Logger is used for logging in the ratelimit handler. It must not be nil.
21+
Logger *slog.Logger
22+
23+
// AllowlistAddrs is a list of IP addresses excluded from rate limiting.
24+
AllowlistAddrs []netip.Addr
25+
26+
// Ratelimit is a maximum number of requests per second from a given IP (0
27+
// to disable).
28+
Ratelimit int
29+
30+
// SubnetLenIPv4 is a subnet length for IPv4 addresses used for rate
31+
// limiting requests.
32+
SubnetLenIPv4 int
33+
34+
// SubnetLenIPv6 is a subnet length for IPv6 addresses used for rate
35+
// limiting requests.
36+
SubnetLenIPv6 int
37+
}
38+
39+
// handler implements [proxy.RequestHandler] with rate limiting functionality.
40+
type handler struct {
41+
buckets *gocache.Cache
42+
handler proxy.RequestHandler
43+
logger *slog.Logger
44+
45+
// mu protects buckets.
46+
mu *sync.Mutex
47+
48+
allowlistAddrs []netip.Addr
49+
ratelimit int
50+
subnetLenIPv4 int
51+
subnetLenIPv6 int
52+
}
53+
54+
// NewRatelimitedRequestHandler wraps a RequestHandler with rate limiting
55+
// functionality. h must not be nil, c must be valid.
56+
//
57+
// TODO(d.kolyshev): !! Use.
58+
func NewRatelimitedRequestHandler(h proxy.RequestHandler, c *Config) (wrapped proxy.RequestHandler) {
59+
if c.Ratelimit <= 0 {
60+
return h
61+
}
62+
63+
return &handler{
64+
handler: h,
65+
logger: c.Logger,
66+
mu: &sync.Mutex{},
67+
allowlistAddrs: c.AllowlistAddrs,
68+
ratelimit: c.Ratelimit,
69+
subnetLenIPv4: c.SubnetLenIPv4,
70+
subnetLenIPv6: c.SubnetLenIPv6,
71+
}
72+
}
73+
74+
// type check
75+
var _ proxy.RequestHandler = (*handler)(nil)
76+
77+
// Handle implements the [proxy.RequestHandler] interface for *handler. If the
78+
// client is rate limited, it returns [proxy.ErrDrop] to signal that no response
79+
// should be sent.
80+
func (h *handler) Handle(p *proxy.Proxy, dctx *proxy.DNSContext) (err error) {
81+
if dctx.Proto == proxy.ProtoUDP && h.isRatelimited(dctx.Addr.Addr()) {
82+
h.logger.Debug("ratelimited based on ip only", "addr", dctx.Addr)
83+
84+
return proxy.ErrDrop
85+
}
86+
87+
return h.handler.Handle(p, dctx)
88+
}
89+
90+
// limiterForIP returns a rate limiter for the specified IP address.
91+
func (h *handler) limiterForIP(ip string) interface{} {
92+
h.mu.Lock()
93+
defer h.mu.Unlock()
94+
95+
if h.buckets == nil {
96+
h.buckets = gocache.New(time.Hour, time.Hour)
97+
}
98+
99+
value, found := h.buckets.Get(ip)
100+
if !found {
101+
value = rate.New(h.ratelimit, time.Second)
102+
h.buckets.Set(ip, value, time.Hour)
103+
}
104+
105+
return value
106+
}
107+
108+
// isRatelimited checks if the specified address should be rate limited.
109+
func (h *handler) isRatelimited(addr netip.Addr) (ok bool) {
110+
addr = addr.Unmap()
111+
_, ok = slices.BinarySearchFunc(h.allowlistAddrs, addr, netip.Addr.Compare)
112+
if ok {
113+
return false
114+
}
115+
116+
var pref netip.Prefix
117+
if addr.Is4() {
118+
pref = netip.PrefixFrom(addr, h.subnetLenIPv4)
119+
} else {
120+
pref = netip.PrefixFrom(addr, h.subnetLenIPv6)
121+
}
122+
pref = pref.Masked()
123+
124+
// TODO(d.kolyshev): Improve caching. Decrease allocations.
125+
ipStr := pref.Addr().String()
126+
value := h.limiterForIP(ipStr)
127+
rl, ok := value.(*rate.RateLimiter)
128+
if !ok {
129+
h.logger.Error(
130+
"invalid value found in ratelimit cache",
131+
slogutil.KeyError,
132+
fmt.Errorf("bad type %T", value),
133+
)
134+
135+
return false
136+
}
137+
138+
allow, _ := rl.Try()
139+
140+
return !allow
141+
}
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
package ratelimit_test
2+
3+
import (
4+
"net/netip"
5+
"testing"
6+
7+
"github.com/AdguardTeam/dnsproxy/internal/ratelimit"
8+
"github.com/AdguardTeam/dnsproxy/proxy"
9+
"github.com/AdguardTeam/golibs/logutil/slogutil"
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
// Subnet lengths used in tests.
15+
const (
16+
subnetLenIPv4 = 24
17+
subnetLenIPv6 = 64
18+
)
19+
20+
// testLogger is a test logger used in tests.
21+
var testLogger = slogutil.NewDiscardLogger()
22+
23+
func TestHandler_Handle(t *testing.T) {
24+
t.Parallel()
25+
26+
testAddr := netip.MustParseAddrPort("192.0.2.0:53")
27+
28+
testCases := []struct {
29+
config *ratelimit.Config
30+
dctx *proxy.DNSContext
31+
want error
32+
name string
33+
}{{
34+
name: "disabled_ratelimit",
35+
config: &ratelimit.Config{
36+
Logger: testLogger,
37+
Ratelimit: 0,
38+
SubnetLenIPv4: subnetLenIPv4,
39+
SubnetLenIPv6: subnetLenIPv6,
40+
},
41+
dctx: &proxy.DNSContext{
42+
Addr: testAddr,
43+
Proto: proxy.ProtoUDP,
44+
},
45+
want: nil,
46+
}, {
47+
name: "tcp_not_ratelimited",
48+
config: &ratelimit.Config{
49+
Logger: testLogger,
50+
Ratelimit: 1,
51+
SubnetLenIPv4: subnetLenIPv4,
52+
SubnetLenIPv6: subnetLenIPv6,
53+
},
54+
dctx: &proxy.DNSContext{
55+
Addr: testAddr,
56+
Proto: proxy.ProtoTCP,
57+
},
58+
want: nil,
59+
}, {
60+
name: "ratelimited",
61+
config: &ratelimit.Config{
62+
Logger: testLogger,
63+
Ratelimit: 1,
64+
SubnetLenIPv4: subnetLenIPv4,
65+
SubnetLenIPv6: subnetLenIPv6,
66+
},
67+
dctx: &proxy.DNSContext{
68+
Addr: testAddr,
69+
Proto: proxy.ProtoUDP,
70+
},
71+
want: proxy.ErrDrop,
72+
}}
73+
74+
mock := &TestRequestHandler{
75+
OnHandle: func(p *proxy.Proxy, dctx *proxy.DNSContext) (err error) {
76+
return nil
77+
},
78+
}
79+
80+
for _, tc := range testCases {
81+
t.Run(tc.name, func(t *testing.T) {
82+
t.Parallel()
83+
84+
wrapped := ratelimit.NewRatelimitedRequestHandler(mock, tc.config)
85+
86+
err := wrapped.Handle(nil, tc.dctx)
87+
require.NoError(t, err, "first request should not be ratelimited")
88+
89+
err = wrapped.Handle(nil, tc.dctx)
90+
assert.Equal(t, tc.want, err)
91+
})
92+
}
93+
}
94+
95+
func TestHandler_Handle_allowlist(t *testing.T) {
96+
t.Parallel()
97+
98+
var (
99+
addrAllow = netip.MustParseAddr("192.0.2.0")
100+
addrPortAllow = netip.AddrPortFrom(addrAllow, 53)
101+
addrPortDrop = netip.MustParseAddrPort("192.0.2.1:53")
102+
)
103+
104+
conf := &ratelimit.Config{
105+
Logger: testLogger,
106+
Ratelimit: 1,
107+
SubnetLenIPv4: subnetLenIPv4,
108+
SubnetLenIPv6: subnetLenIPv6,
109+
AllowlistAddrs: []netip.Addr{
110+
addrAllow,
111+
},
112+
}
113+
114+
mock := &TestRequestHandler{
115+
OnHandle: func(p *proxy.Proxy, dctx *proxy.DNSContext) (err error) {
116+
return nil
117+
},
118+
}
119+
handler := ratelimit.NewRatelimitedRequestHandler(mock, conf)
120+
121+
t.Run("block", func(t *testing.T) {
122+
dctx := &proxy.DNSContext{
123+
Addr: addrPortDrop,
124+
Proto: proxy.ProtoUDP,
125+
}
126+
127+
err := handler.Handle(nil, dctx)
128+
require.NoError(t, err, "first request should not be ratelimited")
129+
130+
err = handler.Handle(nil, dctx)
131+
require.Error(t, err, "second request should be ratelimited")
132+
assert.Equal(t, proxy.ErrDrop, err)
133+
})
134+
135+
t.Run("allow", func(t *testing.T) {
136+
dctx := &proxy.DNSContext{
137+
Addr: addrPortAllow,
138+
Proto: proxy.ProtoUDP,
139+
}
140+
141+
err := handler.Handle(nil, dctx)
142+
require.NoError(t, err, "first request should not be ratelimited")
143+
144+
err = handler.Handle(nil, dctx)
145+
require.NoError(t, err, "second request should not be ratelimited due to whitelist")
146+
})
147+
}
148+
149+
// TestRequestHandler is a mock request handler implementation to simplify
150+
// testing.
151+
//
152+
// TODO(d.kolyshev): Move to internal/dnsproxytest.
153+
type TestRequestHandler struct {
154+
OnHandle func(p *proxy.Proxy, dctx *proxy.DNSContext) (err error)
155+
}
156+
157+
// type check
158+
var _ proxy.RequestHandler = (*TestRequestHandler)(nil)
159+
160+
// Handle implements the [RequestHandler] interface for *TestRequestHandler.
161+
func (h *TestRequestHandler) Handle(p *proxy.Proxy, dctx *proxy.DNSContext) (err error) {
162+
return h.OnHandle(p, dctx)
163+
}

proxy/config.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ type Config struct {
148148
DNS64Prefs []netip.Prefix
149149

150150
// RatelimitWhitelist is a list of IP addresses excluded from rate limiting.
151+
//
152+
// TODO(d.kolyshev): !! Remove.
151153
RatelimitWhitelist []netip.Addr
152154

153155
// EDNSAddr is the ECS IP used in request.
@@ -157,14 +159,20 @@ type Config struct {
157159

158160
// RatelimitSubnetLenIPv4 is a subnet length for IPv4 addresses used for
159161
// rate limiting requests.
162+
//
163+
// TODO(d.kolyshev): !! Remove.
160164
RatelimitSubnetLenIPv4 int
161165

162166
// RatelimitSubnetLenIPv6 is a subnet length for IPv6 addresses used for
163167
// rate limiting requests.
168+
//
169+
// TODO(d.kolyshev): !! Remove.
164170
RatelimitSubnetLenIPv6 int
165171

166172
// Ratelimit is a maximum number of requests per second from a given IP (0
167173
// to disable).
174+
//
175+
// TODO(d.kolyshev): !! Remove.
168176
Ratelimit int
169177

170178
// CacheSizeBytes is the maximum cache size in bytes.

proxy/proxy.go

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ import (
3030
"github.com/AdguardTeam/golibs/validate"
3131
"github.com/ameshkov/dnscrypt/v2"
3232
"github.com/miekg/dns"
33-
gocache "github.com/patrickmn/go-cache"
3433
"github.com/quic-go/quic-go"
3534
"github.com/quic-go/quic-go/http3"
3635
)
@@ -59,6 +58,10 @@ const (
5958
ProtoDNSCrypt Proto = "dnscrypt"
6059
)
6160

61+
// ErrDrop is returned by a RequestHandler to signal that the proxy should not
62+
// send any response to the client.
63+
var ErrDrop = errors.Error("drop response")
64+
6265
// Proxy combines the proxy server state and configuration.
6366
//
6467
// TODO(a.garipov): Consider extracting conf blocks for better fieldalignment.
@@ -102,9 +105,6 @@ type Proxy struct {
102105
// logger is used for logging in the proxy service. It is never nil.
103106
logger *slog.Logger
104107

105-
// ratelimitBuckets is a storage for ratelimiters for individual IPs.
106-
ratelimitBuckets *gocache.Cache
107-
108108
// fastestAddr finds the fastest IP address for the resolved domain.
109109
fastestAddr *fastip.FastestAddr
110110

@@ -203,9 +203,6 @@ type Proxy struct {
203203
// Also make it a pointer.
204204
sync.RWMutex
205205

206-
// ratelimitLock protects ratelimitBuckets.
207-
ratelimitLock sync.Mutex
208-
209206
// rttLock protects upstreamRTTStats.
210207
//
211208
// TODO(e.burkov): Make it a pointer.
@@ -234,7 +231,6 @@ func New(c *Config) (p *Proxy, err error) {
234231
requestHandler: cmp.Or[RequestHandler](c.RequestHandler, DefaultRequestHandler{}),
235232
upstreamRTTStats: map[string]upstreamRTTStats{},
236233
rttLock: sync.Mutex{},
237-
ratelimitLock: sync.Mutex{},
238234
RWMutex: sync.RWMutex{},
239235
// 2 bytes may be used to store packet length (see TCP/TLS).
240236
bytesPool: syncutil.NewSlicePool[byte](2 + dns.MaxMsgSize),

0 commit comments

Comments
 (0)