Skip to content

Commit 96f975a

Browse files
committed
all: imp code
1 parent 6deeaa8 commit 96f975a

File tree

7 files changed

+80
-80
lines changed

7 files changed

+80
-80
lines changed

internal/cmd/config.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ type configuration struct {
181181
EnableEDNSSubnet bool `yaml:"edns"`
182182

183183
// PendingRequestsEnabled controls whether the server should track duplicate
184-
// queries and only send the first of them to the upstream server. It used
185-
// to mitigate the cache poisoning attacks.
184+
// queries and only send the first of them to the upstream server. It is
185+
// used to mitigate the cache poisoning attacks.
186186
PendingRequestsEnabled bool `yaml:"pending-requests-enabled"`
187187

188188
// DNS64 defines whether DNS64 functionality is enabled or not.

proxy/config.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,12 @@ type Config struct {
5858
// constructor will be used.
5959
MessageConstructor MessageConstructor
6060

61-
// PendingRequests limits the number of identical requests sent to the
62-
// upstream server. If nil, the default implementation will be used. Use:
63-
// - [DefaultPendingRequests] to enable the pending requests feature.
64-
// - [EmptyPendingRequests] to disable the pending requests feature.
61+
// PendingRequests is used to mitigate the cache poisoning attacks by
62+
// tracking identical requests and returning the same response for them,
63+
// peforming a single lookup. If nil, the default implementation will be
64+
// used. Use:
65+
// - [DefaultPendingRequests] to enable pending identical requests.
66+
// - [EmptyPendingRequests] to disable pending identical requests.
6567
PendingRequests PendingRequests
6668

6769
// BeforeRequestHandler is an optional custom handler called before each DNS

proxy/pending.go

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ type pendingRequest struct {
4242
// channel is closed.
4343
resolveErr error
4444

45-
// cloneCtx is a clone of the DNSContext that was used to create the
45+
// cloneDNSCtx is a clone of the DNSContext that was used to create the
4646
// pendingRequest and store its result. It must only be accessed for
4747
// reading after the finish channel is closed.
48-
cloneCtx *DNSContext
48+
cloneDNSCtx *DNSContext
4949
}
5050

5151
// NewDefaultPendingRequests creates a new instance of DefaultPendingRequests.
@@ -77,23 +77,25 @@ func (pr *DefaultPendingRequests) queue(
7777
}
7878

7979
pendingVal, exists := pr.storage.LoadOrStore(string(key), req)
80-
if exists {
81-
pending := pendingVal.(*pendingRequest)
82-
<-pending.finish
80+
if !exists {
81+
return false, nil
82+
}
8383

84-
if pending.cloneCtx != nil {
85-
// TODO(e.burkov): Add cloner for DNS messages.
86-
dctx.Res = pending.cloneCtx.Res.Copy().SetReply(dctx.Req)
87-
dctx.Upstream = pending.cloneCtx.Upstream
84+
pending := pendingVal.(*pendingRequest)
85+
<-pending.finish
8886

89-
// TODO(a.garipov): !! Decide how to treat query durations.
90-
dctx.queryStatistics = pending.cloneCtx.queryStatistics
91-
}
87+
origDNSCtx := pending.cloneDNSCtx
9288

93-
return exists, pending.resolveErr
89+
// TODO(a.garipov): Perhaps, statistics should be calculated separately for
90+
// each request.
91+
dctx.queryStatistics = origDNSCtx.queryStatistics
92+
dctx.Upstream = origDNSCtx.Upstream
93+
if origDNSCtx.Res != nil {
94+
// TODO(e.burkov): Add cloner for DNS messages.
95+
dctx.Res = origDNSCtx.Res.Copy().SetReply(dctx.Req)
9496
}
9597

96-
return false, nil
98+
return exists, pending.resolveErr
9799
}
98100

99101
// done implements the [PendingRequests] interface for [DefaultPendingRequests].
@@ -114,14 +116,16 @@ func (pr *DefaultPendingRequests) done(ctx context.Context, dctx *DNSContext, er
114116
pending := pendingVal.(*pendingRequest)
115117
pending.resolveErr = err
116118

117-
cloneCtx := &DNSContext{}
119+
cloneCtx := &DNSContext{
120+
Upstream: dctx.Upstream,
121+
queryStatistics: dctx.queryStatistics,
122+
}
123+
118124
if dctx.Res != nil {
119125
cloneCtx.Res = dctx.Res.Copy()
120-
cloneCtx.Upstream = dctx.Upstream
121-
cloneCtx.queryStatistics = dctx.queryStatistics
122126
}
123127

124-
pending.cloneCtx = cloneCtx
128+
pending.cloneDNSCtx = cloneCtx
125129

126130
pr.storage.Delete(string(key))
127131
close(pending.finish)
Lines changed: 43 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,45 @@
1-
package proxy
1+
package proxy_test
22

33
import (
4-
"context"
54
"net"
5+
"net/netip"
66
"sync"
77
"testing"
8+
"time"
89

910
"github.com/AdguardTeam/dnsproxy/internal/dnsproxytest"
11+
"github.com/AdguardTeam/dnsproxy/proxy"
1012
"github.com/AdguardTeam/dnsproxy/upstream"
1113
"github.com/AdguardTeam/golibs/logutil/slogutil"
14+
"github.com/AdguardTeam/golibs/netutil"
1215
"github.com/AdguardTeam/golibs/testutil"
1316
"github.com/miekg/dns"
1417
"github.com/stretchr/testify/assert"
1518
"github.com/stretchr/testify/require"
1619
)
1720

18-
// testPendingRequests is a mock implementation of [PendingRequests] for tests.
19-
//
20-
// TODO(e.burkov): Think of a better way to test [PendingRequests].
21-
type testPendingRequests struct {
22-
onQueue func(ctx context.Context, dctx *DNSContext) (exists bool, err error)
23-
onDone func(ctx context.Context, dctx *DNSContext, err error)
24-
}
21+
// TODO(e.burkov): Merge those with the ones in internal tests and move to
22+
// dnsproxytest.
2523

26-
// type check
27-
var _ PendingRequests = (*testPendingRequests)(nil)
24+
const (
25+
// testTimeout is the common timeout for tests and contexts.
26+
testTimeout = 1 * time.Second
2827

29-
// queue implements the [proxy.PendingRequests] interface for
30-
// *testPendingRequests.
31-
func (p *testPendingRequests) queue(
32-
ctx context.Context,
33-
dctx *DNSContext,
34-
) (exists bool, err error) {
35-
return p.onQueue(ctx, dctx)
36-
}
28+
// testCacheSize is the default size of the cache in bytes.
29+
testCacheSize = 64 * 1024
30+
)
3731

38-
// done implements the [proxy.PendingRequests] interface for
39-
// *testPendingRequests.
40-
func (p *testPendingRequests) done(ctx context.Context, dctx *DNSContext, err error) {
41-
p.onDone(ctx, dctx, err)
42-
}
32+
var (
33+
// localhostAnyPort is a localhost address with an arbitrary port.
34+
localhostAnyPort = netip.AddrPortFrom(netutil.IPv4Localhost(), 0)
35+
36+
// testTrustedProxies is a set of trusted proxies that includes all
37+
// addresses used in tests.
38+
testTrustedProxies = netutil.SliceSubnetSet{
39+
netip.MustParsePrefix("0.0.0.0/0"),
40+
netip.MustParsePrefix("::0/0"),
41+
}
42+
)
4343

4444
// assertEqualResponses is a helper function that checks if two DNS messages are
4545
// equal, excluding their ID.
@@ -66,6 +66,7 @@ func assertEqualResponses(tb testing.TB, expected, actual *dns.Msg) {
6666
assert.Equal(tb, expected.Extra, actual.Extra)
6767
}
6868

69+
// TODO(e.burkov): Consider unexporting the [proxy.PendingRequests] interface.
6970
func TestPendingRequests(t *testing.T) {
7071
t.Parallel()
7172

@@ -77,57 +78,54 @@ func TestPendingRequests(t *testing.T) {
7778
once := &sync.Once{}
7879
u := &dnsproxytest.FakeUpstream{
7980
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
80-
loadWG.Wait()
8181
once.Do(func() {
8282
resp = (&dns.Msg{}).SetReply(req)
8383
})
8484

8585
// Only allow a single request to be processed.
8686
require.NotNil(testutil.PanicT{}, resp)
8787

88+
loadWG.Wait()
89+
8890
return resp, nil
8991
},
9092
OnAddress: func() (addr string) { return "" },
9193
OnClose: func() (err error) { return nil },
9294
}
9395

94-
pending := NewDefaultPendingRequests()
95-
testPending := &testPendingRequests{
96-
onQueue: func(ctx context.Context, dctx *DNSContext) (exists bool, err error) {
97-
loadWG.Done()
98-
99-
return pending.queue(ctx, dctx)
100-
},
101-
onDone: pending.done,
102-
}
103-
104-
p := mustNew(t, &Config{
96+
p, err := proxy.New(&proxy.Config{
10597
Logger: slogutil.NewDiscardLogger(),
10698
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
10799
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
108-
UpstreamConfig: &UpstreamConfig{Upstreams: []upstream.Upstream{u}},
109-
TrustedProxies: defaultTrustedProxies,
100+
UpstreamConfig: &proxy.UpstreamConfig{Upstreams: []upstream.Upstream{u}},
101+
TrustedProxies: testTrustedProxies,
110102
RatelimitSubnetLenIPv4: 24,
111103
RatelimitSubnetLenIPv6: 64,
112104
Ratelimit: 0,
113105
CacheEnabled: true,
114-
CacheSizeBytes: defaultCacheSize,
106+
CacheSizeBytes: testCacheSize,
115107
EnableEDNSClientSubnet: true,
116-
PendingRequests: testPending,
108+
PendingRequests: proxy.NewDefaultPendingRequests(),
109+
RequestHandler: func(prx *proxy.Proxy, dctx *proxy.DNSContext) (err error) {
110+
loadWG.Done()
111+
112+
return prx.Resolve(dctx)
113+
},
117114
})
115+
require.NoError(t, err)
118116

119117
ctx := testutil.ContextWithTimeout(t, testTimeout)
120-
err := p.Start(ctx)
118+
err = p.Start(ctx)
121119
require.NoError(t, err)
122120
testutil.CleanupAndRequireSuccess(t, func() (err error) {
123121
ctx = testutil.ContextWithTimeout(t, testTimeout)
124122

125123
return p.Shutdown(ctx)
126124
})
127125

128-
addr := p.Addr(ProtoTCP).String()
126+
addr := p.Addr(proxy.ProtoTCP).String()
129127
client := &dns.Client{
130-
Net: string(ProtoTCP),
128+
Net: string(proxy.ProtoTCP),
131129
Timeout: testTimeout,
132130
}
133131

@@ -138,11 +136,11 @@ func TestPendingRequests(t *testing.T) {
138136
for i := range reqsNum {
139137
resolveWG.Add(1)
140138

139+
req := (&dns.Msg{}).SetQuestion("domain.example.", dns.TypeA)
140+
141141
go func() {
142142
defer resolveWG.Done()
143143

144-
req := (&dns.Msg{}).SetQuestion("domain.example.", dns.TypeA)
145-
146144
reqCtx := testutil.ContextWithTimeout(t, testTimeout)
147145
responses[i], _, errs[i] = client.ExchangeContext(reqCtx, req, addr)
148146
}()

proxy/proxy.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -693,9 +693,6 @@ func (p *Proxy) Resolve(dctx *DNSContext) (err error) {
693693

694694
dctx.calcFlagsAndSize()
695695

696-
// Also don't lookup the cache for responses with DNSSEC checking disabled
697-
// since only validated responses are cached and those may be not the
698-
// desired result for user specifying CD flag.
699696
cacheWorks := p.cacheWorks(dctx)
700697
if cacheWorks {
701698
// Only add pending requests if the cache is enabled, since this is a
@@ -771,6 +768,9 @@ func (p *Proxy) cacheWorks(dctx *DNSContext) (ok bool) {
771768
// TODO(e.burkov): It probably should be decided after resolve.
772769
reason = "custom upstreams cache is not configured"
773770
case dctx.Req.CheckingDisabled:
771+
// Also don't lookup the cache for responses with DNSSEC checking
772+
// disabled since only validated responses are cached and those may be
773+
// not the desired result for user specifying CD flag.
774774
reason = "dnssec check disabled"
775775
default:
776776
return true

proxy/stats.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,14 @@ func cachedQueryStatistics(addr string) (s *QueryStatistics) {
7474
}
7575
}
7676

77-
// Main returns the DNS query statistics for the upstream DNS servers.
77+
// Main returns the DNS query statistics for the upstream DNS servers. us and
78+
// its items must not be modified.
7879
func (s *QueryStatistics) Main() (us []*UpstreamStatistics) {
7980
return s.main
8081
}
8182

82-
// Fallback returns the DNS query statistics for the fallback DNS servers.
83+
// Fallback returns the DNS query statistics for the fallback DNS servers. us
84+
// and its items must not be modified.
8385
func (s *QueryStatistics) Fallback() (us []*UpstreamStatistics) {
8486
return s.fallback
8587
}

proxy/stats_test.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@ import (
1717
)
1818

1919
func TestCollectQueryStats(t *testing.T) {
20-
const (
21-
listenIP = "127.0.0.1"
22-
)
23-
2420
var (
2521
testReq = &dns.Msg{
2622
Question: []dns.Question{{
@@ -34,8 +30,6 @@ func TestCollectQueryStats(t *testing.T) {
3430
netip.MustParsePrefix("0.0.0.0/0"),
3531
netip.MustParsePrefix("::0/0"),
3632
}
37-
38-
localhostAnyPort = netip.MustParseAddrPort(netutil.JoinHostPort(listenIP, 0))
3933
)
4034

4135
ups := &dnsproxytest.FakeUpstream{

0 commit comments

Comments
 (0)