Skip to content

Commit a6312e5

Browse files
authored
Merge pull request #48 from blinklabs-io/feat/rework-recursion
feat: rework recursion support
2 parents e18102f + 952e840 commit a6312e5

File tree

1 file changed

+189
-49
lines changed

1 file changed

+189
-49
lines changed

internal/dns/dns.go

Lines changed: 189 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
5959
// Pick random nameserver for domain
6060
tmpNameserver := randomNameserverAddress(nameservers)
6161
// Query the random domain nameserver we picked above
62-
resp, err := queryServer(r, tmpNameserver.String())
62+
resp, err := doQuery(r, tmpNameserver.String(), true)
6363
if err != nil {
6464
// Send failure response
6565
m.SetRcode(r, dns.RcodeServerFailure)
@@ -112,30 +112,6 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
112112
return
113113
}
114114

115-
// Query fallback servers if recursion is enabled
116-
if cfg.Dns.RecursionEnabled {
117-
// Pick random fallback server
118-
fallbackServer := cfg.Dns.FallbackServers[rand.Intn(len(cfg.Dns.FallbackServers))]
119-
// Query chosen server
120-
fallbackResp, err := queryServer(r, fallbackServer)
121-
if err != nil {
122-
// Send failure response
123-
m.SetRcode(r, dns.RcodeServerFailure)
124-
if err := w.WriteMsg(m); err != nil {
125-
logger.Errorf("failed to write response: %s", err)
126-
}
127-
logger.Errorf("failed to query fallback server: %s", err)
128-
return
129-
} else {
130-
copyResponse(r, fallbackResp, m)
131-
// Send response
132-
if err := w.WriteMsg(m); err != nil {
133-
logger.Errorf("failed to write response: %s", err)
134-
}
135-
return
136-
}
137-
}
138-
139115
// Return NXDOMAIN if we have no information about the requested domain or any of its parents
140116
m.SetRcode(r, dns.RcodeNameError)
141117
if err := w.WriteMsg(m); err != nil {
@@ -159,30 +135,64 @@ func copyResponse(req *dns.Msg, srcResp *dns.Msg, destResp *dns.Msg) {
159135
}
160136
}
161137

162-
func queryServer(req *dns.Msg, nameserver string) (*dns.Msg, error) {
163-
m := new(dns.Msg)
164-
m.Id = dns.Id()
165-
m.RecursionDesired = req.RecursionDesired
166-
m.Question = append(m.Question, req.Question...)
167-
in, err := dns.Exchange(m, fmt.Sprintf("%s:53", nameserver))
168-
return in, err
169-
}
170-
171138
func randomNameserverAddress(nameservers map[string][]net.IP) net.IP {
172139
// Put all namserver addresses in single list
173140
tmpNameservers := []net.IP{}
174141
for _, addresses := range nameservers {
175142
tmpNameservers = append(tmpNameservers, addresses...)
176143
}
177-
tmpNameserver := tmpNameservers[rand.Intn(len(tmpNameservers))]
178-
return tmpNameserver
144+
if len(tmpNameservers) > 0 {
145+
tmpNameserver := tmpNameservers[rand.Intn(len(tmpNameservers))]
146+
return tmpNameserver
147+
}
148+
return nil
179149
}
180150

