@@ -59,7 +59,7 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
5959 // Pick random nameserver for domain
6060 tmpNameserver := randomNameserverAddress (nameservers )
6161 // Query the random domain nameserver we picked above
62- resp , err := queryServer (r , tmpNameserver .String ())
62+ resp , err := doQuery (r , tmpNameserver .String (), true )
6363 if err != nil {
6464 // Send failure response
6565 m .SetRcode (r , dns .RcodeServerFailure )
@@ -112,30 +112,6 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
112112 return
113113 }
114114
115- // Query fallback servers if recursion is enabled
116- if cfg .Dns .RecursionEnabled {
117- // Pick random fallback server
118- fallbackServer := cfg .Dns .FallbackServers [rand .Intn (len (cfg .Dns .FallbackServers ))]
119- // Query chosen server
120- fallbackResp , err := queryServer (r , fallbackServer )
121- if err != nil {
122- // Send failure response
123- m .SetRcode (r , dns .RcodeServerFailure )
124- if err := w .WriteMsg (m ); err != nil {
125- logger .Errorf ("failed to write response: %s" , err )
126- }
127- logger .Errorf ("failed to query fallback server: %s" , err )
128- return
129- } else {
130- copyResponse (r , fallbackResp , m )
131- // Send response
132- if err := w .WriteMsg (m ); err != nil {
133- logger .Errorf ("failed to write response: %s" , err )
134- }
135- return
136- }
137- }
138-
139115 // Return NXDOMAIN if we have no information about the requested domain or any of its parents
140116 m .SetRcode (r , dns .RcodeNameError )
141117 if err := w .WriteMsg (m ); err != nil {
@@ -159,30 +135,64 @@ func copyResponse(req *dns.Msg, srcResp *dns.Msg, destResp *dns.Msg) {
159135 }
160136}
161137
162- func queryServer (req * dns.Msg , nameserver string ) (* dns.Msg , error ) {
163- m := new (dns.Msg )
164- m .Id = dns .Id ()
165- m .RecursionDesired = req .RecursionDesired
166- m .Question = append (m .Question , req .Question ... )
167- in , err := dns .Exchange (m , fmt .Sprintf ("%s:53" , nameserver ))
168- return in , err
169- }
170-
171138func randomNameserverAddress (nameservers map [string ][]net.IP ) net.IP {
172139 // Put all namserver addresses in single list
173140 tmpNameservers := []net.IP {}
174141 for _ , addresses := range nameservers {
175142 tmpNameservers = append (tmpNameservers , addresses ... )
176143 }
177- tmpNameserver := tmpNameservers [rand .Intn (len (tmpNameservers ))]
178- return tmpNameserver
144+ if len (tmpNameservers ) > 0 {
145+ tmpNameserver := tmpNameservers [rand .Intn (len (tmpNameservers ))]
146+ return tmpNameserver
147+ }
148+ return nil
179149}
180150
181- func doQuery (msg * dns.Msg , address string ) (* dns.Msg , error ) {
151+ func doQuery (msg * dns.Msg , address string , recursive bool ) (* dns.Msg , error ) {
182152 logger := logging .GetLogger ()
183- logger .Debugf ("querying %s: %s %s" , address , dns .Type (msg .Question [0 ].Qtype ).String (), msg .Question [0 ].Name )
153+ // Default to a random fallback server if no address is specified
154+ if address == "" {
155+ address = randomFallbackServer ()
156+ }
157+ // Add default port to address if there is none
158+ if ! strings .Contains (address , ":" ) {
159+ address = address + `:53`
160+ }
161+ logger .Debugf ("querying %s: %s" , address , formatMessageQuestionSection (msg .Question ))
184162 resp , err := dns .Exchange (msg , address )
185- return resp , err
163+ if err != nil {
164+ return nil , err
165+ }
166+ logger .Debugf ("response: rcode=%s, authoritative=%v, authority=%s, answer=%s, extra=%s" , dns .RcodeToString [resp .Rcode ], resp .Authoritative , formatMessageAnswerSection (resp .Ns ), formatMessageAnswerSection (resp .Answer ), formatMessageAnswerSection (resp .Extra ))
167+ // Immediately return authoritative response
168+ if resp .Authoritative {
169+ return resp , nil
170+ }
171+ if recursive {
172+ if len (resp .Ns ) > 0 {
173+ nameservers := getNameserversFromResponse (resp )
174+ randNsName , randNsAddress := randomNameserver (nameservers )
175+ if randNsAddress == "" {
176+ m := createQuery (randNsName , dns .TypeA )
177+ // XXX: should this query the fallback servers or the server that gave us the NS response?
178+ resp , err := doQuery (m , "" , false )
179+ if err != nil {
180+ return nil , err
181+ }
182+ randNsAddress = getAddressForNameFromResponse (resp , randNsName )
183+ if randNsAddress == "" {
184+ // Return the current response if we couldn't get an address for the nameserver
185+ return resp , nil
186+ }
187+ }
188+ // Perform recursive query
189+ return doQuery (msg , randNsAddress , true )
190+ } else {
191+ // Return the current response if there is no authority information
192+ return resp , nil
193+ }
194+ }
195+ return resp , nil
186196}
187197
188198func findNameserversForDomain (recordName string ) (string , map [string ][]net.IP , error ) {
@@ -211,14 +221,11 @@ func findNameserversForDomain(recordName string) (string, map[string][]net.IP, e
211221 // Query fallback servers, if configured
212222 if len (cfg .Dns .FallbackServers ) > 0 {
213223 // Pick random fallback server
214- fallbackServer := cfg .Dns .FallbackServers [rand .Intn (len (cfg .Dns .FallbackServers ))]
215- serverWithPort := fmt .Sprintf ("%s:53" , fallbackServer )
224+ fallbackServer := randomFallbackServer ()
216225 for startLabelIdx := 0 ; startLabelIdx < len (queryLabels ); startLabelIdx ++ {
217226 lookupDomainName := dns .Fqdn (strings .Join (queryLabels [startLabelIdx :], "." ))
218- m := new (dns.Msg )
219- m .SetQuestion (lookupDomainName , dns .TypeNS )
220- m .RecursionDesired = false
221- in , err := doQuery (m , serverWithPort )
227+ m := createQuery (lookupDomainName , dns .TypeNS )
228+ in , err := doQuery (m , fallbackServer , false )
222229 if err != nil {
223230 return "" , nil , err
224231 }
@@ -231,10 +238,8 @@ func findNameserversForDomain(recordName string) (string, map[string][]net.IP, e
231238 ns := v .Ns
232239 ret [ns ] = make ([]net.IP , 0 )
233240 // Query for matching A/AAAA records
234- m2 := new (dns.Msg )
235- m2 .SetQuestion (ns , dns .TypeA )
236- m2 .RecursionDesired = false
237- in2 , err := doQuery (m2 , serverWithPort )
241+ m2 := createQuery (ns , dns .TypeA )
242+ in2 , err := doQuery (m2 , fallbackServer , false )
238243 if err != nil {
239244 return "" , nil , err
240245 }
@@ -258,3 +263,138 @@ func findNameserversForDomain(recordName string) (string, map[string][]net.IP, e
258263
259264 return "" , nil , nil
260265}
266+
267+ func getNameserversFromResponse (msg * dns.Msg ) map [string ][]net.IP {
268+ if len (msg .Ns ) == 0 {
269+ return nil
270+ }
271+ ret := map [string ][]net.IP {}
272+ for _ , ns := range msg .Ns {
273+ // TODO: handle SOA
274+ switch v := ns .(type ) {
275+ case * dns.NS :
276+ nsName := v .Ns
277+ ret [nsName ] = []net.IP {}
278+ for _ , extra := range msg .Extra {
279+ if extra .Header ().Name != nsName {
280+ continue
281+ }
282+ switch v := extra .(type ) {
283+ case * dns.A :
284+ ret [nsName ] = append (
285+ ret [nsName ],
286+ v .A ,
287+ )
288+ case * dns.AAAA :
289+ ret [nsName ] = append (
290+ ret [nsName ],
291+ v .AAAA ,
292+ )
293+ }
294+ }
295+ }
296+ }
297+ return ret
298+ }
299+
300+ func getAddressForNameFromResponse (msg * dns.Msg , recordName string ) string {
301+ var retRR dns.RR
302+ for _ , answer := range msg .Answer {
303+ if answer .Header ().Name == recordName {
304+ retRR = answer
305+ break
306+ }
307+ }
308+ if retRR == nil {
309+ for _ , extra := range msg .Extra {
310+ if extra .Header ().Name == recordName {
311+ retRR = extra
312+ break
313+ }
314+ }
315+ }
316+ if retRR == nil {
317+ return ""
318+ }
319+ switch v := retRR .(type ) {
320+ case * dns.A :
321+ return v .A .String ()
322+ case * dns.AAAA :
323+ return v .AAAA .String ()
324+ }
325+ return ""
326+ }
327+
328+ func randomNameserver (nameservers map [string ][]net.IP ) (string , string ) {
329+ mapKeys := []string {}
330+ for k := range nameservers {
331+ mapKeys = append (mapKeys , k )
332+ }
333+ if len (mapKeys ) > 0 {
334+ randNsName := mapKeys [rand .Intn (len (mapKeys ))]
335+ randNsAddresses := nameservers [randNsName ]
336+ randNsAddress := randNsAddresses [rand .Intn (len (randNsAddresses ))].String ()
337+ return randNsName , randNsAddress
338+ }
339+ return "" , ""
340+ }
341+
342+ func createQuery (recordName string , recordType uint16 ) * dns.Msg {
343+ m := new (dns.Msg )
344+ m .SetQuestion (recordName , recordType )
345+ m .RecursionDesired = false
346+ return m
347+ }
348+
349+ func randomFallbackServer () string {
350+ cfg := config .GetConfig ()
351+ return cfg .Dns .FallbackServers [rand .Intn (
352+ len (cfg .Dns .FallbackServers ),
353+ )]
354+ }
355+
356+ func formatMessageAnswerSection (section []dns.RR ) string {
357+ ret := "[ "
358+ for idx , rr := range section {
359+ ret += fmt .Sprintf (
360+ "< %s >" ,
361+ strings .ReplaceAll (
362+ strings .TrimPrefix (
363+ rr .String (),
364+ ";" ,
365+ ),
366+ "\t " ,
367+ " " ,
368+ ),
369+ )
370+ if idx != len (section )- 1 {
371+ ret += `,`
372+ }
373+ ret += ` `
374+ }
375+ ret += "]"
376+ return ret
377+ }
378+
379+ func formatMessageQuestionSection (section []dns.Question ) string {
380+ ret := "[ "
381+ for idx , question := range section {
382+ ret += fmt .Sprintf (
383+ "< %s >" ,
384+ strings .ReplaceAll (
385+ strings .TrimPrefix (
386+ question .String (),
387+ ";" ,
388+ ),
389+ "\t " ,
390+ " " ,
391+ ),
392+ )
393+ if idx != len (section )- 1 {
394+ ret += `,`
395+ }
396+ ret += ` `
397+ }
398+ ret += "]"
399+ return ret
400+ }
0 commit comments