Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ type Flags struct {
UDPBuffer uint16 `long:"udp-buffer" description:"Set EDNS0 UDP size in query" default:"1232"`
Verbose bool `short:"v" long:"verbose" description:"Show verbose log messages"`
Trace bool `long:"trace" description:"Show trace log messages"`
Recursive bool `long:"recursive" description:"Do recursive query from authentic servers"`
ForceIPv4 bool `short:"4" long:"ipv4" description:"Force use the ipv4 address"`
ShowVersion bool `short:"V" long:"version" description:"Show version and exit"`
}

Expand Down
33 changes: 33 additions & 0 deletions ipv6.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package main

import (
"fmt"
"net"
)

func supportIPv6() (bool, error) {
interfaces, err := net.Interfaces()
if err != nil {
return false, fmt.Errorf("failed to get network interfaces: %w", err)
}

for _, iface := range interfaces {
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
continue
}

addrs, err := iface.Addrs()
if err != nil {
continue // Continue if we can't get addresses
}

for _, addr := range addrs {
n, ok := addr.(*net.IPNet)
if ok && n.IP.To4() == nil && n.IP.IsGlobalUnicast() {
return true, nil
}
}
}

return false, nil
}
84 changes: 80 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation

if opts.Verbose {
log.SetLevel(log.DebugLevel)
} else if opts.Trace {
} else if opts.Trace || opts.Recursive {
log.SetLevel(log.TraceLevel)
opts.ShowAll = true
}
Expand Down Expand Up @@ -334,7 +334,11 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation
if len(opts.Server) == 0 {
opts.Server = make([]string, 1)

if os.Getenv(defaultServerVar) != "" {
if opts.Recursive {
if err = initRootServer(); err != nil {
return err
}
} else if os.Getenv(defaultServerVar) != "" {
opts.Server[0] = os.Getenv(defaultServerVar)
log.Debugf("Using %s from %s environment variable", opts.Server, defaultServerVar)
} else {
Expand Down Expand Up @@ -414,11 +418,23 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation
}
msgs := createQuery(opts, rrTypesSlice)

if opts.Recursive {
if len(msgs) > 1 {
return fmt.Errorf("Only query one type in recursive mode")
}

opts.Timeout = 10 * time.Minute // FIXME(tao)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The underlying DNS transport has it's own timeout, but this overrides the flag if configured

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default timeout is 10s, which may be too short for recursive query.

}

errChan := make(chan error)

servers := opts.Server

go func() {
recursive:
var replies []*dns.Msg
var entries []*output.Entry
for _, serverStr := range opts.Server {
for _, serverStr := range servers {
// Parse server address and transport type
server, transportType, err := parseServer(serverStr)
if err != nil {
Expand All @@ -442,7 +458,6 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation
}

startTime := time.Now()
var replies []*dns.Msg
for _, msg := range msgs {
if txp == nil {
errChan <- fmt.Errorf("transport is nil")
Expand Down Expand Up @@ -531,6 +546,13 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation
errChan <- fmt.Errorf("invalid output format")
}

if opts.Recursive && len(replies) > 0 {
servers = getRecursiveServers(replies)
if len(servers) > 0 {
goto recursive
}
}

errChan <- nil
}()

Expand All @@ -544,6 +566,60 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation
return nil
}

func initRootServer() error {
ipv4s, ipv6s, err := getRootHints()
if err != nil {
return fmt.Errorf("unable to load root hints: %s", err)
}

hasIPv6, err := supportIPv6()
if err != nil {
return fmt.Errorf("unable to detect ipv6 support: %s", err)
}

if !hasIPv6 {
opts.ForceIPv4 = true
}

if opts.ForceIPv4 {
opts.Server = ipv4s[:1]
} else {
opts.Server = ipv6s[:1]
}
return nil
}

func getRecursiveServers(replies []*dns.Msg) (servers []string) {
if r := replies[0]; len(r.Answer) == 0 {
servers = []string{}
if opts.ForceIPv4 {
for _, extra := range r.Extra {
if a, ok := extra.(*dns.A); ok {
servers = append(servers, a.A.String())
break
}
}
} else {
for _, extra := range r.Extra {
if a, ok := extra.(*dns.AAAA); ok {
servers = append(servers, a.AAAA.String())
break
}
}
}

if len(servers) == 0 {
for _, ns := range r.Ns {
if a, ok := ns.(*dns.NS); ok {
servers = append(servers, a.Ns)
break
}
}
}
}
return
}

func main() {
clearOpts()
if err := driver(os.Args[1:], os.Stdout); err != nil {
Expand Down
26 changes: 26 additions & 0 deletions root.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package main

import (
"net/http"

"github.com/miekg/dns"
)

func getRootHints() (ip4s, ip6s []string, err error) {
resp, err := http.Get("https://www.internic.net/domain/named.root")
if err != nil {
return
}
defer resp.Body.Close()

p := dns.NewZoneParser(resp.Body, "", "")

for rr, ok := p.Next(); ok; rr, ok = p.Next() {
if a, ok := rr.(*dns.A); ok {
ip4s = append(ip4s, a.A.String())
} else if a, ok := rr.(*dns.AAAA); ok {
ip6s = append(ip6s, a.AAAA.String())
}
}
return
}
21 changes: 21 additions & 0 deletions root_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package main

import (
"testing"
)

func TestGetRootHints(t *testing.T) {
ip4s, ip6s, err := getRootHints()

if err != nil {
t.Fatal(err)
}

if len(ip4s) == 0 {
t.Fatal("ipv4 is empty")
}

if len(ip6s) == 0 {
t.Fatal("ipv6 is empty")
}
}
Loading