11package dns01
22
33import (
4+ "fmt"
45 getport "github.com/jsumners/go-getport"
56 "github.com/miekg/dns"
67 "net"
@@ -13,6 +14,10 @@ import (
1314)
1415
1516type testDnsHandler struct {}
17+ type testDnsServer struct {
18+ * dns.Server
19+ getport.PortResult
20+ }
1621
1722func (handler * testDnsHandler ) ServeDNS (writer dns.ResponseWriter , reply * dns.Msg ) {
1823 msg := dns.Msg {}
@@ -39,32 +44,45 @@ func (handler *testDnsHandler) ServeDNS(writer dns.ResponseWriter, reply *dns.Ms
3944 writer .WriteMsg (& msg )
4045}
4146
42- func getTestNameserver (t * testing.T , network string ) * dns.Server {
47+ // getTestNameserver constructs a new DNS server on a local address, or set
48+ // of addresses, that responds to an `A` query for `example.com`.
49+ func getTestNameserver (t * testing.T , network string ) testDnsServer {
4350 server := & dns.Server {
4451 Handler : new (testDnsHandler ),
4552 Net : network ,
4653 }
54+ testServer := testDnsServer {
55+ Server : server ,
56+ }
4757
4858 var protocol getport.Protocol
59+ var address string
4960 switch network {
5061 case "tcp" :
5162 protocol = getport .TCP
63+ address = "0.0.0.0"
5264 case "tcp4" :
5365 protocol = getport .TCP4
66+ address = "127.0.0.1"
5467 case "tcp6" :
5568 protocol = getport .TCP6
69+ address = "::1"
5670 case "udp" :
5771 protocol = getport .UDP
72+ address = "0.0.0.0"
5873 case "udp4" :
5974 protocol = getport .UDP4
75+ address = "127.0.0.1"
6076 case "udp6" :
6177 protocol = getport .UDP6
78+ address = "::1"
6279 }
63- portResult , portError := getport .GetPort (protocol , "127.0.0.1" )
80+ portResult , portError := getport .GetPort (protocol , address )
6481 if portError != nil {
6582 t .Error (portError )
66- return server
83+ return testServer
6784 }
85+ testServer .PortResult = portResult
6886 server .Addr = getport .PortResultToAddress (portResult )
6987
7088 waitLock := sync.Mutex {}
@@ -77,18 +95,52 @@ func getTestNameserver(t *testing.T, network string) *dns.Server {
7795 }()
7896
7997 waitLock .Lock ()
80- return server
98+ return testServer
8199}
82100
83101func TestSendDNSQuery (t * testing.T ) {
84102 t .Run ("does udp4 only" , func (t * testing.T ) {
85103 SetNetworkStack (IPv4Only )
86104 nameserver := getTestNameserver (t , getNetwork ("udp" ))
87- defer nameserver .Shutdown ()
105+ defer nameserver .Server .Shutdown ()
106+
107+ serverAddress := fmt .Sprintf ("127.0.0.1:%d" , nameserver .PortResult .Port )
108+ recursiveNameservers = ParseNameservers ([]string {serverAddress })
109+ msg := createDNSMsg ("example.com." , dns .TypeA , true )
110+ result , queryError := sendDNSQuery (msg , serverAddress )
111+ assert .NoError (t , queryError )
112+ assert .Equal (t , result .Answer [0 ].(* dns.A ).A .String (), "127.0.0.1" )
113+ })
114+
115+ t .Run ("does udp6 only" , func (t * testing.T ) {
116+ SetNetworkStack (IPv6Only )
117+ nameserver := getTestNameserver (t , getNetwork ("udp" ))
118+ defer nameserver .Server .Shutdown ()
88119
89- recursiveNameservers = ParseNameservers ([]string {nameserver .Addr })
120+ serverAddress := fmt .Sprintf ("[::1]:%d" , nameserver .PortResult .Port )
121+ recursiveNameservers = ParseNameservers ([]string {serverAddress })
90122 msg := createDNSMsg ("example.com." , dns .TypeA , true )
91- result , queryError := sendDNSQuery (msg , nameserver .Addr )
123+ result , queryError := sendDNSQuery (msg , serverAddress )
124+ assert .NoError (t , queryError )
125+ assert .Equal (t , result .Answer [0 ].(* dns.A ).A .String (), "127.0.0.1" )
126+ })
127+
128+ t .Run ("does tcp4 and tcp6" , func (t * testing.T ) {
129+ SetNetworkStack (DefaultNetworkStack )
130+ nameserver := getTestNameserver (t , getNetwork ("tcp" ))
131+ defer nameserver .Server .Shutdown ()
132+
133+ serverAddress := fmt .Sprintf ("[::1]:%d" , nameserver .PortResult .Port )
134+ recursiveNameservers = ParseNameservers ([]string {serverAddress })
135+ msg := createDNSMsg ("example.com." , dns .TypeA , true )
136+ result , queryError := sendDNSQuery (msg , serverAddress )
137+ assert .NoError (t , queryError )
138+ assert .Equal (t , result .Answer [0 ].(* dns.A ).A .String (), "127.0.0.1" )
139+
140+ serverAddress = fmt .Sprintf ("127.0.0.1:%d" , nameserver .PortResult .Port )
141+ recursiveNameservers = ParseNameservers ([]string {serverAddress })
142+ msg = createDNSMsg ("example.com." , dns .TypeA , true )
143+ result , queryError = sendDNSQuery (msg , serverAddress )
92144 assert .NoError (t , queryError )
93145 assert .Equal (t , result .Answer [0 ].(* dns.A ).A .String (), "127.0.0.1" )
94146 })
0 commit comments