Skip to content

Commit 1585373

Browse files
authored
refactor: prefer IPv4 addresses when doing recursive lookups (#480)
This also fixes querying IPv6 addresses and removes the randomNameserver function in favor of randomNameserverAddress when doing recursive queries Fixes #477 Signed-off-by: Aurora Gaffney <[email protected]>
1 parent be3591c commit 1585373

File tree

1 file changed

+23
-82
lines changed

1 file changed

+23
-82
lines changed

internal/dns/dns.go

Lines changed: 23 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -362,9 +362,24 @@ func copyResponse(req *dns.Msg, srcResp *dns.Msg, destResp *dns.Msg) {
362362

363363
func randomNameserverAddress(nameservers map[string][]net.IP) net.IP {
364364
// Put all namserver addresses in single list
365-
tmpNameservers := []net.IP{}
365+
tmpNameserversIpv4 := []net.IP{}
366+
tmpNameserversIpv6 := []net.IP{}
366367
for _, addresses := range nameservers {
367-
tmpNameservers = append(tmpNameservers, addresses...)
368+
for _, address := range addresses {
369+
if ip := address.To4(); ip != nil {
370+
tmpNameserversIpv4 = append(tmpNameserversIpv4, address)
371+
} else {
372+
tmpNameserversIpv6 = append(tmpNameserversIpv6, address)
373+
}
374+
}
375+
}
376+
// Collect only IPv4 addresses unless we only have IPv6
377+
// We can't guarantee that IPv6 works, so we try not to use it
378+
var tmpNameservers []net.IP
379+
if len(tmpNameserversIpv4) > 0 {
380+
tmpNameservers = tmpNameserversIpv4
381+
} else {
382+
tmpNameservers = tmpNameserversIpv6
368383
}
369384
if len(tmpNameservers) > 0 {
370385
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(tmpNameservers))))
@@ -383,8 +398,8 @@ func doQuery(msg *dns.Msg, address string, recursive bool) (*dns.Msg, error) {
383398
address = randomFallbackServer()
384399
}
385400
// Add default port to address if there is none
386-
if !strings.Contains(address, ":") {
387-
address = address + `:53`
401+
if _, _, err := net.SplitHostPort(address); err != nil {
402+
address = net.JoinHostPort(address, `53`)
388403
}
389404
slog.Debug(
390405
fmt.Sprintf(
@@ -417,22 +432,12 @@ func doQuery(msg *dns.Msg, address string, recursive bool) (*dns.Msg, error) {
417432
if recursive {
418433
if len(resp.Ns) > 0 {
419434
nameservers := getNameserversFromResponse(resp)
420-
randNsName, randNsAddress := randomNameserver(nameservers)
421-
if randNsAddress == "" {
422-
m := createQuery(randNsName, dns.TypeA)
423-
// XXX: should this query the fallback servers or the server that gave us the NS response?
424-
resp, err := doQuery(m, "", false)
425-
if err != nil {
426-
return nil, err
427-
}
428-
randNsAddress = getAddressForNameFromResponse(resp, randNsName)
429-
if randNsAddress == "" {
430-
// Return the current response if we couldn't get an address for the nameserver
431-
return resp, nil
432-
}
435+
randNsAddress := randomNameserverAddress(nameservers)
436+
if randNsAddress == nil {
437+
return nil, errors.New("could not get nameservers from response")
433438
}
434439
// Perform recursive query
435-
return doQuery(msg, randNsAddress, true)
440+
return doQuery(msg, randNsAddress.String(), true)
436441
} else {
437442
// Return the current response if there is no authority information
438443
return resp, nil
@@ -573,70 +578,6 @@ func getNameserversFromResponse(msg *dns.Msg) map[string][]net.IP {
573578
return ret
574579
}
575580

576-
func getAddressForNameFromResponse(msg *dns.Msg, recordName string) string {
577-
var retRR dns.RR
578-
for _, answer := range msg.Answer {
579-
if answer.Header().Name == recordName {
580-
retRR = answer
581-
break
582-
}
583-
}
584-
if retRR == nil {
585-
for _, extra := range msg.Extra {
586-
if extra.Header().Name == recordName {
587-
retRR = extra
588-
break
589-
}
590-
}
591-
}
592-
if retRR == nil {
593-
return ""
594-
}
595-
switch v := retRR.(type) {
596-
case *dns.A:
597-
if v.A != nil {
598-
return v.A.String()
599-
}
600-
case *dns.AAAA:
601-
if v.AAAA != nil {
602-
return v.AAAA.String()
603-
}
604-
}
605-
return ""
606-
}
607-
608-
func randomNameserver(nameservers map[string][]net.IP) (string, string) {
609-
mapKeys := []string{}
610-
for k := range nameservers {
611-
mapKeys = append(mapKeys, k)
612-
}
613-
if len(mapKeys) > 0 {
614-
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(mapKeys))))
615-
if err != nil {
616-
return "", ""
617-
}
618-
randNsName := mapKeys[n.Int64()]
619-
randNsAddresses := nameservers[randNsName]
620-
if randNsAddresses == nil {
621-
return "", ""
622-
}
623-
n, err = rand.Int(rand.Reader, big.NewInt(int64(len(randNsAddresses))))
624-
if err != nil {
625-
return "", ""
626-
}
627-
randNsAddress := randNsAddresses[n.Int64()].String()
628-
return randNsName, randNsAddress
629-
}
630-
return "", ""
631-
}
632-
633-
func createQuery(recordName string, recordType uint16) *dns.Msg {
634-
m := new(dns.Msg)
635-
m.SetQuestion(recordName, recordType)
636-
m.RecursionDesired = false
637-
return m
638-
}
639-
640581
func randomFallbackServer() string {
641582
cfg := config.GetConfig()
642583
n, err := rand.Int(

0 commit comments

Comments
 (0)