Skip to content

Commit 160ee68

Browse files
committed
refactor: prefer IPv4 addresses when doing recursive lookups
This also fixes querying IPv6 addresses and removes the randomNameserver function in favor of randomNameserverAddress when doing recursive queries Fixes #477 Signed-off-by: Aurora Gaffney <[email protected]>
1 parent be3591c commit 160ee68

File tree

1 file changed

+21
-44
lines changed

1 file changed

+21
-44
lines changed

internal/dns/dns.go

Lines changed: 21 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -362,9 +362,24 @@ func copyResponse(req *dns.Msg, srcResp *dns.Msg, destResp *dns.Msg) {
362362

363363
func randomNameserverAddress(nameservers map[string][]net.IP) net.IP {
364364
// Put all namserver addresses in single list
365-
tmpNameservers := []net.IP{}
365+
tmpNameserversIpv4 := []net.IP{}
366+
tmpNameserversIpv6 := []net.IP{}
366367
for _, addresses := range nameservers {
367-
tmpNameservers = append(tmpNameservers, addresses...)
368+
for _, address := range addresses {
369+
if ip := address.To4(); ip != nil {
370+
tmpNameserversIpv4 = append(tmpNameserversIpv4, address)
371+
} else {
372+
tmpNameserversIpv6 = append(tmpNameserversIpv6, address)
373+
}
374+
}
375+
}
376+
// Collect only IPv4 addresses unless we only have IPv6
377+
// We can't guarantee that IPv6 works, so we try not to use it
378+
var tmpNameservers []net.IP
379+
if len(tmpNameserversIpv4) > 0 {
380+
tmpNameservers = tmpNameserversIpv4
381+
} else {
382+
tmpNameservers = tmpNameserversIpv6
368383
}
369384
if len(tmpNameservers) > 0 {
370385
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(tmpNameservers))))
@@ -383,8 +398,8 @@ func doQuery(msg *dns.Msg, address string, recursive bool) (*dns.Msg, error) {
383398
address = randomFallbackServer()
384399
}
385400
// Add default port to address if there is none
386-
if !strings.Contains(address, ":") {
387-
address = address + `:53`
401+
if ip := net.ParseIP(address); ip != nil {
402+
address = net.JoinHostPort(address, `53`)
388403
}
389404
slog.Debug(
390405
fmt.Sprintf(
@@ -417,22 +432,9 @@ func doQuery(msg *dns.Msg, address string, recursive bool) (*dns.Msg, error) {
417432
if recursive {
418433
if len(resp.Ns) > 0 {
419434
nameservers := getNameserversFromResponse(resp)
420-
randNsName, randNsAddress := randomNameserver(nameservers)
421-
if randNsAddress == "" {
422-
m := createQuery(randNsName, dns.TypeA)
423-
// XXX: should this query the fallback servers or the server that gave us the NS response?
424-
resp, err := doQuery(m, "", false)
425-
if err != nil {
426-
return nil, err
427-
}
428-
randNsAddress = getAddressForNameFromResponse(resp, randNsName)
429-
if randNsAddress == "" {
430-
// Return the current response if we couldn't get an address for the nameserver
431-
return resp, nil
432-
}
433-
}
435+
randNsAddress := randomNameserverAddress(nameservers)
434436
// Perform recursive query
435-
return doQuery(msg, randNsAddress, true)
437+
return doQuery(msg, randNsAddress.String(), true)
436438
} else {
437439
// Return the current response if there is no authority information
438440
return resp, nil
@@ -605,31 +607,6 @@ func getAddressForNameFromResponse(msg *dns.Msg, recordName string) string {
605607
return ""
606608
}
607609

608-
func randomNameserver(nameservers map[string][]net.IP) (string, string) {
609-
mapKeys := []string{}
610-
for k := range nameservers {
611-
mapKeys = append(mapKeys, k)
612-
}
613-
if len(mapKeys) > 0 {
614-
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(mapKeys))))
615-
if err != nil {
616-
return "", ""
617-
}
618-
randNsName := mapKeys[n.Int64()]
619-
randNsAddresses := nameservers[randNsName]
620-
if randNsAddresses == nil {
621-
return "", ""
622-
}
623-
n, err = rand.Int(rand.Reader, big.NewInt(int64(len(randNsAddresses))))
624-
if err != nil {
625-
return "", ""
626-
}
627-
randNsAddress := randNsAddresses[n.Int64()].String()
628-
return randNsName, randNsAddress
629-
}
630-
return "", ""
631-
}
632-
633610
func createQuery(recordName string, recordType uint16) *dns.Msg {
634611
m := new(dns.Msg)
635612
m.SetQuestion(recordName, recordType)

0 commit comments

Comments
 (0)