Skip to content

Commit 561ce91

Browse files
committed
resolver: add DNS address preference wrapper
1 parent c5aadbc commit 561ce91

File tree

2 files changed

+129
-5
lines changed

2 files changed

+129
-5
lines changed

main.go

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,25 @@ type autocertCache struct {
197197

198198
const envCacheEncKey = "DUMBPROXY_CACHE_ENC_KEY"
199199

200+
type dnsPreferenceArg resolver.Preference
201+
202+
func (a *dnsPreferenceArg) String() string {
203+
return resolver.Preference(*a).String()
204+
}
205+
206+
func (a *dnsPreferenceArg) Set(s string) error {
207+
p, err := resolver.ParsePreference(s)
208+
if err != nil {
209+
return nil
210+
}
211+
*a = dnsPreferenceArg(p)
212+
return nil
213+
}
214+
215+
func (a *dnsPreferenceArg) Value() resolver.Preference {
216+
return resolver.Preference(*a)
217+
}
218+
200219
type bindSpec struct {
201220
af string
202221
address string
@@ -241,6 +260,8 @@ type CLIArgs struct {
241260
bwBurst int64
242261
bwBuckets uint
243262
bwSeparate bool
263+
dnsServers []string
264+
dnsPreferAddress dnsPreferenceArg
244265
dnsCacheTTL time.Duration
245266
dnsCacheNegTTL time.Duration
246267
dnsCacheTimeout time.Duration
@@ -251,7 +272,6 @@ type CLIArgs struct {
251272
jsProxyRouterInstances int
252273
proxyproto bool
253274
shutdownTimeout time.Duration
254-
nameservers []string
255275
}
256276

257277
func parse_args() CLIArgs {
@@ -277,6 +297,7 @@ func parse_args() CLIArgs {
277297
address: ":8080",
278298
af: "tcp",
279299
},
300+
dnsPreferAddress: dnsPreferenceArg(resolver.PreferenceIPv4),
280301
}
281302
args.autocertCacheEncKey.Set(os.Getenv(envCacheEncKey))
282303
flag.Func("bind-address", "HTTP proxy listen address. Set empty value to use systemd socket activation. (default \":8080\")", func(p string) error {
@@ -364,12 +385,13 @@ func parse_args() CLIArgs {
364385
flag.BoolVar(&args.bwSeparate, "bw-limit-separate", false, "separate upload and download bandwidth limits")
365386
flag.Func("dns-server", "nameserver specification (udp://..., tcp://..., https://..., tls://..., doh://..., dot://..., default://). Option can be used multiple times for parallel use of multiple nameservers. Empty string resets list", func(p string) error {
366387
if p == "" {
367-
args.nameservers = nil
388+
args.dnsServers = nil
368389
} else {
369-
args.nameservers = append(args.nameservers, p)
390+
args.dnsServers = append(args.dnsServers, p)
370391
}
371392
return nil
372393
})
394+
flag.Var(&args.dnsPreferAddress, "dns-prefer-address", "address resolution preference (none/ipv4/ipv6)")
373395
flag.DurationVar(&args.dnsCacheTTL, "dns-cache-ttl", 0, "enable DNS cache with specified fixed TTL")
374396
flag.DurationVar(&args.dnsCacheNegTTL, "dns-cache-neg-ttl", time.Second, "TTL for negative responses of DNS cache")
375397
flag.DurationVar(&args.dnsCacheTimeout, "dns-cache-timeout", 5*time.Second, "timeout for shared resolves of DNS cache")
@@ -521,13 +543,14 @@ func run() int {
521543
dialerRoot = dialer.NewFilterDialer(filterRoot.Access, dialerRoot) // must follow after resolving in chain
522544

523545
var nameResolver dialer.Resolver = net.DefaultResolver
524-
if len(args.nameservers) > 0 {
525-
nameResolver, err = resolver.FastFromURLs(args.nameservers...)
546+
if len(args.dnsServers) > 0 {
547+
nameResolver, err = resolver.FastFromURLs(args.dnsServers...)
526548
if err != nil {
527549
mainLogger.Critical("Failed to create name resolver: %v", err)
528550
return 3
529551
}
530552
}
553+
nameResolver = resolver.Prefer(nameResolver, args.dnsPreferAddress.Value())
531554
if args.dnsCacheTTL > 0 {
532555
cd := dialer.NewNameResolveCachingDialer(
533556
dialerRoot,

resolver/prefer.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
package resolver
2+
3+
import (
4+
"cmp"
5+
"context"
6+
"fmt"
7+
"net/netip"
8+
"slices"
9+
"strings"
10+
)
11+
12+
type Preference int
13+
14+
const (
15+
PreferenceNothing Preference = iota
16+
PreferenceIPv4
17+
PreferenceIPv6
18+
)
19+
20+
func (p Preference) String() string {
21+
switch p {
22+
case PreferenceNothing:
23+
return "none"
24+
case PreferenceIPv4:
25+
return "ipv4"
26+
case PreferenceIPv6:
27+
return "ipv6"
28+
default:
29+
return fmt.Sprintf("Preference(%d)", int(p))
30+
}
31+
}
32+
33+
func ParsePreference(p string) (Preference, error) {
34+
var res Preference
35+
switch lp := strings.ToLower(p); lp {
36+
case "none", "nothing", "any", "anything":
37+
res = PreferenceNothing
38+
case "ipv4", "ip4", "v4", "4":
39+
res = PreferenceIPv4
40+
case "ipv6", "ip6", "v6", "6":
41+
res = PreferenceIPv6
42+
default:
43+
return 0, fmt.Errorf("unknown preference specification %q", p)
44+
}
45+
return res, nil
46+
}
47+
48+
func boolToInt(x bool) int {
49+
if x {
50+
return 0
51+
}
52+
return 1
53+
}
54+
55+
type PreferIPv4 struct {
56+
LookupNetIPer
57+
}
58+
59+
func (p PreferIPv4) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
60+
addrs, err := p.LookupNetIPer.LookupNetIP(ctx, network, host)
61+
if err != nil {
62+
return nil, err
63+
}
64+
slices.SortStableFunc(addrs, func(a, b netip.Addr) int {
65+
return cmp.Compare(
66+
boolToInt(a.Unmap().Is4()),
67+
boolToInt(b.Unmap().Is4()),
68+
)
69+
})
70+
return addrs, nil
71+
}
72+
73+
type PreferIPv6 struct {
74+
LookupNetIPer
75+
}
76+
77+
func (p PreferIPv6) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
78+
addrs, err := p.LookupNetIPer.LookupNetIP(ctx, network, host)
79+
if err != nil {
80+
return nil, err
81+
}
82+
slices.SortStableFunc(addrs, func(a, b netip.Addr) int {
83+
return cmp.Compare(
84+
boolToInt(a.Unmap().Is6()),
85+
boolToInt(b.Unmap().Is6()),
86+
)
87+
})
88+
return addrs, nil
89+
}
90+
91+
func Prefer(resolver LookupNetIPer, p Preference) LookupNetIPer {
92+
switch p {
93+
case PreferenceNothing:
94+
return resolver
95+
case PreferenceIPv4:
96+
return PreferIPv4{resolver}
97+
case PreferenceIPv6:
98+
return PreferIPv6{resolver}
99+
}
100+
panic("unknown address family preference")
101+
}

0 commit comments

Comments
 (0)