@@ -43,18 +43,21 @@ type DNSVerificationResult struct {
4343type DNSVerificationConfig struct {
4444 // Timeout for DNS queries (default: 10 seconds)
4545 Timeout time.Duration
46-
46+
4747 // MaxRetries for transient failures (default: 3)
4848 MaxRetries int
49-
49+
5050 // RetryDelay base delay between retries (default: 1 second)
5151 RetryDelay time.Duration
52-
52+
5353 // UseSecureResolvers enables use of secure DNS resolvers
5454 UseSecureResolvers bool
55-
55+
5656 // CustomResolvers allows specifying custom DNS servers
5757 CustomResolvers []string
58+
59+ // Resolver allows injecting a custom DNS resolver (primarily for testing)
60+ Resolver DNSResolver
5861}
5962
6063// DefaultDNSConfig returns the default configuration for DNS verification
@@ -108,7 +111,7 @@ func VerifyDNSRecord(domain, expectedToken string) (*DNSVerificationResult, erro
108111// VerifyDNSRecordWithConfig performs DNS verification with custom configuration
109112func VerifyDNSRecordWithConfig (domain , expectedToken string , config * DNSVerificationConfig ) (* DNSVerificationResult , error ) {
110113 startTime := time .Now ()
111-
114+
112115 // Input validation
113116 if domain == "" {
114117 return nil , & DNSVerificationError {
@@ -117,15 +120,15 @@ func VerifyDNSRecordWithConfig(domain, expectedToken string, config *DNSVerifica
117120 Message : "domain cannot be empty" ,
118121 }
119122 }
120-
123+
121124 if expectedToken == "" {
122125 return nil , & DNSVerificationError {
123126 Domain : domain ,
124127 Token : expectedToken ,
125128 Message : "token cannot be empty" ,
126129 }
127130 }
128-
131+
129132 // Validate token format
130133 if ! ValidateTokenFormat (expectedToken ) {
131134 return nil , & DNSVerificationError {
@@ -134,43 +137,43 @@ func VerifyDNSRecordWithConfig(domain, expectedToken string, config *DNSVerifica
134137 Message : "invalid token format" ,
135138 }
136139 }
137-
140+
138141 // Normalize domain (remove trailing dots, convert to lowercase)
139142 domain = strings .ToLower (strings .TrimSuffix (domain , "." ))
140-
143+
141144 log .Printf ("Starting DNS verification for domain: %s with token: %s" , domain , expectedToken )
142-
145+
143146 // Create context with timeout
144147 ctx , cancel := context .WithTimeout (context .Background (), config .Timeout )
145148 defer cancel ()
146-
149+
147150 // Perform verification with retries
148151 result , err := performDNSVerificationWithRetries (ctx , domain , expectedToken , config )
149-
152+
150153 // Calculate duration
151154 duration := time .Since (startTime )
152155 if result != nil {
153156 result .Duration = duration .String ()
154157 }
155-
156- log .Printf ("DNS verification completed for domain %s in %v: success=%t" ,
158+
159+ log .Printf ("DNS verification completed for domain %s in %v: success=%t" ,
157160 domain , duration , result != nil && result .Success )
158-
161+
159162 return result , err
160163}
161164
162165// performDNSVerificationWithRetries implements the retry logic for DNS verification
163166func performDNSVerificationWithRetries (ctx context.Context , domain , expectedToken string , config * DNSVerificationConfig ) (* DNSVerificationResult , error ) {
164167 var lastErr error
165168 var lastResult * DNSVerificationResult
166-
169+
167170 retryDelay := config .RetryDelay
168-
171+
169172 for attempt := 0 ; attempt <= config .MaxRetries ; attempt ++ {
170173 if attempt > 0 {
171- log .Printf ("DNS verification retry %d/%d for domain %s after %v delay" ,
174+ log .Printf ("DNS verification retry %d/%d for domain %s after %v delay" ,
172175 attempt , config .MaxRetries , domain , retryDelay )
173-
176+
174177 // Wait before retry with context cancellation support
175178 select {
176179 case <- time .After (retryDelay ):
@@ -182,38 +185,43 @@ func performDNSVerificationWithRetries(ctx context.Context, domain, expectedToke
182185 Cause : ctx .Err (),
183186 }
184187 }
185-
188+
186189 // Exponential backoff
187190 retryDelay *= 2
188191 }
189-
192+
190193 result , err := performDNSVerification (ctx , domain , expectedToken , config )
191194 if err == nil {
192195 return result , nil
193196 }
194-
197+
195198 lastErr = err
196199 lastResult = result
197-
200+
198201 // Check if error is retryable
199202 if ! isRetryableDNSError (err ) {
200203 log .Printf ("Non-retryable DNS error for domain %s: %v" , domain , err )
201204 break
202205 }
203-
204- log .Printf ("Retryable DNS error for domain %s (attempt %d/%d): %v" ,
206+
207+ log .Printf ("Retryable DNS error for domain %s (attempt %d/%d): %v" ,
205208 domain , attempt + 1 , config .MaxRetries + 1 , err )
206209 }
207-
210+
208211 // All retries exhausted
209212 return lastResult , lastErr
210213}
211214
212215// performDNSVerification performs a single DNS verification attempt
213216func performDNSVerification (ctx context.Context , domain , expectedToken string , config * DNSVerificationConfig ) (* DNSVerificationResult , error ) {
214- // Create resolver
215- resolver := createDNSResolver (config )
216-
217+ // Get resolver (either injected or create default)
218+ var resolver DNSResolver
219+ if config .Resolver != nil {
220+ resolver = config .Resolver
221+ } else {
222+ resolver = NewDefaultDNSResolver (config )
223+ }
224+
217225 // Query TXT records
218226 txtRecords , err := resolver .LookupTXT (ctx , domain )
219227 if err != nil {
@@ -223,22 +231,22 @@ func performDNSVerification(ctx context.Context, domain, expectedToken string, c
223231 Message : "failed to query DNS TXT records" ,
224232 Cause : err ,
225233 }
226-
234+
227235 result := & DNSVerificationResult {
228236 Success : false ,
229237 Domain : domain ,
230238 Token : expectedToken ,
231239 Message : dnsErr .Message ,
232240 }
233-
241+
234242 return result , dnsErr
235243 }
236-
244+
237245 log .Printf ("Found %d TXT records for domain %s" , len (txtRecords ), domain )
238-
246+
239247 // Check for verification token
240248 expectedRecord := fmt .Sprintf ("mcp-verify=%s" , expectedToken )
241-
249+
242250 for _ , record := range txtRecords {
243251 log .Printf ("Checking TXT record: %s" , record )
244252 if record == expectedRecord {
@@ -249,12 +257,12 @@ func performDNSVerification(ctx context.Context, domain, expectedToken string, c
249257 Message : "domain verification successful" ,
250258 TXTRecords : txtRecords ,
251259 }
252-
260+
253261 log .Printf ("DNS verification successful for domain %s" , domain )
254262 return result , nil
255263 }
256264 }
257-
265+
258266 // Token not found
259267 result := & DNSVerificationResult {
260268 Success : false ,
@@ -263,67 +271,37 @@ func performDNSVerification(ctx context.Context, domain, expectedToken string, c
263271 Message : fmt .Sprintf ("verification token not found in DNS TXT records (expected: %s)" , expectedRecord ),
264272 TXTRecords : txtRecords ,
265273 }
266-
274+
267275 log .Printf ("DNS verification failed for domain %s: token not found" , domain )
268276 return result , nil
269277}
270278
271- // createDNSResolver creates a DNS resolver based on configuration
272- func createDNSResolver (config * DNSVerificationConfig ) * net.Resolver {
273- if config .UseSecureResolvers && len (config .CustomResolvers ) > 0 {
274- // Create custom dialer for secure resolvers
275- dialer := & net.Dialer {
276- Timeout : config .Timeout ,
277- }
278-
279- return & net.Resolver {
280- PreferGo : true ,
281- Dial : func (ctx context.Context , network , address string ) (net.Conn , error ) {
282- // Use first available custom resolver
283- // In a production system, you might want to implement round-robin or failover
284- for _ , resolver := range config .CustomResolvers {
285- conn , err := dialer .DialContext (ctx , network , resolver )
286- if err == nil {
287- log .Printf ("Using DNS resolver: %s" , resolver )
288- return conn , nil
289- }
290- log .Printf ("Failed to connect to DNS resolver %s: %v" , resolver , err )
291- }
292- return nil , fmt .Errorf ("all custom DNS resolvers failed" )
293- },
294- }
295- }
296-
297- // Use system default resolver
298- return net .DefaultResolver
299- }
300-
301279// isRetryableDNSError determines if a DNS error should be retried
302280func isRetryableDNSError (err error ) bool {
303281 if err == nil {
304282 return false
305283 }
306-
284+
307285 // Check for temporary network errors
308286 if netErr , ok := err .(* net.OpError ); ok {
309287 return netErr .Temporary ()
310288 }
311-
289+
312290 // Check for context timeout (might be temporary)
313291 if errors .Is (err , context .DeadlineExceeded ) {
314292 return true
315293 }
316-
294+
317295 // Check for DNS-specific temporary failures
318296 dnsErr , ok := err .(* net.DNSError )
319297 if ok {
320298 return dnsErr .Temporary ()
321299 }
322-
300+
323301 // Unwrap and check nested errors
324302 if unwrapped := errors .Unwrap (err ); unwrapped != nil {
325303 return isRetryableDNSError (unwrapped )
326304 }
327-
305+
328306 return false
329307}
0 commit comments