@@ -362,9 +362,24 @@ func copyResponse(req *dns.Msg, srcResp *dns.Msg, destResp *dns.Msg) {
362362
363363func randomNameserverAddress (nameservers map [string ][]net.IP ) net.IP {
364364 // Put all namserver addresses in single list
365- tmpNameservers := []net.IP {}
365+ tmpNameserversIpv4 := []net.IP {}
366+ tmpNameserversIpv6 := []net.IP {}
366367 for _ , addresses := range nameservers {
367- tmpNameservers = append (tmpNameservers , addresses ... )
368+ for _ , address := range addresses {
369+ if ip := address .To4 (); ip != nil {
370+ tmpNameserversIpv4 = append (tmpNameserversIpv4 , address )
371+ } else {
372+ tmpNameserversIpv6 = append (tmpNameserversIpv6 , address )
373+ }
374+ }
375+ }
376+ // Collect only IPv4 addresses unless we only have IPv6
377+ // We can't guarantee that IPv6 works, so we try not to use it
378+ var tmpNameservers []net.IP
379+ if len (tmpNameserversIpv4 ) > 0 {
380+ tmpNameservers = tmpNameserversIpv4
381+ } else {
382+ tmpNameservers = tmpNameserversIpv6
368383 }
369384 if len (tmpNameservers ) > 0 {
370385 n , err := rand .Int (rand .Reader , big .NewInt (int64 (len (tmpNameservers ))))
@@ -383,8 +398,8 @@ func doQuery(msg *dns.Msg, address string, recursive bool) (*dns.Msg, error) {
383398 address = randomFallbackServer ()
384399 }
385400 // Add default port to address if there is none
386- if ! strings . Contains ( address , ":" ) {
387- address = address + `: 53`
401+ if _ , _ , err := net . SplitHostPort ( address ); err != nil {
402+ address = net . JoinHostPort ( address , ` 53`)
388403 }
389404 slog .Debug (
390405 fmt .Sprintf (
@@ -417,22 +432,12 @@ func doQuery(msg *dns.Msg, address string, recursive bool) (*dns.Msg, error) {
417432 if recursive {
418433 if len (resp .Ns ) > 0 {
419434 nameservers := getNameserversFromResponse (resp )
420- randNsName , randNsAddress := randomNameserver (nameservers )
421- if randNsAddress == "" {
422- m := createQuery (randNsName , dns .TypeA )
423- // XXX: should this query the fallback servers or the server that gave us the NS response?
424- resp , err := doQuery (m , "" , false )
425- if err != nil {
426- return nil , err
427- }
428- randNsAddress = getAddressForNameFromResponse (resp , randNsName )
429- if randNsAddress == "" {
430- // Return the current response if we couldn't get an address for the nameserver
431- return resp , nil
432- }
435+ randNsAddress := randomNameserverAddress (nameservers )
436+ if randNsAddress == nil {
437+ return nil , errors .New ("could not get nameservers from response" )
433438 }
434439 // Perform recursive query
435- return doQuery (msg , randNsAddress , true )
440+ return doQuery (msg , randNsAddress . String () , true )
436441 } else {
437442 // Return the current response if there is no authority information
438443 return resp , nil
@@ -573,70 +578,6 @@ func getNameserversFromResponse(msg *dns.Msg) map[string][]net.IP {
573578 return ret
574579}
575580
576- func getAddressForNameFromResponse (msg * dns.Msg , recordName string ) string {
577- var retRR dns.RR
578- for _ , answer := range msg .Answer {
579- if answer .Header ().Name == recordName {
580- retRR = answer
581- break
582- }
583- }
584- if retRR == nil {
585- for _ , extra := range msg .Extra {
586- if extra .Header ().Name == recordName {
587- retRR = extra
588- break
589- }
590- }
591- }
592- if retRR == nil {
593- return ""
594- }
595- switch v := retRR .(type ) {
596- case * dns.A :
597- if v .A != nil {
598- return v .A .String ()
599- }
600- case * dns.AAAA :
601- if v .AAAA != nil {
602- return v .AAAA .String ()
603- }
604- }
605- return ""
606- }
607-
608- func randomNameserver (nameservers map [string ][]net.IP ) (string , string ) {
609- mapKeys := []string {}
610- for k := range nameservers {
611- mapKeys = append (mapKeys , k )
612- }
613- if len (mapKeys ) > 0 {
614- n , err := rand .Int (rand .Reader , big .NewInt (int64 (len (mapKeys ))))
615- if err != nil {
616- return "" , ""
617- }
618- randNsName := mapKeys [n .Int64 ()]
619- randNsAddresses := nameservers [randNsName ]
620- if randNsAddresses == nil {
621- return "" , ""
622- }
623- n , err = rand .Int (rand .Reader , big .NewInt (int64 (len (randNsAddresses ))))
624- if err != nil {
625- return "" , ""
626- }
627- randNsAddress := randNsAddresses [n .Int64 ()].String ()
628- return randNsName , randNsAddress
629- }
630- return "" , ""
631- }
632-
633- func createQuery (recordName string , recordType uint16 ) * dns.Msg {
634- m := new (dns.Msg )
635- m .SetQuestion (recordName , recordType )
636- m .RecursionDesired = false
637- return m
638- }
639-
640581func randomFallbackServer () string {
641582 cfg := config .GetConfig ()
642583 n , err := rand .Int (
0 commit comments