@@ -68,41 +68,25 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
6868 }
6969 }
7070
71- // Check for known record for domain nameserver
72- records , err := state .GetState ().LookupNameserverRecord (
71+ // Check for known record from local storage
72+ records , err := state .GetState ().LookupRecords (
73+ []string {dns .Type (r .Question [0 ].Qtype ).String ()},
7374 strings .TrimSuffix (r .Question [0 ].Name , "." ),
7475 )
7576 if err != nil {
76- logger .Errorf ("failed to lookup record in state: %s" , err )
77+ logger .Errorf ("failed to lookup records in state: %s" , err )
7778 return
7879 }
7980 if records != nil {
8081 // Assemble response
8182 m .SetReply (r )
82- for k , v := range records {
83- k = dns .Fqdn (k )
84- address := net .ParseIP (v )
85- // A or AAAA record
86- if address .To4 () != nil {
87- // IPv4
88- a := & dns.A {
89- Hdr : dns.RR_Header {
90- Name : k ,
91- Rrtype : dns .TypeA ,
92- Class : dns .ClassINET ,
93- Ttl : 999 ,
94- },
95- A : address ,
96- }
97- m .Answer = append (m .Answer , a )
98- } else {
99- // IPv6
100- aaaa := & dns.AAAA {
101- Hdr : dns.RR_Header {Name : k , Rrtype : dns .TypeAAAA , Class : dns .ClassINET , Ttl : 999 },
102- AAAA : address ,
103- }
104- m .Answer = append (m .Answer , aaaa )
83+ for _ , tmpRecord := range records {
84+ tmpRR , err := stateRecordToDnsRR (tmpRecord )
85+ if err != nil {
86+ logger .Errorf ("failed to convert state record to dns.RR: %s" , err )
87+ return
10588 }
89+ m .Answer = append (m .Answer , tmpRR )
10690 }
10791 // Send response
10892 if err := w .WriteMsg (m ); err != nil {
@@ -112,6 +96,7 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
11296 return
11397 }
11498
99+ // Check for any NS records for parent domains from local storage
115100 nameserverDomain , nameservers , err := findNameserversForDomain (
116101 r .Question [0 ].Name ,
117102 )
@@ -182,13 +167,52 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
182167 return
183168 }
184169
170+ // Query fallback servers, if configured
171+ if len (cfg .Dns .FallbackServers ) > 0 {
172+ // Pick random fallback server
173+ fallbackServer := randomFallbackServer ()
174+ // Pass along query to chosen fallback server
175+ resp , err := doQuery (r , fallbackServer , false )
176+ if err != nil {
177+ // Send failure response
178+ m .SetRcode (r , dns .RcodeServerFailure )
179+ if err := w .WriteMsg (m ); err != nil {
180+ logger .Errorf ("failed to write response: %s" , err )
181+ }
182+ logger .Errorf ("failed to query domain nameserver: %s" , err )
183+ return
184+ } else {
185+ copyResponse (r , resp , m )
186+ // Send response
187+ if err := w .WriteMsg (m ); err != nil {
188+ logger .Errorf ("failed to write response: %s" , err )
189+ }
190+ return
191+ }
192+ }
193+
185194 // Return NXDOMAIN if we have no information about the requested domain or any of its parents
186195 m .SetRcode (r , dns .RcodeNameError )
187196 if err := w .WriteMsg (m ); err != nil {
188197 logger .Errorf ("failed to write response: %s" , err )
189198 }
190199}
191200
201+ func stateRecordToDnsRR (record state.DomainRecord ) (dns.RR , error ) {
202+ tmpTtl := ""
203+ if record .Ttl > 0 {
204+ tmpTtl = fmt .Sprintf ("%d" , record .Ttl )
205+ }
206+ tmpRR := fmt .Sprintf (
207+ "%s %s IN %s %s" ,
208+ record .Lhs ,
209+ tmpTtl ,
210+ record .Type ,
211+ record .Rhs ,
212+ )
213+ return dns .NewRR (tmpRR )
214+ }
215+
192216func copyResponse (req * dns.Msg , srcResp * dns.Msg , destResp * dns.Msg ) {
193217 // Copy relevant data from original request and source response into destination response
194218 destResp .SetRcode (req , srcResp .MsgHdr .Rcode )
@@ -279,8 +303,6 @@ func doQuery(msg *dns.Msg, address string, recursive bool) (*dns.Msg, error) {
279303func findNameserversForDomain (
280304 recordName string ,
281305) (string , map [string ][]net.IP , error ) {
282- cfg := config .GetConfig ()
283-
284306 // Split record name into labels and lookup each domain and parent until we get a hit
285307 queryLabels := dns .SplitDomainName (recordName )
286308
@@ -314,51 +336,6 @@ func findNameserversForDomain(
314336 }
315337 }
316338
317- // Query fallback servers, if configured
318- if len (cfg .Dns .FallbackServers ) > 0 {
319- // Pick random fallback server
320- fallbackServer := randomFallbackServer ()
321- for startLabelIdx := 0 ; startLabelIdx < len (queryLabels ); startLabelIdx ++ {
322- lookupDomainName := dns .Fqdn (
323- strings .Join (queryLabels [startLabelIdx :], "." ),
324- )
325- m := createQuery (lookupDomainName , dns .TypeNS )
326- in , err := doQuery (m , fallbackServer , false )
327- if err != nil {
328- return "" , nil , err
329- }
330- if in .Rcode == dns .RcodeSuccess {
331- if len (in .Answer ) > 0 {
332- ret := map [string ][]net.IP {}
333- for _ , answer := range in .Answer {
334- switch v := answer .(type ) {
335- case * dns.NS :
336- ns := v .Ns
337- ret [ns ] = make ([]net.IP , 0 )
338- // Query for matching A/AAAA records
339- m2 := createQuery (ns , dns .TypeA )
340- in2 , err := doQuery (m2 , fallbackServer , false )
341- if err != nil {
342- return "" , nil , err
343- }
344- for _ , answer2 := range in2 .Answer {
345- switch v := answer2 .(type ) {
346- case * dns.A :
347- ret [ns ] = append (ret [ns ], v .A )
348- case * dns.AAAA :
349- ret [ns ] = append (ret [ns ], v .AAAA )
350- }
351- }
352- }
353- }
354- if len (ret ) > 0 {
355- return lookupDomainName , ret , nil
356- }
357- }
358- }
359- }
360- }
361-
362339 return "" , nil , nil
363340}
364341
0 commit comments