Skip to content

Commit 1202ae0

Browse files
committed
Implement DNS record verification for domain ownership validation
1 parent 60ac929 commit 1202ae0

File tree

5 files changed

+427
-107
lines changed

5 files changed

+427
-107
lines changed

internal/verification/dns.go

Lines changed: 50 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,21 @@ type DNSVerificationResult struct {
4343
type DNSVerificationConfig struct {
4444
// Timeout for DNS queries (default: 10 seconds)
4545
Timeout time.Duration
46-
46+
4747
// MaxRetries for transient failures (default: 3)
4848
MaxRetries int
49-
49+
5050
// RetryDelay base delay between retries (default: 1 second)
5151
RetryDelay time.Duration
52-
52+
5353
// UseSecureResolvers enables use of secure DNS resolvers
5454
UseSecureResolvers bool
55-
55+
5656
// CustomResolvers allows specifying custom DNS servers
5757
CustomResolvers []string
58+
59+
// Resolver allows injecting a custom DNS resolver (primarily for testing)
60+
Resolver DNSResolver
5861
}
5962

6063
// DefaultDNSConfig returns the default configuration for DNS verification
@@ -108,7 +111,7 @@ func VerifyDNSRecord(domain, expectedToken string) (*DNSVerificationResult, erro
108111
// VerifyDNSRecordWithConfig performs DNS verification with custom configuration
109112
func VerifyDNSRecordWithConfig(domain, expectedToken string, config *DNSVerificationConfig) (*DNSVerificationResult, error) {
110113
startTime := time.Now()
111-
114+
112115
// Input validation
113116
if domain == "" {
114117
return nil, &DNSVerificationError{
@@ -117,15 +120,15 @@ func VerifyDNSRecordWithConfig(domain, expectedToken string, config *DNSVerifica
117120
Message: "domain cannot be empty",
118121
}
119122
}
120-
123+
121124
if expectedToken == "" {
122125
return nil, &DNSVerificationError{
123126
Domain: domain,
124127
Token: expectedToken,
125128
Message: "token cannot be empty",
126129
}
127130
}
128-
131+
129132
// Validate token format
130133
if !ValidateTokenFormat(expectedToken) {
131134
return nil, &DNSVerificationError{
@@ -134,43 +137,43 @@ func VerifyDNSRecordWithConfig(domain, expectedToken string, config *DNSVerifica
134137
Message: "invalid token format",
135138
}
136139
}
137-
140+
138141
// Normalize domain (remove trailing dots, convert to lowercase)
139142
domain = strings.ToLower(strings.TrimSuffix(domain, "."))
140-
143+
141144
log.Printf("Starting DNS verification for domain: %s with token: %s", domain, expectedToken)
142-
145+
143146
// Create context with timeout
144147
ctx, cancel := context.WithTimeout(context.Background(), config.Timeout)
145148
defer cancel()
146-
149+
147150
// Perform verification with retries
148151
result, err := performDNSVerificationWithRetries(ctx, domain, expectedToken, config)
149-
152+
150153
// Calculate duration
151154
duration := time.Since(startTime)
152155
if result != nil {
153156
result.Duration = duration.String()
154157
}
155-
156-
log.Printf("DNS verification completed for domain %s in %v: success=%t",
158+
159+
log.Printf("DNS verification completed for domain %s in %v: success=%t",
157160
domain, duration, result != nil && result.Success)
158-
161+
159162
return result, err
160163
}
161164

162165
// performDNSVerificationWithRetries implements the retry logic for DNS verification
163166
func performDNSVerificationWithRetries(ctx context.Context, domain, expectedToken string, config *DNSVerificationConfig) (*DNSVerificationResult, error) {
164167
var lastErr error
165168
var lastResult *DNSVerificationResult
166-
169+
167170
retryDelay := config.RetryDelay
168-
171+
169172
for attempt := 0; attempt <= config.MaxRetries; attempt++ {
170173
if attempt > 0 {
171-
log.Printf("DNS verification retry %d/%d for domain %s after %v delay",
174+
log.Printf("DNS verification retry %d/%d for domain %s after %v delay",
172175
attempt, config.MaxRetries, domain, retryDelay)
173-
176+
174177
// Wait before retry with context cancellation support
175178
select {
176179
case <-time.After(retryDelay):
@@ -182,38 +185,43 @@ func performDNSVerificationWithRetries(ctx context.Context, domain, expectedToke
182185
Cause: ctx.Err(),
183186
}
184187
}
185-
188+
186189
// Exponential backoff
187190
retryDelay *= 2
188191
}
189-
192+
190193
result, err := performDNSVerification(ctx, domain, expectedToken, config)
191194
if err == nil {
192195
return result, nil
193196
}
194-
197+
195198
lastErr = err
196199
lastResult = result
197-
200+
198201
// Check if error is retryable
199202
if !isRetryableDNSError(err) {
200203
log.Printf("Non-retryable DNS error for domain %s: %v", domain, err)
201204
break
202205
}
203-
204-
log.Printf("Retryable DNS error for domain %s (attempt %d/%d): %v",
206+
207+
log.Printf("Retryable DNS error for domain %s (attempt %d/%d): %v",
205208
domain, attempt+1, config.MaxRetries+1, err)
206209
}
207-
210+
208211
// All retries exhausted
209212
return lastResult, lastErr
210213
}
211214

212215
// performDNSVerification performs a single DNS verification attempt
213216
func performDNSVerification(ctx context.Context, domain, expectedToken string, config *DNSVerificationConfig) (*DNSVerificationResult, error) {
214-
// Create resolver
215-
resolver := createDNSResolver(config)
216-
217+
// Get resolver (either injected or create default)
218+
var resolver DNSResolver
219+
if config.Resolver != nil {
220+
resolver = config.Resolver
221+
} else {
222+
resolver = NewDefaultDNSResolver(config)
223+
}
224+
217225
// Query TXT records
218226
txtRecords, err := resolver.LookupTXT(ctx, domain)
219227
if err != nil {
@@ -223,22 +231,22 @@ func performDNSVerification(ctx context.Context, domain, expectedToken string, c
223231
Message: "failed to query DNS TXT records",
224232
Cause: err,
225233
}
226-
234+
227235
result := &DNSVerificationResult{
228236
Success: false,
229237
Domain: domain,
230238
Token: expectedToken,
231239
Message: dnsErr.Message,
232240
}
233-
241+
234242
return result, dnsErr
235243
}
236-
244+
237245
log.Printf("Found %d TXT records for domain %s", len(txtRecords), domain)
238-
246+
239247
// Check for verification token
240248
expectedRecord := fmt.Sprintf("mcp-verify=%s", expectedToken)
241-
249+
242250
for _, record := range txtRecords {
243251
log.Printf("Checking TXT record: %s", record)
244252
if record == expectedRecord {
@@ -249,12 +257,12 @@ func performDNSVerification(ctx context.Context, domain, expectedToken string, c
249257
Message: "domain verification successful",
250258
TXTRecords: txtRecords,
251259
}
252-
260+
253261
log.Printf("DNS verification successful for domain %s", domain)
254262
return result, nil
255263
}
256264
}
257-
265+
258266
// Token not found
259267
result := &DNSVerificationResult{
260268
Success: false,
@@ -263,67 +271,37 @@ func performDNSVerification(ctx context.Context, domain, expectedToken string, c
263271
Message: fmt.Sprintf("verification token not found in DNS TXT records (expected: %s)", expectedRecord),
264272
TXTRecords: txtRecords,
265273
}
266-
274+
267275
log.Printf("DNS verification failed for domain %s: token not found", domain)
268276
return result, nil
269277
}
270278

271-
// createDNSResolver creates a DNS resolver based on configuration
272-
func createDNSResolver(config *DNSVerificationConfig) *net.Resolver {
273-
if config.UseSecureResolvers && len(config.CustomResolvers) > 0 {
274-
// Create custom dialer for secure resolvers
275-
dialer := &net.Dialer{
276-
Timeout: config.Timeout,
277-
}
278-
279-
return &net.Resolver{
280-
PreferGo: true,
281-
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
282-
// Use first available custom resolver
283-
// In a production system, you might want to implement round-robin or failover
284-
for _, resolver := range config.CustomResolvers {
285-
conn, err := dialer.DialContext(ctx, network, resolver)
286-
if err == nil {
287-
log.Printf("Using DNS resolver: %s", resolver)
288-
return conn, nil
289-
}
290-
log.Printf("Failed to connect to DNS resolver %s: %v", resolver, err)
291-
}
292-
return nil, fmt.Errorf("all custom DNS resolvers failed")
293-
},
294-
}
295-
}
296-
297-
// Use system default resolver
298-
return net.DefaultResolver
299-
}
300-
301279
// isRetryableDNSError determines if a DNS error should be retried
302280
func isRetryableDNSError(err error) bool {
303281
if err == nil {
304282
return false
305283
}
306-
284+
307285
// Check for temporary network errors
308286
if netErr, ok := err.(*net.OpError); ok {
309287
return netErr.Temporary()
310288
}
311-
289+
312290
// Check for context timeout (might be temporary)
313291
if errors.Is(err, context.DeadlineExceeded) {
314292
return true
315293
}
316-
294+
317295
// Check for DNS-specific temporary failures
318296
dnsErr, ok := err.(*net.DNSError)
319297
if ok {
320298
return dnsErr.Temporary()
321299
}
322-
300+
323301
// Unwrap and check nested errors
324302
if unwrapped := errors.Unwrap(err); unwrapped != nil {
325303
return isRetryableDNSError(unwrapped)
326304
}
327-
305+
328306
return false
329307
}

internal/verification/dns_mock.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package verification
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"time"
7+
)
8+
9+
// MockDNSResolver implements DNSResolver for testing
10+
type MockDNSResolver struct {
11+
// TXTRecords maps domain names to their TXT records
12+
TXTRecords map[string][]string
13+
14+
// Errors maps domain names to errors that should be returned
15+
Errors map[string]error
16+
17+
// Delay simulates DNS query latency
18+
Delay time.Duration
19+
20+
// CallCount tracks how many times LookupTXT was called
21+
CallCount int
22+
23+
// LastDomain tracks the last domain that was queried
24+
LastDomain string
25+
}
26+
27+
func (m *MockDNSResolver) LookupTXT(ctx context.Context, name string) ([]string, error) {
28+
m.CallCount++
29+
m.LastDomain = name
30+
31+
// Simulate delay if configured
32+
if m.Delay > 0 {
33+
select {
34+
case <-time.After(m.Delay):
35+
case <-ctx.Done():
36+
return nil, ctx.Err()
37+
}
38+
}
39+
40+
// Return error if configured for this domain
41+
if err, exists := m.Errors[name]; exists {
42+
return nil, err
43+
}
44+
45+
// Return TXT records if configured
46+
if records, exists := m.TXTRecords[name]; exists {
47+
return records, nil
48+
}
49+
50+
// Default: return empty records (domain exists but no TXT records)
51+
return []string{}, nil
52+
}
53+
54+
// Reset clears all state in the mock resolver
55+
func (m *MockDNSResolver) Reset() {
56+
m.CallCount = 0
57+
m.LastDomain = ""
58+
if m.TXTRecords != nil {
59+
for k := range m.TXTRecords {
60+
delete(m.TXTRecords, k)
61+
}
62+
}
63+
if m.Errors != nil {
64+
for k := range m.Errors {
65+
delete(m.Errors, k)
66+
}
67+
}
68+
}
69+
70+
// SetTXTRecord sets a TXT record for a domain
71+
func (m *MockDNSResolver) SetTXTRecord(domain string, records ...string) {
72+
if m.TXTRecords == nil {
73+
m.TXTRecords = make(map[string][]string)
74+
}
75+
m.TXTRecords[domain] = records
76+
}
77+
78+
// SetError sets an error to be returned for a domain
79+
func (m *MockDNSResolver) SetError(domain string, err error) {
80+
if m.Errors == nil {
81+
m.Errors = make(map[string]error)
82+
}
83+
m.Errors[domain] = err
84+
}
85+
86+
// SetVerificationToken is a convenience method to set up a valid verification token
87+
func (m *MockDNSResolver) SetVerificationToken(domain, token string) {
88+
m.SetTXTRecord(domain, fmt.Sprintf("mcp-verify=%s", token))
89+
}
90+
91+
// NewMockDNSResolver creates a new mock DNS resolver
92+
func NewMockDNSResolver() *MockDNSResolver {
93+
return &MockDNSResolver{
94+
TXTRecords: make(map[string][]string),
95+
Errors: make(map[string]error),
96+
}
97+
}

0 commit comments

Comments
 (0)