1- package proxy
1+ package proxy_test
22
33import (
4- "context"
54 "net"
5+ "net/netip"
66 "sync"
77 "testing"
8+ "time"
89
910 "github.com/AdguardTeam/dnsproxy/internal/dnsproxytest"
11+ "github.com/AdguardTeam/dnsproxy/proxy"
1012 "github.com/AdguardTeam/dnsproxy/upstream"
1113 "github.com/AdguardTeam/golibs/logutil/slogutil"
14+ "github.com/AdguardTeam/golibs/netutil"
1215 "github.com/AdguardTeam/golibs/testutil"
1316 "github.com/miekg/dns"
1417 "github.com/stretchr/testify/assert"
1518 "github.com/stretchr/testify/require"
1619)
1720
18- // testPendingRequests is a mock implementation of [PendingRequests] for tests.
19- //
20- // TODO(e.burkov): Think of a better way to test [PendingRequests].
21- type testPendingRequests struct {
22- onQueue func (ctx context.Context , dctx * DNSContext ) (exists bool , err error )
23- onDone func (ctx context.Context , dctx * DNSContext , err error )
24- }
21+ // TODO(e.burkov): Merge those with the ones in internal tests and move to
22+ // dnsproxytest.
2523
26- // type check
27- var _ PendingRequests = (* testPendingRequests )(nil )
24+ const (
25+ // testTimeout is the common timeout for tests and contexts.
26+ testTimeout = 1 * time .Second
2827
29- // queue implements the [proxy.PendingRequests] interface for
30- // *testPendingRequests.
31- func (p * testPendingRequests ) queue (
32- ctx context.Context ,
33- dctx * DNSContext ,
34- ) (exists bool , err error ) {
35- return p .onQueue (ctx , dctx )
36- }
28+ // testCacheSize is the default size of the cache in bytes.
29+ testCacheSize = 64 * 1024
30+ )
3731
38- // done implements the [proxy.PendingRequests] interface for
39- // *testPendingRequests.
40- func (p * testPendingRequests ) done (ctx context.Context , dctx * DNSContext , err error ) {
41- p .onDone (ctx , dctx , err )
42- }
32+ var (
33+ // localhostAnyPort is a localhost address with an arbitrary port.
34+ localhostAnyPort = netip .AddrPortFrom (netutil .IPv4Localhost (), 0 )
35+
36+ // testTrustedProxies is a set of trusted proxies that includes all
37+ // addresses used in tests.
38+ testTrustedProxies = netutil.SliceSubnetSet {
39+ netip .MustParsePrefix ("0.0.0.0/0" ),
40+ netip .MustParsePrefix ("::0/0" ),
41+ }
42+ )
4343
4444// assertEqualResponses is a helper function that checks if two DNS messages are
4545// equal, excluding their ID.
@@ -66,6 +66,7 @@ func assertEqualResponses(tb testing.TB, expected, actual *dns.Msg) {
6666 assert .Equal (tb , expected .Extra , actual .Extra )
6767}
6868
69+ // TODO(e.burkov): Consider unexporting the [proxy.PendingRequests] interface.
6970func TestPendingRequests (t * testing.T ) {
7071 t .Parallel ()
7172
@@ -77,57 +78,54 @@ func TestPendingRequests(t *testing.T) {
7778 once := & sync.Once {}
7879 u := & dnsproxytest.FakeUpstream {
7980 OnExchange : func (req * dns.Msg ) (resp * dns.Msg , err error ) {
80- loadWG .Wait ()
8181 once .Do (func () {
8282 resp = (& dns.Msg {}).SetReply (req )
8383 })
8484
8585 // Only allow a single request to be processed.
8686 require .NotNil (testutil.PanicT {}, resp )
8787
88+ loadWG .Wait ()
89+
8890 return resp , nil
8991 },
9092 OnAddress : func () (addr string ) { return "" },
9193 OnClose : func () (err error ) { return nil },
9294 }
9395
94- pending := NewDefaultPendingRequests ()
95- testPending := & testPendingRequests {
96- onQueue : func (ctx context.Context , dctx * DNSContext ) (exists bool , err error ) {
97- loadWG .Done ()
98-
99- return pending .queue (ctx , dctx )
100- },
101- onDone : pending .done ,
102- }
103-
104- p := mustNew (t , & Config {
96+ p , err := proxy .New (& proxy.Config {
10597 Logger : slogutil .NewDiscardLogger (),
10698 UDPListenAddr : []* net.UDPAddr {net .UDPAddrFromAddrPort (localhostAnyPort )},
10799 TCPListenAddr : []* net.TCPAddr {net .TCPAddrFromAddrPort (localhostAnyPort )},
108- UpstreamConfig : & UpstreamConfig {Upstreams : []upstream.Upstream {u }},
109- TrustedProxies : defaultTrustedProxies ,
100+ UpstreamConfig : & proxy. UpstreamConfig {Upstreams : []upstream.Upstream {u }},
101+ TrustedProxies : testTrustedProxies ,
110102 RatelimitSubnetLenIPv4 : 24 ,
111103 RatelimitSubnetLenIPv6 : 64 ,
112104 Ratelimit : 0 ,
113105 CacheEnabled : true ,
114- CacheSizeBytes : defaultCacheSize ,
106+ CacheSizeBytes : testCacheSize ,
115107 EnableEDNSClientSubnet : true ,
116- PendingRequests : testPending ,
108+ PendingRequests : proxy .NewDefaultPendingRequests (),
109+ RequestHandler : func (prx * proxy.Proxy , dctx * proxy.DNSContext ) (err error ) {
110+ loadWG .Done ()
111+
112+ return prx .Resolve (dctx )
113+ },
117114 })
115+ require .NoError (t , err )
118116
119117 ctx := testutil .ContextWithTimeout (t , testTimeout )
120- err : = p .Start (ctx )
118+ err = p .Start (ctx )
121119 require .NoError (t , err )
122120 testutil .CleanupAndRequireSuccess (t , func () (err error ) {
123121 ctx = testutil .ContextWithTimeout (t , testTimeout )
124122
125123 return p .Shutdown (ctx )
126124 })
127125
128- addr := p .Addr (ProtoTCP ).String ()
126+ addr := p .Addr (proxy . ProtoTCP ).String ()
129127 client := & dns.Client {
130- Net : string (ProtoTCP ),
128+ Net : string (proxy . ProtoTCP ),
131129 Timeout : testTimeout ,
132130 }
133131
@@ -138,11 +136,11 @@ func TestPendingRequests(t *testing.T) {
138136 for i := range reqsNum {
139137 resolveWG .Add (1 )
140138
139+ req := (& dns.Msg {}).SetQuestion ("domain.example." , dns .TypeA )
140+
141141 go func () {
142142 defer resolveWG .Done ()
143143
144- req := (& dns.Msg {}).SetQuestion ("domain.example." , dns .TypeA )
145-
146144 reqCtx := testutil .ContextWithTimeout (t , testTimeout )
147145 responses [i ], _ , errs [i ] = client .ExchangeContext (reqCtx , req , addr )
148146 }()
0 commit comments