181-
func doQuery(msg *dns.Msg, address string) (*dns.Msg, error) {
151+
func doQuery(msg *dns.Msg, address string, recursive bool) (*dns.Msg, error) {
182152
logger := logging.GetLogger()
183-
logger.Debugf("querying %s: %s %s", address, dns.Type(msg.Question[0].Qtype).String(), msg.Question[0].Name)
153+
// Default to a random fallback server if no address is specified
154+
if address == "" {
155+
address = randomFallbackServer()
156+
}
157+
// Add default port to address if there is none
158+
if !strings.Contains(address, ":") {
159+
address = address + `:53`
160+
}
161+
logger.Debugf("querying %s: %s", address, formatMessageQuestionSection(msg.Question))
184162
resp, err := dns.Exchange(msg, address)
185-
return resp, err
163+
if err != nil {
164+
return nil, err
165+
}
166+
logger.Debugf("response: rcode=%s, authoritative=%v, authority=%s, answer=%s, extra=%s", dns.RcodeToString[resp.Rcode], resp.Authoritative, formatMessageAnswerSection(resp.Ns), formatMessageAnswerSection(resp.Answer), formatMessageAnswerSection(resp.Extra))
167+
// Immediately return authoritative response
168+
if resp.Authoritative {
169+
return resp, nil
170+
}
171+
if recursive {
172+
if len(resp.Ns) > 0 {
173+
nameservers := getNameserversFromResponse(resp)
174+
randNsName, randNsAddress := randomNameserver(nameservers)
175+
if randNsAddress == "" {
176+
m := createQuery(randNsName, dns.TypeA)
177+
// XXX: should this query the fallback servers or the server that gave us the NS response?
178+
resp, err := doQuery(m, "", false)
179+
if err != nil {
180+
return nil, err
181+
}
182+
randNsAddress = getAddressForNameFromResponse(resp, randNsName)
183+
if randNsAddress == "" {
184+
// Return the current response if we couldn't get an address for the nameserver
185+
return resp, nil
186+
}
187+
}
188+
// Perform recursive query
189+
return doQuery(msg, randNsAddress, true)
190+
} else {
191+
// Return the current response if there is no authority information
192+
return resp, nil
193+
}
194+
}
195+
return resp, nil
186196
}
187197

188198
func findNameserversForDomain(recordName string) (string, map[string][]net.IP, error) {
@@ -211,14 +221,11 @@ func findNameserversForDomain(recordName string) (string, map[string][]net.IP, e
211221
// Query fallback servers, if configured
212222
if len(cfg.Dns.FallbackServers) > 0 {
213223
// Pick random fallback server
214-
fallbackServer := cfg.Dns.FallbackServers[rand.Intn(len(cfg.Dns.FallbackServers))]
215-
serverWithPort := fmt.Sprintf("%s:53", fallbackServer)
224+
fallbackServer := randomFallbackServer()
216225
for startLabelIdx := 0; startLabelIdx < len(queryLabels); startLabelIdx++ {
217226
lookupDomainName := dns.Fqdn(strings.Join(queryLabels[startLabelIdx:], "."))
218-
m := new(dns.Msg)
219-
m.SetQuestion(lookupDomainName, dns.TypeNS)
220-
m.RecursionDesired = false
221-
in, err := doQuery(m, serverWithPort)
227+
m := createQuery(lookupDomainName, dns.TypeNS)
228+
in, err := doQuery(m, fallbackServer, false)
222229
if err != nil {
223230
return "", nil, err
224231
}
@@ -231,10 +238,8 @@ func findNameserversForDomain(recordName string) (string, map[string][]net.IP, e
231238
ns := v.Ns
232239
ret[ns] = make([]net.IP, 0)
233240
// Query for matching A/AAAA records
234-
m2 := new(dns.Msg)
235-
m2.SetQuestion(ns, dns.TypeA)
236-
m2.RecursionDesired = false
237-
in2, err := doQuery(m2, serverWithPort)
241+
m2 := createQuery(ns, dns.TypeA)
242+
in2, err := doQuery(m2, fallbackServer, false)
238243
if err != nil {
239244
return "", nil, err
240245
}
@@ -258,3 +263,138 @@ func findNameserversForDomain(recordName string) (string, map[string][]net.IP, e
258263

259264
return "", nil, nil
260265
}
266+
267+
func getNameserversFromResponse(msg *dns.Msg) map[string][]net.IP {
268+
if len(msg.Ns) == 0 {
269+
return nil
270+
}
271+
ret := map[string][]net.IP{}
272+
for _, ns := range msg.Ns {
273+
// TODO: handle SOA
274+
switch v := ns.(type) {
275+
case *dns.NS:
276+
nsName := v.Ns
277+
ret[nsName] = []net.IP{}
278+
for _, extra := range msg.Extra {
279+
if extra.Header().Name != nsName {
280+
continue
281+
}
282+
switch v := extra.(type) {
283+
case *dns.A:
284+
ret[nsName] = append(
285+
ret[nsName],
286+
v.A,
287+
)
288+
case *dns.AAAA:
289+
ret[nsName] = append(
290+
ret[nsName],
291+
v.AAAA,
292+
)
293+
}
294+
}
295+
}
296+
}
297+
return ret
298+
}
299+
300+
func getAddressForNameFromResponse(msg *dns.Msg, recordName string) string {
301+
var retRR dns.RR
302+
for _, answer := range msg.Answer {
303+
if answer.Header().Name == recordName {
304+
retRR = answer
305+
break
306+
}
307+
}
308+
if retRR == nil {
309+
for _, extra := range msg.Extra {
310+
if extra.Header().Name == recordName {
311+
retRR = extra
312+
break
313+
}
314+
}
315+
}
316+
if retRR == nil {
317+
return ""
318+
}
319+
switch v := retRR.(type) {
320+
case *dns.A:
321+
return v.A.String()
322+
case *dns.AAAA:
323+
return v.AAAA.String()
324+
}
325+
return ""
326+
}
327+
328+
func randomNameserver(nameservers map[string][]net.IP) (string, string) {
329+
mapKeys := []string{}
330+
for k := range nameservers {
331+
mapKeys = append(mapKeys, k)
332+
}
333+
if len(mapKeys) > 0 {
334+
randNsName := mapKeys[rand.Intn(len(mapKeys))]
335+
randNsAddresses := nameservers[randNsName]
336+
randNsAddress := randNsAddresses[rand.Intn(len(randNsAddresses))].String()
337+
return randNsName, randNsAddress
338+
}
339+
return "", ""
340+
}
341+
342+
func createQuery(recordName string, recordType uint16) *dns.Msg {
343+
m := new(dns.Msg)
344+
m.SetQuestion(recordName, recordType)
345+
m.RecursionDesired = false
346+
return m
347+
}
348+
349+
func randomFallbackServer() string {
350+
cfg := config.GetConfig()
351+
return cfg.Dns.FallbackServers[rand.Intn(
352+
len(cfg.Dns.FallbackServers),
353+
)]
354+
}
355+
356+
func formatMessageAnswerSection(section []dns.RR) string {
357+
ret := "[ "
358+
for idx, rr := range section {
359+
ret += fmt.Sprintf(
360+
"< %s >",
361+
strings.ReplaceAll(
362+
strings.TrimPrefix(
363+
rr.String(),
364+
";",
365+
),
366+
"\t",
367+
" ",
368+
),
369+
)
370+
if idx != len(section)-1 {
371+
ret += `,`
372+
}
373+
ret += ` `
374+
}
375+
ret += "]"
376+
return ret
377+
}
378+
379+
func formatMessageQuestionSection(section []dns.Question) string {
380+
ret := "[ "
381+
for idx, question := range section {
382+
ret += fmt.Sprintf(
383+
"< %s >",
384+
strings.ReplaceAll(
385+
strings.TrimPrefix(
386+
question.String(),
387+
";",
388+
),
389+
"\t",
390+
" ",
391+
),
392+
)
393+
if idx != len(section)-1 {
394+
ret += `,`
395+
}
396+
ret += ` `
397+
}
398+
ret += "]"
399+
return ret
400+
}

0 commit comments

Comments
 (0)