Skip to content

Commit 268659d

Browse files
committed
all: add proxy.RequestHandler interface
1 parent e0e1e8e commit 268659d

File tree

11 files changed

+122
-72
lines changed

11 files changed

+122
-72
lines changed

internal/cmd/proxy.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ func createProxyConfig(
8181
MaxGoroutines: conf.MaxGoRoutines,
8282
UsePrivateRDNS: conf.UsePrivateRDNS,
8383
PrivateSubnets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
84-
RequestHandler: reqHdlr.HandleRequest,
84+
RequestHandler: reqHdlr,
8585
PendingRequests: &proxy.PendingRequestsConfig{
8686
Enabled: conf.PendingRequestsEnabled,
8787
},

internal/handler/default.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ type Default struct {
3333
isIPv6Halted bool
3434
}
3535

36-
// NewDefault creates a new [Default] handler.
36+
// NewDefault creates a new [*Default] handler.
3737
func NewDefault(conf *DefaultConfig) (d *Default) {
3838
mc, ok := conf.MessageConstructor.(messageConstructor)
3939
if !ok {
@@ -50,10 +50,13 @@ func NewDefault(conf *DefaultConfig) (d *Default) {
5050
}
5151
}
5252

53-
// HandleRequest resolves the DNS request within proxyCtx. It only calls
54-
// [proxy.Proxy.Resolve] if the request isn't handled by any of the internal
55-
// handlers.
56-
func (h *Default) HandleRequest(p *proxy.Proxy, proxyCtx *proxy.DNSContext) (err error) {
53+
// type check
54+
var _ proxy.RequestHandler = (*Default)(nil)
55+
56+
// Handle implements the [RequestHandler] interface for *Default. It resolves
57+
// the DNS request within proxyCtx. It only calls [proxy.Proxy.Resolve] if the
58+
// request isn't handled by any of the internal handlers.
59+
func (h *Default) Handle(p *proxy.Proxy, proxyCtx *proxy.DNSContext) (err error) {
5760
// TODO(e.burkov): Use the [*context.Context] instead of
5861
// [*proxy.DNSContext] when the interface-based handler is implemented.
5962
ctx := context.TODO()

proxy/config.go

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,6 @@ const (
2828
DefaultOptimisticAnswerTTL = 30 * time.Second
2929
)
3030

31-
// RequestHandler is an optional custom handler for DNS requests. It's used
32-
// instead of [Proxy.Resolve] if set. The resulting error doesn't affect the
33-
// request processing.
34-
//
35-
// TODO(e.burkov): Use the same interface-based approach as
36-
// [BeforeRequestHandler].
37-
type RequestHandler func(p *Proxy, dctx *DNSContext) (err error)
38-
3931
// Config contains all the fields necessary for proxy configuration.
4032
//
4133
// TODO(a.garipov): Consider extracting conf blocks for better fieldalignment.
@@ -69,7 +61,7 @@ type Config struct {
6961
BeforeRequestHandler BeforeRequestHandler
7062

7163
// RequestHandler is an optional custom handler for DNS requests. It's used
72-
// instead of [Proxy.Resolve] if set. See [RequestHandler].
64+
// instead of defaultRequestHandler if set.
7365
RequestHandler RequestHandler
7466

7567
// UpstreamConfig is a general set of DNS servers to forward requests to.

proxy/handler_internal_test.go

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,13 @@ func TestFilteringHandler(t *testing.T) {
1717
m := &sync.RWMutex{}
1818
blockResponse := false
1919

20-
// Prepare the proxy server
21-
dnsProxy := mustNew(t, &Config{
22-
Logger: slogutil.NewDiscardLogger(),
23-
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
24-
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
25-
UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
26-
TrustedProxies: defaultTrustedProxies,
27-
RatelimitSubnetLenIPv4: 24,
28-
RatelimitSubnetLenIPv6: 64,
29-
RequestHandler: func(p *Proxy, d *DNSContext) error {
20+
reqHandler := &TestRequestHandler{
21+
OnHandle: func(p *Proxy, d *DNSContext) (err error) {
3022
m.Lock()
3123
defer m.Unlock()
3224

3325
if !blockResponse {
34-
// Use the default Resolve method if response is not blocked
26+
// Use the default Resolve method if response is not blocked.
3527
return p.Resolve(d)
3628
}
3729

@@ -41,8 +33,21 @@ func TestFilteringHandler(t *testing.T) {
4133

4234
// Set the response right away
4335
d.Res = &resp
36+
4437
return nil
4538
},
39+
}
40+
41+
// Prepare the proxy server.
42+
dnsProxy := mustNew(t, &Config{
43+
Logger: slogutil.NewDiscardLogger(),
44+
TrustedProxies: defaultTrustedProxies,
45+
UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
46+
RequestHandler: reqHandler,
47+
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
48+
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
49+
RatelimitSubnetLenIPv4: 24,
50+
RatelimitSubnetLenIPv6: 64,
4651
})
4752

4853
servicetest.RequireRun(t, dnsProxy, testTimeout)

proxy/pending_test.go

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,14 @@ func TestPendingRequests(t *testing.T) {
7979
workloadWG := &sync.WaitGroup{}
8080
workloadWG.Add(reqsNum)
8181

82+
reqHandler := &proxy.TestRequestHandler{
83+
OnHandle: func(p *proxy.Proxy, d *proxy.DNSContext) (err error) {
84+
workloadWG.Done()
85+
86+
return p.Resolve(d)
87+
},
88+
}
89+
8290
once := &sync.Once{}
8391
u := &dnsproxytest.Upstream{
8492
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
@@ -98,25 +106,21 @@ func TestPendingRequests(t *testing.T) {
98106
}
99107

100108
p, err := proxy.New(&proxy.Config{
101-
Logger: slogutil.NewDiscardLogger(),
109+
Logger: slogutil.NewDiscardLogger(),
110+
UpstreamConfig: &proxy.UpstreamConfig{Upstreams: []upstream.Upstream{u}},
111+
TrustedProxies: testTrustedProxies,
112+
PendingRequests: &proxy.PendingRequestsConfig{
113+
Enabled: true,
114+
},
115+
RequestHandler: reqHandler,
102116
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
103117
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
104-
UpstreamConfig: &proxy.UpstreamConfig{Upstreams: []upstream.Upstream{u}},
105-
TrustedProxies: testTrustedProxies,
106118
RatelimitSubnetLenIPv4: 24,
107119
RatelimitSubnetLenIPv6: 64,
108120
Ratelimit: 0,
109-
CacheEnabled: true,
110121
CacheSizeBytes: testCacheSize,
122+
CacheEnabled: true,
111123
EnableEDNSClientSubnet: true,
112-
PendingRequests: &proxy.PendingRequestsConfig{
113-
Enabled: true,
114-
},
115-
RequestHandler: func(prx *proxy.Proxy, dctx *proxy.DNSContext) (err error) {
116-
workloadWG.Done()
117-
118-
return prx.Resolve(dctx)
119-
},
120124
})
121125
require.NoError(t, err)
122126

proxy/proxy.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ type Proxy struct {
9292
// beforeRequestHandler handles the request's context before it is resolved.
9393
beforeRequestHandler BeforeRequestHandler
9494

95+
// requestHandler handles the DNS request after it's been processed by the
96+
// beforeRequestHandler. It is never nil.
97+
requestHandler RequestHandler
98+
9599
// dnsCryptServer serves DNSCrypt queries.
96100
dnsCryptServer *dnscrypt.Server
97101

@@ -227,6 +231,7 @@ func New(c *Config) (p *Proxy, err error) {
227231
c.BeforeRequestHandler,
228232
noopRequestHandler{},
229233
),
234+
requestHandler: cmp.Or[RequestHandler](c.RequestHandler, defaultRequestHandler{}),
230235
upstreamRTTStats: map[string]upstreamRTTStats{},
231236
rttLock: sync.Mutex{},
232237
ratelimitLock: sync.Mutex{},

proxy/requesthandler.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package proxy
2+
3+
// RequestHandler is an interface for handling DNS requests.
4+
type RequestHandler interface {
5+
// Handle resolves the DNS request within *DNSContext.
6+
Handle(p *Proxy, dctx *DNSContext) (err error)
7+
}
8+
9+
// defaultRequestHandler implements [RequestHandler] by calling [Proxy.Resolve].
10+
type defaultRequestHandler struct{}
11+
12+
// type check
13+
var _ RequestHandler = defaultRequestHandler{}
14+
15+
// Handle implements the [RequestHandler] interface for defaultRequestHandler.
16+
func (defaultRequestHandler) Handle(p *Proxy, proxyCtx *DNSContext) (err error) {
17+
return p.Resolve(proxyCtx)
18+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package proxy
2+
3+
// TestRequestHandler is a mock request handler implementation to simplify
4+
// testing.
5+
type TestRequestHandler struct {
6+
OnHandle func(p *Proxy, dctx *DNSContext) (err error)
7+
}
8+
9+
// type check
10+
var _ RequestHandler = (*TestRequestHandler)(nil)
11+
12+
// Handle implements the [RequestHandler] interface for *TestRequestHandler.
13+
func (h *TestRequestHandler) Handle(p *Proxy, dctx *DNSContext) (err error) {
14+
return h.OnHandle(p, dctx)
15+
}

proxy/server.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,7 @@ func (p *Proxy) handleDNSRequest(d *DNSContext) (err error) {
121121

122122
d.Res = p.validateRequest(d)
123123
if d.Res == nil {
124-
if p.RequestHandler != nil {
125-
err = errors.Annotate(p.RequestHandler(p, d), "using request handler: %w")
126-
} else {
127-
err = errors.Annotate(p.Resolve(d), "using default request handler: %w")
128-
}
124+
err = p.requestHandler.Handle(p, d)
129125
}
130126

131127
p.logDNSMessage(d.Res)

proxy/serverhttps_internal_test.go

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,27 +75,30 @@ func TestProxy_trustedProxies(t *testing.T) {
7575
)
7676

7777
doRequest := func(t *testing.T, addr, expectedClientIP netip.Addr) {
78+
var gotAddr netip.Addr
79+
reqHandler := &TestRequestHandler{
80+
OnHandle: func(p *Proxy, d *DNSContext) (err error) {
81+
gotAddr = d.Addr.Addr()
82+
83+
return p.Resolve(d)
84+
},
85+
}
86+
7887
// Prepare the proxy server.
7988
tlsConf, caPem := newTLSConfig(t)
8089
dnsProxy := mustNew(t, &Config{
8190
Logger: slogutil.NewDiscardLogger(),
91+
UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
92+
TrustedProxies: defaultTrustedProxies,
93+
RequestHandler: reqHandler,
94+
TLSConfig: tlsConf,
8295
TLSListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
8396
HTTPSListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
8497
QUICListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
85-
TLSConfig: tlsConf,
86-
UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
87-
TrustedProxies: defaultTrustedProxies,
8898
RatelimitSubnetLenIPv4: 24,
8999
RatelimitSubnetLenIPv6: 64,
90100
})
91101

92-
var gotAddr netip.Addr
93-
dnsProxy.RequestHandler = func(_ *Proxy, d *DNSContext) (err error) {
94-
gotAddr = d.Addr.Addr()
95-
96-
return dnsProxy.Resolve(d)
97-
}
98-
99102
client := createTestHTTPClient(dnsProxy, caPem, false)
100103

101104
msg := newTestMessage()

0 commit comments

Comments
 (0)