@@ -35,6 +35,7 @@ func startListener(server *dns.Server) {
35
35
36
36
func handleQuery (w dns.ResponseWriter , r * dns.Msg ) {
37
37
logger := logging .GetLogger ()
38
+ cfg := config .GetConfig ()
38
39
m := new (dns.Msg )
39
40
40
41
// Split query name into labels and lookup each domain and parent until we get a hit
@@ -50,31 +51,58 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
50
51
}
51
52
// Assemble response
52
53
m .SetReply (r )
53
- for nameserver , ipAddress := range nameServers {
54
- // Add trailing dot to make everybody happy
55
- nameserver = nameserver + `.`
56
- // NS record
57
- ns := & dns.NS {
58
- Hdr : dns.RR_Header {Name : (lookupDomainName + `.` ), Rrtype : dns .TypeNS , Class : dns .ClassINET , Ttl : 999 },
59
- Ns : nameserver ,
54
+ if cfg .Dns .RecursionEnabled {
55
+ // Pick random nameserver for domain
56
+ tmpNameservers := []string {}
57
+ for nameserver := range nameServers {
58
+ tmpNameservers = append (tmpNameservers , nameserver )
60
59
}
61
- m . Ns = append ( m . Ns , ns )
62
- // A or AAAA record
63
- ipAddr := net . ParseIP ( ipAddress )
64
- if ipAddr . To4 () != nil {
65
- // IPv4
66
- a := & dns.A {
67
- Hdr : dns. RR_Header { Name : nameserver , Rrtype : dns . TypeA , Class : dns . ClassINET , Ttl : 999 },
68
- A : ipAddr ,
60
+ tmpNameserver := nameServers [ tmpNameservers [ rand . Intn ( len ( tmpNameservers ))]]
61
+ // Query the random domain nameserver we picked above
62
+ resp , err := queryServer ( r , tmpNameserver )
63
+ if err != nil {
64
+ // Send failure response
65
+ m . SetRcode ( r , dns .RcodeServerFailure )
66
+ if err := w . WriteMsg ( m ); err != nil {
67
+ logger . Errorf ( "failed to write response: %s" , err )
69
68
}
70
- m .Extra = append (m .Extra , a )
69
+ logger .Errorf ("failed to query domain nameserver: %s" , err )
70
+ return
71
71
} else {
72
- // IPv6
73
- aaaa := & dns.AAAA {
74
- Hdr : dns.RR_Header {Name : nameserver , Rrtype : dns .TypeAAAA , Class : dns .ClassINET , Ttl : 999 },
75
- AAAA : ipAddr ,
72
+ copyResponse (r , resp , m )
73
+ // Send response
74
+ if err := w .WriteMsg (m ); err != nil {
75
+ logger .Errorf ("failed to write response: %s" , err )
76
+ }
77
+ return
78
+ }
79
+ } else {
80
+ for nameserver , ipAddress := range nameServers {
81
+ // Add trailing dot to make everybody happy
82
+ nameserver = nameserver + `.`
83
+ // NS record
84
+ ns := & dns.NS {
85
+ Hdr : dns.RR_Header {Name : (lookupDomainName + `.` ), Rrtype : dns .TypeNS , Class : dns .ClassINET , Ttl : 999 },
86
+ Ns : nameserver ,
87
+ }
88
+ m .Ns = append (m .Ns , ns )
89
+ // A or AAAA record
90
+ ipAddr := net .ParseIP (ipAddress )
91
+ if ipAddr .To4 () != nil {
92
+ // IPv4
93
+ a := & dns.A {
94
+ Hdr : dns.RR_Header {Name : nameserver , Rrtype : dns .TypeA , Class : dns .ClassINET , Ttl : 999 },
95
+ A : ipAddr ,
96
+ }
97
+ m .Extra = append (m .Extra , a )
98
+ } else {
99
+ // IPv6
100
+ aaaa := & dns.AAAA {
101
+ Hdr : dns.RR_Header {Name : nameserver , Rrtype : dns .TypeAAAA , Class : dns .ClassINET , Ttl : 999 },
102
+ AAAA : ipAddr ,
103
+ }
104
+ m .Extra = append (m .Extra , aaaa )
76
105
}
77
- m .Extra = append (m .Extra , aaaa )
78
106
}
79
107
}
80
108
// Send response
@@ -85,35 +113,28 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
85
113
return
86
114
}
87
115
88
- // Query fallback servers
89
- fallbackResp , err := queryFallbackServer (r )
90
- if err != nil {
91
- // Send failure response
92
- m .SetRcode (r , dns .RcodeServerFailure )
93
- if err := w .WriteMsg (m ); err != nil {
94
- logger .Errorf ("failed to write response: %s" , err )
95
- }
96
- logger .Errorf ("failed to query fallback server: %s" , err )
97
- return
98
- } else {
99
- // Copy relevant data from fallback response into our response
100
- m .SetRcode (r , fallbackResp .MsgHdr .Rcode )
101
- m .RecursionDesired = r .RecursionDesired
102
- m .RecursionAvailable = fallbackResp .RecursionAvailable
103
- if fallbackResp .Ns != nil {
104
- m .Ns = append (m .Ns , fallbackResp .Ns ... )
105
- }
106
- if fallbackResp .Answer != nil {
107
- m .Answer = append (m .Answer , fallbackResp .Answer ... )
108
- }
109
- if fallbackResp .Extra != nil {
110
- m .Extra = append (m .Extra , fallbackResp .Extra ... )
111
- }
112
- // Send response
113
- if err := w .WriteMsg (m ); err != nil {
114
- logger .Errorf ("failed to write response: %s" , err )
116
+ // Query fallback servers if recursion is enabled
117
+ if cfg .Dns .RecursionEnabled {
118
+ // Pick random fallback server
119
+ fallbackServer := cfg .Dns .FallbackServers [rand .Intn (len (cfg .Dns .FallbackServers ))]
120
+ // Query chosen server
121
+ fallbackResp , err := queryServer (r , fallbackServer )
122
+ if err != nil {
123
+ // Send failure response
124
+ m .SetRcode (r , dns .RcodeServerFailure )
125
+ if err := w .WriteMsg (m ); err != nil {
126
+ logger .Errorf ("failed to write response: %s" , err )
127
+ }
128
+ logger .Errorf ("failed to query fallback server: %s" , err )
129
+ return
130
+ } else {
131
+ copyResponse (r , fallbackResp , m )
132
+ // Send response
133
+ if err := w .WriteMsg (m ); err != nil {
134
+ logger .Errorf ("failed to write response: %s" , err )
135
+ }
136
+ return
115
137
}
116
- return
117
138
}
118
139
119
140
// Return NXDOMAIN if we have no information about the requested domain or any of its parents
@@ -123,15 +144,27 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
123
144
}
124
145
}
125
146
126
- func queryFallbackServer (req * dns.Msg ) (* dns.Msg , error ) {
127
- // Pick random fallback server
128
- cfg := config .GetConfig ()
129
- fallbackServer := cfg .Dns .FallbackServers [rand .Intn (len (cfg .Dns .FallbackServers ))]
130
- // Query chosen server
147
+ func copyResponse (req * dns.Msg , srcResp * dns.Msg , destResp * dns.Msg ) {
148
+ // Copy relevant data from original request and source response into destination response
149
+ destResp .SetRcode (req , srcResp .MsgHdr .Rcode )
150
+ destResp .RecursionDesired = req .RecursionDesired
151
+ destResp .RecursionAvailable = srcResp .RecursionAvailable
152
+ if srcResp .Ns != nil {
153
+ destResp .Ns = append (destResp .Ns , srcResp .Ns ... )
154
+ }
155
+ if srcResp .Answer != nil {
156
+ destResp .Answer = append (destResp .Answer , srcResp .Answer ... )
157
+ }
158
+ if srcResp .Extra != nil {
159
+ destResp .Extra = append (destResp .Extra , srcResp .Extra ... )
160
+ }
161
+ }
162
+
163
+ func queryServer (req * dns.Msg , nameserver string ) (* dns.Msg , error ) {
131
164
m := new (dns.Msg )
132
165
m .Id = dns .Id ()
133
166
m .RecursionDesired = req .RecursionDesired
134
167
m .Question = append (m .Question , req .Question ... )
135
- in , err := dns .Exchange (m , fmt .Sprintf ("%s:53" , fallbackServer ))
168
+ in , err := dns .Exchange (m , fmt .Sprintf ("%s:53" , nameserver ))
136
169
return in , err
137
170
}
0 commit comments