@@ -48,28 +48,18 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
4848 }
4949 }
5050
51- // Split query name into labels and lookup each domain and parent until we get a hit
52- queryLabels := dns .SplitDomainName (r .Question [0 ].Name )
53- for startLabelIdx := 0 ; startLabelIdx < len (queryLabels ); startLabelIdx ++ {
54- lookupDomainName := strings .Join (queryLabels [startLabelIdx :], "." )
55- nameServers , err := state .GetState ().LookupDomain (lookupDomainName )
56- if err != nil {
57- logger .Errorf ("failed to lookup domain: %s" , err )
58- }
59- if nameServers == nil {
60- continue
61- }
51+ nameserverDomain , nameservers , err := findNameserversForDomain (r .Question [0 ].Name )
52+ if err != nil {
53+ logger .Errorf ("failed to lookup nameservers for %s: %s" , r .Question [0 ].Name , err )
54+ }
55+ if nameservers != nil {
6256 // Assemble response
6357 m .SetReply (r )
6458 if cfg .Dns .RecursionEnabled {
6559 // Pick random nameserver for domain
66- tmpNameservers := []string {}
67- for nameserver := range nameServers {
68- tmpNameservers = append (tmpNameservers , nameserver )
69- }
70- tmpNameserver := nameServers [tmpNameservers [rand .Intn (len (tmpNameservers ))]]
60+ tmpNameserver := randomNameserverAddress (nameservers )
7161 // Query the random domain nameserver we picked above
72- resp , err := queryServer (r , tmpNameserver )
62+ resp , err := queryServer (r , tmpNameserver . String () )
7363 if err != nil {
7464 // Send failure response
7565 m .SetRcode (r , dns .RcodeServerFailure )
@@ -87,31 +77,30 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
8777 return
8878 }
8979 } else {
90- for nameserver , ipAddress := range nameServers {
91- // Add trailing dot to make everybody happy
92- nameserver = nameserver + `.`
80+ for nameserver , addresses := range nameservers {
9381 // NS record
9482 ns := & dns.NS {
95- Hdr : dns.RR_Header {Name : (lookupDomainName + `.` ), Rrtype : dns .TypeNS , Class : dns .ClassINET , Ttl : 999 },
83+ Hdr : dns.RR_Header {Name : (nameserverDomain ), Rrtype : dns .TypeNS , Class : dns .ClassINET , Ttl : 999 },
9684 Ns : nameserver ,
9785 }
9886 m .Ns = append (m .Ns , ns )
99- // A or AAAA record
100- ipAddr := net .ParseIP (ipAddress )
101- if ipAddr .To4 () != nil {
102- // IPv4
103- a := & dns.A {
104- Hdr : dns.RR_Header {Name : nameserver , Rrtype : dns .TypeA , Class : dns .ClassINET , Ttl : 999 },
105- A : ipAddr ,
106- }
107- m .Extra = append (m .Extra , a )
108- } else {
109- // IPv6
110- aaaa := & dns.AAAA {
111- Hdr : dns.RR_Header {Name : nameserver , Rrtype : dns .TypeAAAA , Class : dns .ClassINET , Ttl : 999 },
112- AAAA : ipAddr ,
87+ for _ , address := range addresses {
88+ // A or AAAA record
89+ if address .To4 () != nil {
90+ // IPv4
91+ a := & dns.A {
92+ Hdr : dns.RR_Header {Name : nameserver , Rrtype : dns .TypeA , Class : dns .ClassINET , Ttl : 999 },
93+ A : address ,
94+ }
95+ m .Extra = append (m .Extra , a )
96+ } else {
97+ // IPv6
98+ aaaa := & dns.AAAA {
99+ Hdr : dns.RR_Header {Name : nameserver , Rrtype : dns .TypeAAAA , Class : dns .ClassINET , Ttl : 999 },
100+ AAAA : address ,
101+ }
102+ m .Extra = append (m .Extra , aaaa )
113103 }
114- m .Extra = append (m .Extra , aaaa )
115104 }
116105 }
117106 }
@@ -178,3 +167,94 @@ func queryServer(req *dns.Msg, nameserver string) (*dns.Msg, error) {
178167 in , err := dns .Exchange (m , fmt .Sprintf ("%s:53" , nameserver ))
179168 return in , err
180169}
170+
171+ func randomNameserverAddress (nameservers map [string ][]net.IP ) net.IP {
172+ // Put all namserver addresses in single list
173+ tmpNameservers := []net.IP {}
174+ for _ , addresses := range nameservers {
175+ tmpNameservers = append (tmpNameservers , addresses ... )
176+ }
177+ tmpNameserver := tmpNameservers [rand .Intn (len (tmpNameservers ))]
178+ return tmpNameserver
179+ }
180+
181+ func doQuery (msg * dns.Msg , address string ) (* dns.Msg , error ) {
182+ logger := logging .GetLogger ()
183+ logger .Debugf ("querying %s: %s %s" , address , dns .Type (msg .Question [0 ].Qtype ).String (), msg .Question [0 ].Name )
184+ resp , err := dns .Exchange (msg , address )
185+ return resp , err
186+ }
187+
188+ func findNameserversForDomain (recordName string ) (string , map [string ][]net.IP , error ) {
189+ cfg := config .GetConfig ()
190+
191+ // Split record name into labels and lookup each domain and parent until we get a hit
192+ queryLabels := dns .SplitDomainName (recordName )
193+
194+ // Check on-chain domains first
195+ for startLabelIdx := 0 ; startLabelIdx < len (queryLabels ); startLabelIdx ++ {
196+ lookupDomainName := strings .Join (queryLabels [startLabelIdx :], "." )
197+ nameservers , err := state .GetState ().LookupDomain (lookupDomainName )
198+ if err != nil {
199+ return "" , nil , err
200+ }
201+ if nameservers != nil {
202+ ret := map [string ][]net.IP {}
203+ for k , v := range nameservers {
204+ k = k + `.`
205+ ret [k ] = append (ret [k ], net .ParseIP (v ))
206+ }
207+ return dns .Fqdn (lookupDomainName ), ret , nil
208+ }
209+ }
210+
211+ // Query fallback servers, if configured
212+ if len (cfg .Dns .FallbackServers ) > 0 {
213+ // Pick random fallback server
214+ fallbackServer := cfg .Dns .FallbackServers [rand .Intn (len (cfg .Dns .FallbackServers ))]
215+ serverWithPort := fmt .Sprintf ("%s:53" , fallbackServer )
216+ for startLabelIdx := 0 ; startLabelIdx < len (queryLabels ); startLabelIdx ++ {
217+ lookupDomainName := dns .Fqdn (strings .Join (queryLabels [startLabelIdx :], "." ))
218+ m := new (dns.Msg )
219+ m .SetQuestion (lookupDomainName , dns .TypeNS )
220+ m .RecursionDesired = false
221+ in , err := doQuery (m , serverWithPort )
222+ if err != nil {
223+ return "" , nil , err
224+ }
225+ if in .Rcode == dns .RcodeSuccess {
226+ if len (in .Answer ) > 0 {
227+ ret := map [string ][]net.IP {}
228+ for _ , answer := range in .Answer {
229+ switch v := answer .(type ) {
230+ case * dns.NS :
231+ ns := v .Ns
232+ ret [ns ] = make ([]net.IP , 0 )
233+ // Query for matching A/AAAA records
234+ m2 := new (dns.Msg )
235+ m2 .SetQuestion (ns , dns .TypeA )
236+ m2 .RecursionDesired = false
237+ in2 , err := doQuery (m2 , serverWithPort )
238+ if err != nil {
239+ return "" , nil , err
240+ }
241+ for _ , answer2 := range in2 .Answer {
242+ switch v := answer2 .(type ) {
243+ case * dns.A :
244+ ret [ns ] = append (ret [ns ], v .A )
245+ case * dns.AAAA :
246+ ret [ns ] = append (ret [ns ], v .AAAA )
247+ }
248+ }
249+ }
250+ }
251+ if len (ret ) > 0 {
252+ return lookupDomainName , ret , nil
253+ }
254+ }
255+ }
256+ }
257+ }
258+
259+ return "" , nil , nil
260+ }
0 commit comments