@@ -362,9 +362,24 @@ func copyResponse(req *dns.Msg, srcResp *dns.Msg, destResp *dns.Msg) {
362362
363363func randomNameserverAddress (nameservers map [string ][]net.IP ) net.IP {
364364 // Put all namserver addresses in single list
365- tmpNameservers := []net.IP {}
365+ tmpNameserversIpv4 := []net.IP {}
366+ tmpNameserversIpv6 := []net.IP {}
366367 for _ , addresses := range nameservers {
367- tmpNameservers = append (tmpNameservers , addresses ... )
368+ for _ , address := range addresses {
369+ if ip := address .To4 (); ip != nil {
370+ tmpNameserversIpv4 = append (tmpNameserversIpv4 , address )
371+ } else {
372+ tmpNameserversIpv6 = append (tmpNameserversIpv6 , address )
373+ }
374+ }
375+ }
376+ // Collect only IPv4 addresses unless we only have IPv6
377+ // We can't guarantee that IPv6 works, so we try not to use it
378+ var tmpNameservers []net.IP
379+ if len (tmpNameserversIpv4 ) > 0 {
380+ tmpNameservers = tmpNameserversIpv4
381+ } else {
382+ tmpNameservers = tmpNameserversIpv6
368383 }
369384 if len (tmpNameservers ) > 0 {
370385 n , err := rand .Int (rand .Reader , big .NewInt (int64 (len (tmpNameservers ))))
@@ -383,8 +398,8 @@ func doQuery(msg *dns.Msg, address string, recursive bool) (*dns.Msg, error) {
383398 address = randomFallbackServer ()
384399 }
385400 // Add default port to address if there is none
386- if ! strings . Contains (address , ":" ) {
387- address = address + `: 53`
401+ if ip := net . ParseIP (address ); ip != nil {
402+ address = net . JoinHostPort ( address , ` 53`)
388403 }
389404 slog .Debug (
390405 fmt .Sprintf (
@@ -417,22 +432,9 @@ func doQuery(msg *dns.Msg, address string, recursive bool) (*dns.Msg, error) {
417432 if recursive {
418433 if len (resp .Ns ) > 0 {
419434 nameservers := getNameserversFromResponse (resp )
420- randNsName , randNsAddress := randomNameserver (nameservers )
421- if randNsAddress == "" {
422- m := createQuery (randNsName , dns .TypeA )
423- // XXX: should this query the fallback servers or the server that gave us the NS response?
424- resp , err := doQuery (m , "" , false )
425- if err != nil {
426- return nil , err
427- }
428- randNsAddress = getAddressForNameFromResponse (resp , randNsName )
429- if randNsAddress == "" {
430- // Return the current response if we couldn't get an address for the nameserver
431- return resp , nil
432- }
433- }
435+ randNsAddress := randomNameserverAddress (nameservers )
434436 // Perform recursive query
435- return doQuery (msg , randNsAddress , true )
437+ return doQuery (msg , randNsAddress . String () , true )
436438 } else {
437439 // Return the current response if there is no authority information
438440 return resp , nil
@@ -605,31 +607,6 @@ func getAddressForNameFromResponse(msg *dns.Msg, recordName string) string {
605607 return ""
606608}
607609
608- func randomNameserver (nameservers map [string ][]net.IP ) (string , string ) {
609- mapKeys := []string {}
610- for k := range nameservers {
611- mapKeys = append (mapKeys , k )
612- }
613- if len (mapKeys ) > 0 {
614- n , err := rand .Int (rand .Reader , big .NewInt (int64 (len (mapKeys ))))
615- if err != nil {
616- return "" , ""
617- }
618- randNsName := mapKeys [n .Int64 ()]
619- randNsAddresses := nameservers [randNsName ]
620- if randNsAddresses == nil {
621- return "" , ""
622- }
623- n , err = rand .Int (rand .Reader , big .NewInt (int64 (len (randNsAddresses ))))
624- if err != nil {
625- return "" , ""
626- }
627- randNsAddress := randNsAddresses [n .Int64 ()].String ()
628- return randNsName , randNsAddress
629- }
630- return "" , ""
631- }
632-
633610func createQuery (recordName string , recordType uint16 ) * dns.Msg {
634611 m := new (dns.Msg )
635612 m .SetQuestion (recordName , recordType )
0 commit comments