Skip to content

Commit 140fb83

Browse files
authored
Merge pull request #25 from blinklabs-io/feat/dns-recursion
feat: recursive DNS lookup support
2 parents e4cf123 + bae78ff commit 140fb83

File tree

2 files changed

+92
-58
lines changed

2 files changed

+92
-58
lines changed

internal/config/config.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ type LoggingConfig struct {
4141
}
4242

4343
type DnsConfig struct {
44-
ListenAddress string `yaml:"address" envconfig:"DNS_LISTEN_ADDRESS"`
45-
ListenPort uint `yaml:"port" envconfig:"DNS_LISTEN_PORT"`
46-
FallbackServers []string `yaml:"fallbackServers" envconfig:"DNS_FALLBACK_SERVERS"`
44+
ListenAddress string `yaml:"address" envconfig:"DNS_LISTEN_ADDRESS"`
45+
ListenPort uint `yaml:"port" envconfig:"DNS_LISTEN_PORT"`
46+
RecursionEnabled bool `yaml:"recursionEnabled" envconfig:"DNS_RECURSION"`
47+
FallbackServers []string `yaml:"fallbackServers" envconfig:"DNS_FALLBACK_SERVERS"`
4748
}
4849

4950
type DebugConfig struct {

internal/dns/dns.go

Lines changed: 88 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ func startListener(server *dns.Server) {
3535

3636
func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
3737
logger := logging.GetLogger()
38+
cfg := config.GetConfig()
3839
m := new(dns.Msg)
3940

4041
// Split query name into labels and lookup each domain and parent until we get a hit
@@ -50,31 +51,58 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
5051
}
5152
// Assemble response
5253
m.SetReply(r)
53-
for nameserver, ipAddress := range nameServers {
54-
// Add trailing dot to make everybody happy
55-
nameserver = nameserver + `.`
56-
// NS record
57-
ns := &dns.NS{
58-
Hdr: dns.RR_Header{Name: (lookupDomainName + `.`), Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 999},
59-
Ns: nameserver,
54+
if cfg.Dns.RecursionEnabled {
55+
// Pick random nameserver for domain
56+
tmpNameservers := []string{}
57+
for nameserver := range nameServers {
58+
tmpNameservers = append(tmpNameservers, nameserver)
6059
}
61-
m.Ns = append(m.Ns, ns)
62-
// A or AAAA record
63-
ipAddr := net.ParseIP(ipAddress)
64-
if ipAddr.To4() != nil {
65-
// IPv4
66-
a := &dns.A{
67-
Hdr: dns.RR_Header{Name: nameserver, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 999},
68-
A: ipAddr,
60+
tmpNameserver := nameServers[tmpNameservers[rand.Intn(len(tmpNameservers))]]
61+
// Query the random domain nameserver we picked above
62+
resp, err := queryServer(r, tmpNameserver)
63+
if err != nil {
64+
// Send failure response
65+
m.SetRcode(r, dns.RcodeServerFailure)
66+
if err := w.WriteMsg(m); err != nil {
67+
logger.Errorf("failed to write response: %s", err)
6968
}
70-
m.Extra = append(m.Extra, a)
69+
logger.Errorf("failed to query domain nameserver: %s", err)
70+
return
7171
} else {
72-
// IPv6
73-
aaaa := &dns.AAAA{
74-
Hdr: dns.RR_Header{Name: nameserver, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 999},
75-
AAAA: ipAddr,
72+
copyResponse(r, resp, m)
73+
// Send response
74+
if err := w.WriteMsg(m); err != nil {
75+
logger.Errorf("failed to write response: %s", err)
76+
}
77+
return
78+
}
79+
} else {
80+
for nameserver, ipAddress := range nameServers {
81+
// Add trailing dot to make everybody happy
82+
nameserver = nameserver + `.`
83+
// NS record
84+
ns := &dns.NS{
85+
Hdr: dns.RR_Header{Name: (lookupDomainName + `.`), Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 999},
86+
Ns: nameserver,
87+
}
88+
m.Ns = append(m.Ns, ns)
89+
// A or AAAA record
90+
ipAddr := net.ParseIP(ipAddress)
91+
if ipAddr.To4() != nil {
92+
// IPv4
93+
a := &dns.A{
94+
Hdr: dns.RR_Header{Name: nameserver, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 999},
95+
A: ipAddr,
96+
}
97+
m.Extra = append(m.Extra, a)
98+
} else {
99+
// IPv6
100+
aaaa := &dns.AAAA{
101+
Hdr: dns.RR_Header{Name: nameserver, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 999},
102+
AAAA: ipAddr,
103+
}
104+
m.Extra = append(m.Extra, aaaa)
76105
}
77-
m.Extra = append(m.Extra, aaaa)
78106
}
79107
}
80108
// Send response
@@ -85,35 +113,28 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
85113
return
86114
}
87115

88-
// Query fallback servers
89-
fallbackResp, err := queryFallbackServer(r)
90-
if err != nil {
91-
// Send failure response
92-
m.SetRcode(r, dns.RcodeServerFailure)
93-
if err := w.WriteMsg(m); err != nil {
94-
logger.Errorf("failed to write response: %s", err)
95-
}
96-
logger.Errorf("failed to query fallback server: %s", err)
97-
return
98-
} else {
99-
// Copy relevant data from fallback response into our response
100-
m.SetRcode(r, fallbackResp.MsgHdr.Rcode)
101-
m.RecursionDesired = r.RecursionDesired
102-
m.RecursionAvailable = fallbackResp.RecursionAvailable
103-
if fallbackResp.Ns != nil {
104-
m.Ns = append(m.Ns, fallbackResp.Ns...)
105-
}
106-
if fallbackResp.Answer != nil {
107-
m.Answer = append(m.Answer, fallbackResp.Answer...)
108-
}
109-
if fallbackResp.Extra != nil {
110-
m.Extra = append(m.Extra, fallbackResp.Extra...)
111-
}
112-
// Send response
113-
if err := w.WriteMsg(m); err != nil {
114-
logger.Errorf("failed to write response: %s", err)
116+
// Query fallback servers if recursion is enabled
117+
if cfg.Dns.RecursionEnabled {
118+
// Pick random fallback server
119+
fallbackServer := cfg.Dns.FallbackServers[rand.Intn(len(cfg.Dns.FallbackServers))]
120+
// Query chosen server
121+
fallbackResp, err := queryServer(r, fallbackServer)
122+
if err != nil {
123+
// Send failure response
124+
m.SetRcode(r, dns.RcodeServerFailure)
125+
if err := w.WriteMsg(m); err != nil {
126+
logger.Errorf("failed to write response: %s", err)
127+
}
128+
logger.Errorf("failed to query fallback server: %s", err)
129+
return
130+
} else {
131+
copyResponse(r, fallbackResp, m)
132+
// Send response
133+
if err := w.WriteMsg(m); err != nil {
134+
logger.Errorf("failed to write response: %s", err)
135+
}
136+
return
115137
}
116-
return
117138
}
118139

119140
// Return NXDOMAIN if we have no information about the requested domain or any of its parents
@@ -123,15 +144,27 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
123144
}
124145
}
125146

126-
func queryFallbackServer(req *dns.Msg) (*dns.Msg, error) {
127-
// Pick random fallback server
128-
cfg := config.GetConfig()
129-
fallbackServer := cfg.Dns.FallbackServers[rand.Intn(len(cfg.Dns.FallbackServers))]
130-
// Query chosen server
147+
func copyResponse(req *dns.Msg, srcResp *dns.Msg, destResp *dns.Msg) {
148+
// Copy relevant data from original request and source response into destination response
149+
destResp.SetRcode(req, srcResp.MsgHdr.Rcode)
150+
destResp.RecursionDesired = req.RecursionDesired
151+
destResp.RecursionAvailable = srcResp.RecursionAvailable
152+
if srcResp.Ns != nil {
153+
destResp.Ns = append(destResp.Ns, srcResp.Ns...)
154+
}
155+
if srcResp.Answer != nil {
156+
destResp.Answer = append(destResp.Answer, srcResp.Answer...)
157+
}
158+
if srcResp.Extra != nil {
159+
destResp.Extra = append(destResp.Extra, srcResp.Extra...)
160+
}
161+
}
162+
163+
func queryServer(req *dns.Msg, nameserver string) (*dns.Msg, error) {
131164
m := new(dns.Msg)
132165
m.Id = dns.Id()
133166
m.RecursionDesired = req.RecursionDesired
134167
m.Question = append(m.Question, req.Question...)
135-
in, err := dns.Exchange(m, fmt.Sprintf("%s:53", fallbackServer))
168+
in, err := dns.Exchange(m, fmt.Sprintf("%s:53", nameserver))
136169
return in, err
137170
}

0 commit comments

Comments
 (0)