@@ -35,6 +35,7 @@ func startListener(server *dns.Server) {
3535
3636func handleQuery (w dns.ResponseWriter , r * dns.Msg ) {
3737 logger := logging .GetLogger ()
38+ cfg := config .GetConfig ()
3839 m := new (dns.Msg )
3940
4041 // Split query name into labels and lookup each domain and parent until we get a hit
@@ -50,31 +51,58 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
5051 }
5152 // Assemble response
5253 m .SetReply (r )
53- for nameserver , ipAddress := range nameServers {
54- // Add trailing dot to make everybody happy
55- nameserver = nameserver + `.`
56- // NS record
57- ns := & dns.NS {
58- Hdr : dns.RR_Header {Name : (lookupDomainName + `.` ), Rrtype : dns .TypeNS , Class : dns .ClassINET , Ttl : 999 },
59- Ns : nameserver ,
54+ if cfg .Dns .RecursionEnabled {
55+ // Pick random nameserver for domain
56+ tmpNameservers := []string {}
57+ for nameserver := range nameServers {
58+ tmpNameservers = append (tmpNameservers , nameserver )
6059 }
61- m . Ns = append ( m . Ns , ns )
62- // A or AAAA record
63- ipAddr := net . ParseIP ( ipAddress )
64- if ipAddr . To4 () != nil {
65- // IPv4
66- a := & dns.A {
67- Hdr : dns. RR_Header { Name : nameserver , Rrtype : dns . TypeA , Class : dns . ClassINET , Ttl : 999 },
68- A : ipAddr ,
60+ tmpNameserver := nameServers [ tmpNameservers [ rand . Intn ( len ( tmpNameservers ))]]
61+ // Query the random domain nameserver we picked above
62+ resp , err := queryServer ( r , tmpNameserver )
63+ if err != nil {
64+ // Send failure response
65+ m . SetRcode ( r , dns .RcodeServerFailure )
66+ if err := w . WriteMsg ( m ); err != nil {
67+ logger . Errorf ( "failed to write response: %s" , err )
6968 }
70- m .Extra = append (m .Extra , a )
69+ logger .Errorf ("failed to query domain nameserver: %s" , err )
70+ return
7171 } else {
72- // IPv6
73- aaaa := & dns.AAAA {
74- Hdr : dns.RR_Header {Name : nameserver , Rrtype : dns .TypeAAAA , Class : dns .ClassINET , Ttl : 999 },
75- AAAA : ipAddr ,
72+ copyResponse (r , resp , m )
73+ // Send response
74+ if err := w .WriteMsg (m ); err != nil {
75+ logger .Errorf ("failed to write response: %s" , err )
76+ }
77+ return
78+ }
79+ } else {
80+ for nameserver , ipAddress := range nameServers {
81+ // Add trailing dot to make everybody happy
82+ nameserver = nameserver + `.`
83+ // NS record
84+ ns := & dns.NS {
85+ Hdr : dns.RR_Header {Name : (lookupDomainName + `.` ), Rrtype : dns .TypeNS , Class : dns .ClassINET , Ttl : 999 },
86+ Ns : nameserver ,
87+ }
88+ m .Ns = append (m .Ns , ns )
89+ // A or AAAA record
90+ ipAddr := net .ParseIP (ipAddress )
91+ if ipAddr .To4 () != nil {
92+ // IPv4
93+ a := & dns.A {
94+ Hdr : dns.RR_Header {Name : nameserver , Rrtype : dns .TypeA , Class : dns .ClassINET , Ttl : 999 },
95+ A : ipAddr ,
96+ }
97+ m .Extra = append (m .Extra , a )
98+ } else {
99+ // IPv6
100+ aaaa := & dns.AAAA {
101+ Hdr : dns.RR_Header {Name : nameserver , Rrtype : dns .TypeAAAA , Class : dns .ClassINET , Ttl : 999 },
102+ AAAA : ipAddr ,
103+ }
104+ m .Extra = append (m .Extra , aaaa )
76105 }
77- m .Extra = append (m .Extra , aaaa )
78106 }
79107 }
80108 // Send response
@@ -85,35 +113,28 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
85113 return
86114 }
87115
88- // Query fallback servers
89- fallbackResp , err := queryFallbackServer (r )
90- if err != nil {
91- // Send failure response
92- m .SetRcode (r , dns .RcodeServerFailure )
93- if err := w .WriteMsg (m ); err != nil {
94- logger .Errorf ("failed to write response: %s" , err )
95- }
96- logger .Errorf ("failed to query fallback server: %s" , err )
97- return
98- } else {
99- // Copy relevant data from fallback response into our response
100- m .SetRcode (r , fallbackResp .MsgHdr .Rcode )
101- m .RecursionDesired = r .RecursionDesired
102- m .RecursionAvailable = fallbackResp .RecursionAvailable
103- if fallbackResp .Ns != nil {
104- m .Ns = append (m .Ns , fallbackResp .Ns ... )
105- }
106- if fallbackResp .Answer != nil {
107- m .Answer = append (m .Answer , fallbackResp .Answer ... )
108- }
109- if fallbackResp .Extra != nil {
110- m .Extra = append (m .Extra , fallbackResp .Extra ... )
111- }
112- // Send response
113- if err := w .WriteMsg (m ); err != nil {
114- logger .Errorf ("failed to write response: %s" , err )
116+ // Query fallback servers if recursion is enabled
117+ if cfg .Dns .RecursionEnabled {
118+ // Pick random fallback server
119+ fallbackServer := cfg .Dns .FallbackServers [rand .Intn (len (cfg .Dns .FallbackServers ))]
120+ // Query chosen server
121+ fallbackResp , err := queryServer (r , fallbackServer )
122+ if err != nil {
123+ // Send failure response
124+ m .SetRcode (r , dns .RcodeServerFailure )
125+ if err := w .WriteMsg (m ); err != nil {
126+ logger .Errorf ("failed to write response: %s" , err )
127+ }
128+ logger .Errorf ("failed to query fallback server: %s" , err )
129+ return
130+ } else {
131+ copyResponse (r , fallbackResp , m )
132+ // Send response
133+ if err := w .WriteMsg (m ); err != nil {
134+ logger .Errorf ("failed to write response: %s" , err )
135+ }
136+ return
115137 }
116- return
117138 }
118139
119140 // Return NXDOMAIN if we have no information about the requested domain or any of its parents
@@ -123,15 +144,27 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
123144 }
124145}
125146
126- func queryFallbackServer (req * dns.Msg ) (* dns.Msg , error ) {
127- // Pick random fallback server
128- cfg := config .GetConfig ()
129- fallbackServer := cfg .Dns .FallbackServers [rand .Intn (len (cfg .Dns .FallbackServers ))]
130- // Query chosen server
147+ func copyResponse (req * dns.Msg , srcResp * dns.Msg , destResp * dns.Msg ) {
148+ // Copy relevant data from original request and source response into destination response
149+ destResp .SetRcode (req , srcResp .MsgHdr .Rcode )
150+ destResp .RecursionDesired = req .RecursionDesired
151+ destResp .RecursionAvailable = srcResp .RecursionAvailable
152+ if srcResp .Ns != nil {
153+ destResp .Ns = append (destResp .Ns , srcResp .Ns ... )
154+ }
155+ if srcResp .Answer != nil {
156+ destResp .Answer = append (destResp .Answer , srcResp .Answer ... )
157+ }
158+ if srcResp .Extra != nil {
159+ destResp .Extra = append (destResp .Extra , srcResp .Extra ... )
160+ }
161+ }
162+
163+ func queryServer (req * dns.Msg , nameserver string ) (* dns.Msg , error ) {
131164 m := new (dns.Msg )
132165 m .Id = dns .Id ()
133166 m .RecursionDesired = req .RecursionDesired
134167 m .Question = append (m .Question , req .Question ... )
135- in , err := dns .Exchange (m , fmt .Sprintf ("%s:53" , fallbackServer ))
168+ in , err := dns .Exchange (m , fmt .Sprintf ("%s:53" , nameserver ))
136169 return in , err
137170}
0 commit comments