diff --git a/.golangci.yml b/.golangci.yml index 74b0e4af5a..8df9bfed45 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -172,6 +172,8 @@ issues: text: 'dnsTimeout is a global variable' - path: challenge/dns01/nameserver_test.go text: 'findXByFqdnTestCases is a global variable' + - path: challenge/dns01/network.go + text: 'currentNetworkStack is a global variable' - path: challenge/http01/domain_matcher.go text: 'string `Host` has \d occurrences, make it a constant' - path: challenge/http01/domain_matcher.go diff --git a/challenge/dns01/nameserver.go b/challenge/dns01/nameserver.go index 206611be4a..8cc8d0cc65 100644 --- a/challenge/dns01/nameserver.go +++ b/challenge/dns01/nameserver.go @@ -265,7 +265,8 @@ func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg { func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) { if ok, _ := strconv.ParseBool(os.Getenv("LEGO_EXPERIMENTAL_DNS_TCP_ONLY")); ok { - tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout} + network := currentNetworkStack.Network("tcp") + tcp := &dns.Client{Net: network, Timeout: dnsTimeout} r, _, err := tcp.Exchange(m, ns) if err != nil { return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err} @@ -274,11 +275,16 @@ func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) { return r, nil } - udp := &dns.Client{Net: "udp", Timeout: dnsTimeout} + udpNetwork := currentNetworkStack.Network("udp") + udp := &dns.Client{Net: udpNetwork, Timeout: dnsTimeout} r, _, err := udp.Exchange(m, ns) - if r != nil && r.Truncated { - tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout} + // We can encounter a net.OpError if the nameserver is not listening + // on UDP at all, i.e. net.Dial could not make a connection. + var opErr *net.OpError + if (r != nil && r.Truncated) || errors.As(err, &opErr) { + tcpNetwork := currentNetworkStack.Network("tcp") + tcp := &dns.Client{Net: tcpNetwork, Timeout: dnsTimeout} // If the TCP request succeeds, the "err" will reset to nil r, _, err = tcp.Exchange(m, ns) } diff --git a/challenge/dns01/nameserver_test.go b/challenge/dns01/nameserver_test.go index 15b19beba0..957dddc196 100644 --- a/challenge/dns01/nameserver_test.go +++ b/challenge/dns01/nameserver_test.go @@ -2,14 +2,134 @@ package dns01 import ( "errors" + "net" "sort" + "sync" "testing" + "time" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func testDNSHandler(writer dns.ResponseWriter, reply *dns.Msg) { + msg := dns.Msg{} + msg.SetReply(reply) + + if reply.Question[0].Qtype == dns.TypeA { + msg.Authoritative = true + domain := msg.Question[0].Name + msg.Answer = append( + msg.Answer, + &dns.A{ + Hdr: dns.RR_Header{ + Name: domain, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 60, + }, + A: net.IPv4(127, 0, 0, 1), + }, + ) + } + + _ = writer.WriteMsg(&msg) +} + +// getTestNameserver constructs a new DNS server on a local address, or set +// of addresses, that responds to an `A` query for `example.com`. +func getTestNameserver(t *testing.T, network string) *dns.Server { + t.Helper() + server := &dns.Server{ + Handler: dns.HandlerFunc(testDNSHandler), + Net: network, + } + switch network { + case "tcp", "udp": + server.Addr = "0.0.0.0:0" + case "tcp4", "udp4": + server.Addr = "127.0.0.1:0" + case "tcp6", "udp6": + server.Addr = "[::1]:0" + } + + waitLock := sync.Mutex{} + waitLock.Lock() + server.NotifyStartedFunc = waitLock.Unlock + + go func() { _ = server.ListenAndServe() }() + + waitLock.Lock() + return server +} + +func startTestNameserver(t *testing.T, stack networkStack, proto string) (shutdown func(), addr string) { + t.Helper() + currentNetworkStack = stack + srv := getTestNameserver(t, currentNetworkStack.Network(proto)) + + shutdown = func() { _ = srv.Shutdown() } + if proto == "tcp" { + addr = srv.Listener.Addr().String() + } else { + addr = srv.PacketConn.LocalAddr().String() + } + return +} + +func TestSendDNSQuery(t *testing.T) { + currentNameservers := recursiveNameservers + + t.Cleanup(func() { + recursiveNameservers = currentNameservers + currentNetworkStack = dualStack + }) + + t.Run("does udp4 only", func(t *testing.T) { + stop, addr := startTestNameserver(t, ipv4only, "udp") + defer stop() + + recursiveNameservers = ParseNameservers([]string{addr}) + msg := createDNSMsg("example.com.", dns.TypeA, true) + result, queryError := sendDNSQuery(msg, addr) + require.NoError(t, queryError) + assert.Equal(t, "127.0.0.1", result.Answer[0].(*dns.A).A.String()) + }) + + t.Run("does udp6 only", func(t *testing.T) { + stop, addr := startTestNameserver(t, ipv6only, "udp") + defer stop() + + recursiveNameservers = ParseNameservers([]string{addr}) + msg := createDNSMsg("example.com.", dns.TypeA, true) + result, queryError := sendDNSQuery(msg, addr) + require.NoError(t, queryError) + assert.Equal(t, "127.0.0.1", result.Answer[0].(*dns.A).A.String()) + }) + + t.Run("does tcp4 and tcp6", func(t *testing.T) { + stop, addr := startTestNameserver(t, dualStack, "tcp") + host, port, _ := net.SplitHostPort(addr) + defer stop() + t.Logf("### port: %s", port) + + addr6 := net.JoinHostPort(host, port) + recursiveNameservers = ParseNameservers([]string{addr6}) + msg := createDNSMsg("example.com.", dns.TypeA, true) + result, queryError := sendDNSQuery(msg, addr6) + require.NoError(t, queryError) + assert.Equal(t, "127.0.0.1", result.Answer[0].(*dns.A).A.String()) + + addr4 := net.JoinHostPort("127.0.0.1", port) + recursiveNameservers = ParseNameservers([]string{addr4}) + msg = createDNSMsg("example.com.", dns.TypeA, true) + result, queryError = sendDNSQuery(msg, addr4) + require.NoError(t, queryError) + assert.Equal(t, "127.0.0.1", result.Answer[0].(*dns.A).A.String()) + }) +} + func TestLookupNameserversOK(t *testing.T) { testCases := []struct { fqdn string @@ -75,6 +195,7 @@ var findXByFqdnTestCases = []struct { primaryNs string nameservers []string expectedError string + timeout time.Duration }{ { desc: "domain is a CNAME", @@ -117,14 +238,18 @@ var findXByFqdnTestCases = []struct { zone: "google.com.", primaryNs: "ns1.google.com.", nameservers: []string{":7053", ":8053", "8.8.8.8:53"}, + timeout: 500 * time.Millisecond, }, { desc: "only non-existent nameservers", fqdn: "mail.google.com.", zone: "google.com.", nameservers: []string{":7053", ":8053", ":9053"}, - // use only the start of the message because the port changes with each call: 127.0.0.1:XXXXX->127.0.0.1:7053. - expectedError: "[fqdn=mail.google.com.] could not find the start of authority for 'mail.google.com.': DNS call error: read udp ", + // NOTE: On Windows, net.DialContext finds a way down to the ContectEx syscall. + // There a fault is marked as "connectex", not "connect", see + // https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/net/fd_windows.go;l=112 + expectedError: "could not find the start of authority for 'mail.google.com.':", + timeout: 500 * time.Millisecond, }, { desc: "no nameservers", @@ -155,6 +280,11 @@ func TestFindZoneByFqdnCustom(t *testing.T) { func TestFindPrimaryNsByFqdnCustom(t *testing.T) { for _, test := range findXByFqdnTestCases { t.Run(test.desc, func(t *testing.T) { + origTimeout := dnsTimeout + if test.timeout > 0 { + dnsTimeout = test.timeout + } + ClearFqdnCache() ns, err := FindPrimaryNsByFqdnCustom(test.fqdn, test.nameservers) @@ -165,6 +295,8 @@ func TestFindPrimaryNsByFqdnCustom(t *testing.T) { require.NoError(t, err) assert.Equal(t, test.primaryNs, ns) } + + dnsTimeout = origTimeout }) } } diff --git a/challenge/dns01/network.go b/challenge/dns01/network.go new file mode 100644 index 0000000000..26e0ba57b0 --- /dev/null +++ b/challenge/dns01/network.go @@ -0,0 +1,41 @@ +package dns01 + +// networkStack is used to indicate which IP stack should be used for DNS queries. +type networkStack int + +const ( + dualStack networkStack = iota + ipv4only + ipv6only +) + +// currentNetworkStack is used to define which IP stack will be used. The default is +// both IPv4 and IPv6. Set to IPv4Only or IPv6Only to select either version. +var currentNetworkStack = dualStack + +// Network interprets the NetworkStack setting in relation to the desired +// protocol. The proto value should be either "udp" or "tcp". +func (s networkStack) Network(proto string) string { + // The DNS client passes whatever value is set in (*dns.Client).Net to + // the [net.Dialer](https://github.com/miekg/dns/blob/fe20d5d/client.go#L119-L141). + // And the net.Dialer accepts strings such as "udp4" or "tcp6" + // (https://cs.opensource.google/go/go/+/refs/tags/go1.18.9:src/net/dial.go;l=167-182). + switch s { + case ipv4only: + return proto + "4" + case ipv6only: + return proto + "6" + default: + return proto + } +} + +// SetIPv4Only forces DNS queries to only happen over the IPv4 stack. +func SetIPv4Only() { currentNetworkStack = ipv4only } + +// SetIPv6Only forces DNS queries to only happen over the IPv6 stack. +func SetIPv6Only() { currentNetworkStack = ipv6only } + +// SetDualStack indicates that both IPv4 and IPv6 should be allowed. +// This setting lets the OS determine which IP stack to use. +func SetDualStack() { currentNetworkStack = dualStack } diff --git a/challenge/http01/http_challenge_server.go b/challenge/http01/http_challenge_server.go index f69f5ac1f8..458995a9af 100644 --- a/challenge/http01/http_challenge_server.go +++ b/challenge/http01/http_challenge_server.go @@ -41,6 +41,28 @@ func NewUnixProviderServer(socketPath string, mode fs.FileMode) *ProviderServer return &ProviderServer{network: "unix", address: socketPath, socketMode: mode, matcher: &hostMatcher{}} } +// SetIPv4Only starts the challenge server on an IPv4 address. +// +// Calling this method has no effect if s was created with NewUnixProviderServer. +func (s *ProviderServer) SetIPv4Only() { s.setTCPStack("tcp4") } + +// SetIPv6Only starts the challenge server on an IPv6 address. +// +// Calling this method has no effect if s was created with NewUnixProviderServer. +func (s *ProviderServer) SetIPv6Only() { s.setTCPStack("tcp6") } + +// SetDualStack indicates that both IPv4 and IPv6 should be allowed. +// This setting lets the OS determine which IP stack to use for the challenge server. +// +// Calling this method has no effect if s was created with NewUnixProviderServer. +func (s *ProviderServer) SetDualStack() { s.setTCPStack("tcp") } + +func (s *ProviderServer) setTCPStack(network string) { + if s.network != "unix" { + s.network = network + } +} + // Present starts a web server and makes the token available at `ChallengePath(token)` for web requests. func (s *ProviderServer) Present(domain, token, keyAuth string) error { var err error diff --git a/challenge/http01/http_challenge_test.go b/challenge/http01/http_challenge_test.go index 3a5aa6bbe0..7971c84f2f 100644 --- a/challenge/http01/http_challenge_test.go +++ b/challenge/http01/http_challenge_test.go @@ -32,6 +32,7 @@ func TestProviderServer_GetAddress(t *testing.T) { testCases := []struct { desc string server *ProviderServer + network func(server *ProviderServer) expected string }{ { @@ -49,6 +50,18 @@ func TestProviderServer_GetAddress(t *testing.T) { server: NewProviderServer("localhost", "8080"), expected: "localhost:8080", }, + { + desc: "TCP4 with host and port", + server: NewProviderServer("localhost", "8080"), + network: func(s *ProviderServer) { s.SetIPv4Only() }, + expected: "localhost:8080", + }, + { + desc: "TCP6 with host and port", + server: NewProviderServer("localhost", "8080"), + network: func(s *ProviderServer) { s.SetIPv6Only() }, + expected: "localhost:8080", + }, { desc: "UDS socket", server: NewUnixProviderServer(sock, fs.ModeSocket|0o666), @@ -60,6 +73,10 @@ func TestProviderServer_GetAddress(t *testing.T) { t.Run(test.desc, func(t *testing.T) { t.Parallel() + if test.network != nil { + test.network(test.server) + } + address := test.server.GetAddress() assert.Equal(t, test.expected, address) }) diff --git a/challenge/tlsalpn01/tls_alpn_challenge_server.go b/challenge/tlsalpn01/tls_alpn_challenge_server.go index e5581b1e3f..83b7106604 100644 --- a/challenge/tlsalpn01/tls_alpn_challenge_server.go +++ b/challenge/tlsalpn01/tls_alpn_challenge_server.go @@ -26,6 +26,7 @@ const ( type ProviderServer struct { iface string port string + network string listener net.Listener } @@ -33,9 +34,22 @@ type ProviderServer struct { // Setting iface and / or port to an empty string will make the server fall back to // the "any" interface and port 443 respectively. func NewProviderServer(iface, port string) *ProviderServer { - return &ProviderServer{iface: iface, port: port} + if port == "" { + port = defaultTLSPort + } + return &ProviderServer{iface: iface, port: port, network: "tcp"} } +// SetIPv4Only starts the challenge server on an IPv4 address. +func (s *ProviderServer) SetIPv4Only() { s.network = "tcp4" } + +// SetIPv6Only starts the challenge server on an IPv6 address. +func (s *ProviderServer) SetIPv6Only() { s.network = "tcp6" } + +// SetDualStack indicates that both IPv4 and IPv6 should be allowed. +// This setting lets the OS determine which IP stack to use for the challenge server. +func (s *ProviderServer) SetDualStack() { s.network = "tcp" } + func (s *ProviderServer) GetAddress() string { return net.JoinHostPort(s.iface, s.port) } @@ -65,7 +79,7 @@ func (s *ProviderServer) Present(domain, token, keyAuth string) error { tlsConf.NextProtos = []string{ACMETLS1Protocol} // Create the listener with the created tls.Config. - s.listener, err = tls.Listen("tcp", s.GetAddress(), tlsConf) + s.listener, err = tls.Listen(s.network, s.GetAddress(), tlsConf) if err != nil { return fmt.Errorf("could not start HTTPS server for challenge: %w", err) } diff --git a/challenge/tlsalpn01/tls_alpn_challenge_test.go b/challenge/tlsalpn01/tls_alpn_challenge_test.go index 8725a1360f..2f57867c13 100644 --- a/challenge/tlsalpn01/tls_alpn_challenge_test.go +++ b/challenge/tlsalpn01/tls_alpn_challenge_test.go @@ -9,6 +9,7 @@ import ( "encoding/asn1" "net" "net/http" + "os" "testing" "github.com/go-acme/lego/v4/acme" @@ -20,6 +21,59 @@ import ( "github.com/stretchr/testify/require" ) +func TestProviderServer_GetAddress(t *testing.T) { + dir := t.TempDir() + t.Cleanup(func() { _ = os.RemoveAll(dir) }) + + testCases := []struct { + desc string + server *ProviderServer + network func(*ProviderServer) + expected string + }{ + { + desc: "TCP default address", + server: NewProviderServer("", ""), + expected: ":443", + }, + { + desc: "TCP with explicit port", + server: NewProviderServer("", "4443"), + expected: ":4443", + }, + { + desc: "TCP with host and port", + server: NewProviderServer("localhost", "4443"), + expected: "localhost:4443", + }, + { + desc: "TCP4 with host and port", + server: NewProviderServer("localhost", "4443"), + network: func(s *ProviderServer) { s.SetIPv4Only() }, + expected: "localhost:4443", + }, + { + desc: "TCP6 with host and port", + server: NewProviderServer("localhost", "4443"), + network: func(s *ProviderServer) { s.SetIPv6Only() }, + expected: "localhost:4443", + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + if test.network != nil { + test.network(test.server) + } + + address := test.server.GetAddress() + assert.Equal(t, test.expected, address) + }) + } +} + func TestChallenge(t *testing.T) { _, apiURL := tester.SetupFakeAPI(t) @@ -75,7 +129,7 @@ func TestChallenge(t *testing.T) { solver := NewChallenge( core, mockValidate, - &ProviderServer{port: port}, + &ProviderServer{port: port, network: "tcp"}, ) authz := acme.Authorization{ @@ -104,7 +158,7 @@ func TestChallengeInvalidPort(t *testing.T) { solver := NewChallenge( core, func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }, - &ProviderServer{port: "123456"}, + &ProviderServer{port: "123456", network: "tcp"}, ) authz := acme.Authorization{ @@ -176,7 +230,7 @@ func TestChallengeIPaddress(t *testing.T) { solver := NewChallenge( core, mockValidate, - &ProviderServer{port: port}, + &ProviderServer{port: port, network: "tcp"}, ) authz := acme.Authorization{ diff --git a/cmd/flags.go b/cmd/flags.go index cd21c466bb..69840be9d1 100644 --- a/cmd/flags.go +++ b/cmd/flags.go @@ -11,6 +11,16 @@ import ( func CreateFlags(defaultPath string) []cli.Flag { return []cli.Flag{ + &cli.BoolFlag{ + Name: "ipv4only", + Aliases: []string{"4"}, + Usage: "Use IPv4 only. This flag is ignored if ipv6only is also specified.", + }, + &cli.BoolFlag{ + Name: "ipv6only", + Aliases: []string{"6"}, + Usage: "Use IPv6 only. This flag is ignored if ipv4only is also specified.", + }, &cli.StringSliceFlag{ Name: "domains", Aliases: []string{"d"}, diff --git a/cmd/setup_challenges.go b/cmd/setup_challenges.go index 719f8dd6cf..a7d50729ad 100644 --- a/cmd/setup_challenges.go +++ b/cmd/setup_challenges.go @@ -42,6 +42,24 @@ func setupChallenges(ctx *cli.Context, client *lego.Client) { } } +type networkStackSetter interface { + SetIPv4Only() + SetIPv6Only() + SetDualStack() +} + +func setNetwork(ctx *cli.Context, srv networkStackSetter) { + switch v4, v6 := ctx.IsSet("ipv4only"), ctx.IsSet("ipv6only"); { + case v4 && !v6: + srv.SetIPv4Only() + case !v4 && v6: + srv.SetIPv6Only() + default: + // setting both --ipv4only and --ipv6only is not an error, just a no-op + srv.SetDualStack() + } +} + //nolint:gocyclo // the complexity is expected. func setupHTTPProvider(ctx *cli.Context) challenge.Provider { switch { @@ -75,12 +93,14 @@ func setupHTTPProvider(ctx *cli.Context) challenge.Provider { } srv := http01.NewProviderServer(host, port) + setNetwork(ctx, srv) if header := ctx.String("http.proxy-header"); header != "" { srv.SetProxyHeader(header) } return srv case ctx.Bool("http"): srv := http01.NewProviderServer("", "") + setNetwork(ctx, srv) if header := ctx.String("http.proxy-header"); header != "" { srv.SetProxyHeader(header) } @@ -104,9 +124,13 @@ func setupTLSProvider(ctx *cli.Context) challenge.Provider { log.Fatal(err) } - return tlsalpn01.NewProviderServer(host, port) + srv := tlsalpn01.NewProviderServer(host, port) + setNetwork(ctx, srv) + return srv case ctx.Bool("tls"): - return tlsalpn01.NewProviderServer("", "") + srv := tlsalpn01.NewProviderServer("", "") + setNetwork(ctx, srv) + return srv default: log.Fatal("Invalid HTTP challenge options.") return nil @@ -119,6 +143,16 @@ func setupDNS(ctx *cli.Context, client *lego.Client) { log.Fatal(err) } + switch v4, v6 := ctx.IsSet("ipv4only"), ctx.IsSet("ipv6only"); { + case v4 && !v6: + dns01.SetIPv4Only() + case !v4 && v6: + dns01.SetIPv6Only() + default: + // setting both --ipv4only and --ipv6only is not an error, just a no-op + dns01.SetDualStack() + } + servers := ctx.StringSlice("dns.resolvers") err = client.Challenge.SetDNS01Provider(provider, dns01.CondOption(len(servers) > 0, diff --git a/docs/data/zz_cli_help.toml b/docs/data/zz_cli_help.toml index 6b3d91c0ec..da92dde158 100644 --- a/docs/data/zz_cli_help.toml +++ b/docs/data/zz_cli_help.toml @@ -19,6 +19,8 @@ COMMANDS: help, h Shows a list of commands or help for one command GLOBAL OPTIONS: + --ipv4only, -4 Use IPv4 only. This flag is ignored if ipv6only is also specified. (default: false) + --ipv6only, -6 Use IPv6 only. This flag is ignored if ipv4only is also specified. (default: false) --domains value, -d value [ --domains value, -d value ] Add a domain to the process. Can be specified multiple times. --server value, -s value CA hostname (and optionally :port). The server certificate must be trusted in order to avoid further modifications to the client. (default: "https://acme-v02.api.letsencrypt.org/directory") [$LEGO_SERVER] --accept-tos, -a By setting this flag to true you indicate that you accept the current Let's Encrypt terms of service. (default: false)