Skip to content

Commit 66eeb50

Browse files
committed
Completely cover new nameserver code
1 parent a419925 commit 66eeb50

File tree

2 files changed

+63
-8
lines changed

2 files changed

+63
-8
lines changed

challenge/dns01/nameserver.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,10 @@ func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) {
294294
in, _, err := udp.Exchange(m, ns)
295295

296296
network = getNetwork("tcp")
297-
if in != nil && in.Truncated {
297+
// We can encounter a net.OpError if the nameserver is not listening
298+
// on UDP at all, i.e. net.Dial could not make a connection.
299+
_, isOpErr := err.(*net.OpError)
300+
if (in != nil && in.Truncated) || isOpErr {
298301
tcp := &dns.Client{Net: network, Timeout: dnsTimeout}
299302
// If the TCP request succeeds, the err will reset to nil
300303
in, _, err = tcp.Exchange(m, ns)

challenge/dns01/nameserver_test.go

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package dns01
22

33
import (
4+
"fmt"
45
getport "github.com/jsumners/go-getport"
56
"github.com/miekg/dns"
67
"net"
@@ -13,6 +14,10 @@ import (
1314
)
1415

1516
type testDnsHandler struct{}
17+
type testDnsServer struct {
18+
*dns.Server
19+
getport.PortResult
20+
}
1621

1722
func (handler *testDnsHandler) ServeDNS(writer dns.ResponseWriter, reply *dns.Msg) {
1823
msg := dns.Msg{}
@@ -39,32 +44,45 @@ func (handler *testDnsHandler) ServeDNS(writer dns.ResponseWriter, reply *dns.Ms
3944
writer.WriteMsg(&msg)
4045
}
4146

42-
func getTestNameserver(t *testing.T, network string) *dns.Server {
47+
// getTestNameserver constructs a new DNS server on a local address, or set
48+
// of addresses, that responds to an `A` query for `example.com`.
49+
func getTestNameserver(t *testing.T, network string) testDnsServer {
4350
server := &dns.Server{
4451
Handler: new(testDnsHandler),
4552
Net: network,
4653
}
54+
testServer := testDnsServer{
55+
Server: server,
56+
}
4757

4858
var protocol getport.Protocol
59+
var address string
4960
switch network {
5061
case "tcp":
5162
protocol = getport.TCP
63+
address = "0.0.0.0"
5264
case "tcp4":
5365
protocol = getport.TCP4
66+
address = "127.0.0.1"
5467
case "tcp6":
5568
protocol = getport.TCP6
69+
address = "::1"
5670
case "udp":
5771
protocol = getport.UDP
72+
address = "0.0.0.0"
5873
case "udp4":
5974
protocol = getport.UDP4
75+
address = "127.0.0.1"
6076
case "udp6":
6177
protocol = getport.UDP6
78+
address = "::1"
6279
}
63-
portResult, portError := getport.GetPort(protocol, "127.0.0.1")
80+
portResult, portError := getport.GetPort(protocol, address)
6481
if portError != nil {
6582
t.Error(portError)
66-
return server
83+
return testServer
6784
}
85+
testServer.PortResult = portResult
6886
server.Addr = getport.PortResultToAddress(portResult)
6987

7088
waitLock := sync.Mutex{}
@@ -77,18 +95,52 @@ func getTestNameserver(t *testing.T, network string) *dns.Server {
7795
}()
7896

7997
waitLock.Lock()
80-
return server
98+
return testServer
8199
}
82100

83101
func TestSendDNSQuery(t *testing.T) {
84102
t.Run("does udp4 only", func(t *testing.T) {
85103
SetNetworkStack(IPv4Only)
86104
nameserver := getTestNameserver(t, getNetwork("udp"))
87-
defer nameserver.Shutdown()
105+
defer nameserver.Server.Shutdown()
106+
107+
serverAddress := fmt.Sprintf("127.0.0.1:%d", nameserver.PortResult.Port)
108+
recursiveNameservers = ParseNameservers([]string{serverAddress})
109+
msg := createDNSMsg("example.com.", dns.TypeA, true)
110+
result, queryError := sendDNSQuery(msg, serverAddress)
111+
assert.NoError(t, queryError)
112+
assert.Equal(t, result.Answer[0].(*dns.A).A.String(), "127.0.0.1")
113+
})
114+
115+
t.Run("does udp6 only", func(t *testing.T) {
116+
SetNetworkStack(IPv6Only)
117+
nameserver := getTestNameserver(t, getNetwork("udp"))
118+
defer nameserver.Server.Shutdown()
88119

89-
recursiveNameservers = ParseNameservers([]string{nameserver.Addr})
120+
serverAddress := fmt.Sprintf("[::1]:%d", nameserver.PortResult.Port)
121+
recursiveNameservers = ParseNameservers([]string{serverAddress})
90122
msg := createDNSMsg("example.com.", dns.TypeA, true)
91-
result, queryError := sendDNSQuery(msg, nameserver.Addr)
123+
result, queryError := sendDNSQuery(msg, serverAddress)
124+
assert.NoError(t, queryError)
125+
assert.Equal(t, result.Answer[0].(*dns.A).A.String(), "127.0.0.1")
126+
})
127+
128+
t.Run("does tcp4 and tcp6", func(t *testing.T) {
129+
SetNetworkStack(DefaultNetworkStack)
130+
nameserver := getTestNameserver(t, getNetwork("tcp"))
131+
defer nameserver.Server.Shutdown()
132+
133+
serverAddress := fmt.Sprintf("[::1]:%d", nameserver.PortResult.Port)
134+
recursiveNameservers = ParseNameservers([]string{serverAddress})
135+
msg := createDNSMsg("example.com.", dns.TypeA, true)
136+
result, queryError := sendDNSQuery(msg, serverAddress)
137+
assert.NoError(t, queryError)
138+
assert.Equal(t, result.Answer[0].(*dns.A).A.String(), "127.0.0.1")
139+
140+
serverAddress = fmt.Sprintf("127.0.0.1:%d", nameserver.PortResult.Port)
141+
recursiveNameservers = ParseNameservers([]string{serverAddress})
142+
msg = createDNSMsg("example.com.", dns.TypeA, true)
143+
result, queryError = sendDNSQuery(msg, serverAddress)
92144
assert.NoError(t, queryError)
93145
assert.Equal(t, result.Answer[0].(*dns.A).A.String(), "127.0.0.1")
94146
})

0 commit comments

Comments
 (0)