diff --git a/agent/utils/ssl/manual_client.go b/agent/utils/ssl/manual_client.go index d7e7ee5a156a..dc501e893adc 100644 --- a/agent/utils/ssl/manual_client.go +++ b/agent/utils/ssl/manual_client.go @@ -12,6 +12,7 @@ import ( "fmt" "github.com/1Panel-dev/1Panel/agent/app/model" "github.com/go-acme/lego/v4/certificate" + "github.com/miekg/dns" "golang.org/x/crypto/acme" "log" "net" @@ -123,7 +124,7 @@ func queryDNSRecords(domain string) (map[string]string, error) { return records, nil } -func (c *ManualClient) handleAuthorization(ctx context.Context, authzURL string) error { +func (c *ManualClient) handleAuthorization(ctx context.Context, authzURL string, nameservers []string) error { authz, err := c.client.GetAuthorization(ctx, authzURL) if err != nil { return fmt.Errorf("failed to get authorization: %v", err) @@ -158,9 +159,22 @@ func (c *ManualClient) handleAuthorization(ctx context.Context, authzURL string) for { c.logger.Printf("[INFO] [%s] acme: Checking DNS record propagation.", domain) - currentRecords, err := queryDNSRecords(domain) - if err != nil { - return fmt.Errorf("failed to query DNS records: %v", err) + var currentRecords map[string]string + var queryErr error + if len(nameservers) == 0 { + currentRecords, queryErr = queryDNSRecords(domain) + } else { + var errs []string + for _, nameserver := range nameservers { + currentRecords, queryErr = queryDNSRecordsWithResolver(ctx, c.logger, domain, nameserver) + errs = append(errs, fmt.Sprintf("%s: %v", nameserver, queryErr)) + } + if queryErr != nil && len(errs) > 0 { + queryErr = fmt.Errorf("all nameservers failed: %s", strings.Join(errs, "; ")) + } + } + if currentRecords == nil && queryErr != nil { + return fmt.Errorf("failed to query DNS records: %v", queryErr) } recordName := fmt.Sprintf("_acme-challenge.%s", domain) providedRecord, exists := currentRecords[recordName] @@ -262,7 +276,7 @@ func (c *ManualClient) RequestCertificate(ctx context.Context, websiteSSL *model defer delete(Orders, websiteSSL.ID) for _, authzURL := range order.AuthzURLs { - if err := c.handleAuthorization(ctx, authzURL); err != nil { + if err := c.handleAuthorization(ctx, authzURL, getNameservers(*websiteSSL)); err != nil { return res, err } } @@ -306,3 +320,73 @@ func (c *ManualClient) RequestCertificate(ctx context.Context, websiteSSL *model } return resource, nil } + +func getNameservers(websiteSSL model.WebsiteSSL) []string { + var nameservers []string + if websiteSSL.Nameserver1 != "" { + nameservers = append(nameservers, handleNameserver(websiteSSL.Nameserver1)) + } + if websiteSSL.Nameserver2 != "" { + nameservers = append(nameservers, handleNameserver(websiteSSL.Nameserver2)) + } + return nameservers +} + +func handleNameserver(nameserver string) string { + if strings.Contains(nameserver, ":") { + return nameserver + } + return fmt.Sprintf("%s:53", nameserver) +} + +func queryDNSRecordsWithResolver(ctx context.Context, logger *log.Logger, domain string, dnsServer string) (map[string]string, error) { + recordName := fmt.Sprintf("_acme-challenge.%s", domain) + c := new(dns.Client) + c.Timeout = 10 * time.Second + c.Net = "udp" + + m := new(dns.Msg) + m.SetQuestion(dns.Fqdn(recordName), dns.TypeTXT) + m.RecursionDesired = true + + r, _, err := c.ExchangeContext(ctx, m, dnsServer) + if isNetworkError(err) { + logger.Printf("[WARN] Network error occurred while querying DNS: %v", err) + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("DNS query failed: %w", err) + } + + if r.Rcode == dns.RcodeNameError { + logger.Printf("[INFO] DNS record does not exist yet (NXDOMAIN)") + return nil, nil + } + + if r.Rcode != dns.RcodeSuccess { + return nil, fmt.Errorf("DNS query failed with code: %s", dns.RcodeToString[r.Rcode]) + } + + records := make(map[string]string) + + for _, answer := range r.Answer { + if txt, ok := answer.(*dns.TXT); ok { + if len(txt.Txt) > 0 { + records[recordName] = txt.Txt[0] + break + } + } + } + + return records, nil +} + +func isNetworkError(err error) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), "i/o timeout") || + strings.Contains(err.Error(), "connection refused") || + strings.Contains(err.Error(), "network is unreachable") || + strings.Contains(err.Error(), "no route to host") +}