From 6124405f945df7fdd7c892755bfa3337bb00d92f Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 7 Aug 2025 15:24:29 +0200 Subject: [PATCH 1/6] unify domain validation into single package used --- management/server/account.go | 14 ++----- management/server/nameserver.go | 27 ++----------- .../networks/resources/types/resource.go | 8 ++-- shared/management/domain/validate.go | 38 ++++++++++++++----- 4 files changed, 39 insertions(+), 48 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 0f60bc91cb8..333c9cca795 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -9,7 +9,6 @@ import ( "net/netip" "os" "reflect" - "regexp" "slices" "strconv" "strings" @@ -45,6 +44,7 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/status" ) @@ -224,7 +224,7 @@ func BuildManager( // enable single account mode only if configured by user and number of existing accounts is not grater than 1 am.singleAccountMode = singleAccountModeDomain != "" && accountsCounter <= 1 if am.singleAccountMode { - if !isDomainValid(singleAccountModeDomain) { + if !domain.IsValidDomain(singleAccountModeDomain) { return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain) } am.singleAccountModeDomain = singleAccountModeDomain @@ -401,7 +401,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra return status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") } - if newSettings.DNSDomain != "" && !isDomainValid(newSettings.DNSDomain) { + if newSettings.DNSDomain != "" && !domain.IsValidDomain(newSettings.DNSDomain) { return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain) } @@ -1518,7 +1518,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context return userAuth.AccountId, nil } - if userAuth.DomainCategory != types.PrivateCategory || !isDomainValid(userAuth.Domain) { + if userAuth.DomainCategory != types.PrivateCategory || !domain.IsValidDomain(userAuth.Domain) { return am.GetAccountIDByUserID(ctx, userAuth.UserId, userAuth.Domain) } @@ -1701,12 +1701,6 @@ func (am *DefaultAccountManager) HasConnectedChannel(peerID string) bool { return am.peersUpdateManager.HasChannel(peerID) } -var invalidDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`) - -func isDomainValid(domain string) bool { - return invalidDomainRegexp.MatchString(domain) -} - // GetDNSDomain returns the configured dnsDomain func (am *DefaultAccountManager) GetDNSDomain(settings *types.Settings) string { if settings == nil { diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 1ee8805fc9d..0a55d80f756 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -2,11 +2,8 @@ package server import ( "context" - "errors" - "regexp" "unicode/utf8" - "github.com/miekg/dns" "github.com/rs/xid" nbdns "github.com/netbirdio/netbird/dns" @@ -15,13 +12,10 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + nbDomain "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/status" ) -const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*[*.a-z]{1,}$` - -var invalidDomainName = errors.New("invalid domain name") - // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Read) @@ -268,8 +262,8 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo } for _, domain := range domains { - if err := validateDomain(domain); err != nil { - return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s %q", domain, err) + if nbDomain.IsValidDomain(domain) { + return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s", domain) } } return nil @@ -313,18 +307,3 @@ func validateGroups(list []string, groups map[string]*types.Group) error { return nil } - -var domainMatcher = regexp.MustCompile(domainPattern) - -func validateDomain(domain string) error { - if !domainMatcher.MatchString(domain) { - return errors.New("domain should consists of only letters, numbers, and hyphens with no leading, trailing hyphens, or spaces") - } - - _, valid := dns.IsDomainName(domain) - if !valid { - return invalidDomainName - } - - return nil -} diff --git a/management/server/networks/resources/types/resource.go b/management/server/networks/resources/types/resource.go index 7874be85865..1c5248f889d 100644 --- a/management/server/networks/resources/types/resource.go +++ b/management/server/networks/resources/types/resource.go @@ -4,15 +4,14 @@ import ( "errors" "fmt" "net/netip" - "regexp" "github.com/rs/xid" - nbDomain "github.com/netbirdio/netbird/shared/management/domain" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/route" + nbDomain "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/http/api" ) @@ -166,9 +165,8 @@ func GetResourceType(address string) (NetworkResourceType, string, netip.Prefix, return Host, "", netip.PrefixFrom(ip, ip.BitLen()), nil } - domainRegex := regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`) - if domainRegex.MatchString(address) { - return Domain, address, netip.Prefix{}, nil + if domain, err := nbDomain.ToValidDomain(address); err == nil { + return Domain, string(domain), netip.Prefix{}, nil } return "", "", netip.Prefix{}, errors.New("not a valid host, subnet, or domain") diff --git a/shared/management/domain/validate.go b/shared/management/domain/validate.go index bf2af7116fc..6744c704ba8 100644 --- a/shared/management/domain/validate.go +++ b/shared/management/domain/validate.go @@ -22,17 +22,11 @@ func ValidateDomains(domains []string) (List, error) { var domainList List for _, d := range domains { - // handles length and idna conversion - punycode, err := FromString(d) + validDomain, err := ToValidDomain(d) if err != nil { - return domainList, fmt.Errorf("convert domain to punycode: %s: %w", d, err) + return nil, fmt.Errorf("invalid domain %s: %w", d, err) } - - if !domainRegex.MatchString(string(punycode)) { - return domainList, fmt.Errorf("invalid domain format: %s", d) - } - - domainList = append(domainList, punycode) + domainList = append(domainList, validDomain) } return domainList, nil } @@ -54,3 +48,29 @@ func ValidateDomainsList(domains []string) error { } return nil } + +// IsValidDomain checks if the given domain is valid. +func IsValidDomain(domain string) bool { + // handles length and idna conversion + punycode, err := FromString(domain) + if err != nil { + return false + } + + return !domainRegex.MatchString(string(punycode)) +} + +// ToValidDomain converts a domain to a valid domain format. +func ToValidDomain(domain string) (Domain, error) { + // handles length and idna conversion + punycode, err := FromString(domain) + if err != nil { + return "", fmt.Errorf("convert domain to punycode: %s: %w", domain, err) + } + + if !domainRegex.MatchString(string(punycode)) { + return "", fmt.Errorf("invalid domain format: %s", domain) + } + + return punycode, nil +} From 0af0447f1b7e5b83fd09fd482e60beef22427db2 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 7 Aug 2025 17:05:29 +0200 Subject: [PATCH 2/6] update regex and tests --- management/server/nameserver.go | 6 +++++- management/server/nameserver_test.go | 10 +++++----- shared/management/domain/validate.go | 4 ++-- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 0a55d80f756..a9b7e3cf742 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -2,6 +2,7 @@ package server import ( "context" + "strings" "unicode/utf8" "github.com/rs/xid" @@ -262,7 +263,10 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo } for _, domain := range domains { - if nbDomain.IsValidDomain(domain) { + if strings.HasPrefix(domain, "*") { + return status.Errorf(status.InvalidArgument, "wildcard prefix is not allowed: %s", domain) + } + if !nbDomain.IsValidDomain(domain) { return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s", domain) } } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 959e7856a78..ad46443e023 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -910,12 +910,12 @@ func TestValidateDomain(t *testing.T) { errFunc: require.NoError, }, { - name: "Valid domain name with trailing dot", + name: "Invalid domain name with trailing dot", domain: "example.", - errFunc: require.NoError, + errFunc: require.Error, }, { - name: "Invalid wildcard domain name", + name: "Valid wildcard domain name", domain: "*.example", errFunc: require.Error, }, @@ -932,7 +932,7 @@ func TestValidateDomain(t *testing.T) { { name: "Invalid domain name with double hyphen", domain: "test--example.com", - errFunc: require.Error, + errFunc: require.NoError, // Note: Double hyphen is not valid but due to punicode hard to filter out }, { name: "Invalid domain name with a label exceeding 63 characters", @@ -968,7 +968,7 @@ func TestValidateDomain(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - testCase.errFunc(t, validateDomain(testCase.domain)) + testCase.errFunc(t, validateDomainInput(false, []string{testCase.domain}, false)) }) } diff --git a/shared/management/domain/validate.go b/shared/management/domain/validate.go index 6744c704ba8..6a52e636fd5 100644 --- a/shared/management/domain/validate.go +++ b/shared/management/domain/validate.go @@ -8,7 +8,7 @@ import ( const maxDomains = 32 -var domainRegex = regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`) +var domainRegex = regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)(?:\.(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$`) // ValidateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList. func ValidateDomains(domains []string) (List, error) { @@ -57,7 +57,7 @@ func IsValidDomain(domain string) bool { return false } - return !domainRegex.MatchString(string(punycode)) + return domainRegex.MatchString(string(punycode)) } // ToValidDomain converts a domain to a valid domain format. From b5da6d3f8e9c66cea4561d29cf7904f00d856dc9 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 7 Aug 2025 17:26:22 +0200 Subject: [PATCH 3/6] dynamic regex --- management/server/account.go | 6 +-- management/server/nameserver.go | 6 +-- .../networks/resources/types/resource.go | 2 +- shared/management/domain/validate.go | 46 +++++++++++++++++-- 4 files changed, 47 insertions(+), 13 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 333c9cca795..47e09b95eee 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -224,7 +224,7 @@ func BuildManager( // enable single account mode only if configured by user and number of existing accounts is not grater than 1 am.singleAccountMode = singleAccountModeDomain != "" && accountsCounter <= 1 if am.singleAccountMode { - if !domain.IsValidDomain(singleAccountModeDomain) { + if !domain.IsValidDomain(singleAccountModeDomain, false, false) { return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain) } am.singleAccountModeDomain = singleAccountModeDomain @@ -401,7 +401,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra return status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") } - if newSettings.DNSDomain != "" && !domain.IsValidDomain(newSettings.DNSDomain) { + if newSettings.DNSDomain != "" && !domain.IsValidDomain(newSettings.DNSDomain, false, true) { return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain) } @@ -1518,7 +1518,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context return userAuth.AccountId, nil } - if userAuth.DomainCategory != types.PrivateCategory || !domain.IsValidDomain(userAuth.Domain) { + if userAuth.DomainCategory != types.PrivateCategory || !domain.IsValidDomain(userAuth.Domain, false, false) { return am.GetAccountIDByUserID(ctx, userAuth.UserId, userAuth.Domain) } diff --git a/management/server/nameserver.go b/management/server/nameserver.go index a9b7e3cf742..380625ecb49 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -2,7 +2,6 @@ package server import ( "context" - "strings" "unicode/utf8" "github.com/rs/xid" @@ -263,10 +262,7 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo } for _, domain := range domains { - if strings.HasPrefix(domain, "*") { - return status.Errorf(status.InvalidArgument, "wildcard prefix is not allowed: %s", domain) - } - if !nbDomain.IsValidDomain(domain) { + if !nbDomain.IsValidDomain(domain, false, true) { return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s", domain) } } diff --git a/management/server/networks/resources/types/resource.go b/management/server/networks/resources/types/resource.go index 1c5248f889d..2f9347394c3 100644 --- a/management/server/networks/resources/types/resource.go +++ b/management/server/networks/resources/types/resource.go @@ -165,7 +165,7 @@ func GetResourceType(address string) (NetworkResourceType, string, netip.Prefix, return Host, "", netip.PrefixFrom(ip, ip.BitLen()), nil } - if domain, err := nbDomain.ToValidDomain(address); err == nil { + if domain, err := nbDomain.ToValidDomain(address, true, false); err == nil { return Domain, string(domain), netip.Prefix{}, nil } diff --git a/shared/management/domain/validate.go b/shared/management/domain/validate.go index 6a52e636fd5..1d508a33151 100644 --- a/shared/management/domain/validate.go +++ b/shared/management/domain/validate.go @@ -4,11 +4,45 @@ import ( "fmt" "regexp" "strings" + "sync" ) const maxDomains = 32 -var domainRegex = regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)(?:\.(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$`) +var regexCache = map[string]*regexp.Regexp{} +var regexCacheMu sync.Mutex + +func buildDomainRegex(allowWildcard, allowSingleToplevel bool) *regexp.Regexp { + key := fmt.Sprintf("%t:%t", allowWildcard, allowSingleToplevel) + + regexCacheMu.Lock() + defer regexCacheMu.Unlock() + + if re, ok := regexCache[key]; ok { + return re + } + + var pattern strings.Builder + pattern.WriteString("^") + + if allowWildcard { + pattern.WriteString(`(?:\*\.)?`) + } + + label := `(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?` + + if allowSingleToplevel { + pattern.WriteString(label + `(?:\.` + label + `)*`) + } else { + pattern.WriteString(label + `(?:\.` + label + `)+`) + } + + pattern.WriteString("$") + + re := regexp.MustCompile(pattern.String()) + regexCache[key] = re + return re +} // ValidateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList. func ValidateDomains(domains []string) (List, error) { @@ -22,7 +56,7 @@ func ValidateDomains(domains []string) (List, error) { var domainList List for _, d := range domains { - validDomain, err := ToValidDomain(d) + validDomain, err := ToValidDomain(d, false, false) if err != nil { return nil, fmt.Errorf("invalid domain %s: %w", d, err) } @@ -40,6 +74,8 @@ func ValidateDomainsList(domains []string) error { return fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains) } + domainRegex := buildDomainRegex(false, true) + for _, d := range domains { d := strings.ToLower(d) if !domainRegex.MatchString(d) { @@ -50,24 +86,26 @@ func ValidateDomainsList(domains []string) error { } // IsValidDomain checks if the given domain is valid. -func IsValidDomain(domain string) bool { +func IsValidDomain(domain string, allowWildcard, allowSingleToplevel bool) bool { // handles length and idna conversion punycode, err := FromString(domain) if err != nil { return false } + domainRegex := buildDomainRegex(allowWildcard, allowSingleToplevel) return domainRegex.MatchString(string(punycode)) } // ToValidDomain converts a domain to a valid domain format. -func ToValidDomain(domain string) (Domain, error) { +func ToValidDomain(domain string, allowWildcard, allowSingleToplevel bool) (Domain, error) { // handles length and idna conversion punycode, err := FromString(domain) if err != nil { return "", fmt.Errorf("convert domain to punycode: %s: %w", domain, err) } + domainRegex := buildDomainRegex(allowWildcard, allowSingleToplevel) if !domainRegex.MatchString(string(punycode)) { return "", fmt.Errorf("invalid domain format: %s", domain) } From 1c1706753df30cdb2fa46932f65f7afa08d467bc Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 7 Aug 2025 17:53:39 +0200 Subject: [PATCH 4/6] allow wildcard --- shared/management/domain/validate.go | 2 +- shared/management/domain/validate_test.go | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/shared/management/domain/validate.go b/shared/management/domain/validate.go index 1d508a33151..74a7901c19e 100644 --- a/shared/management/domain/validate.go +++ b/shared/management/domain/validate.go @@ -56,7 +56,7 @@ func ValidateDomains(domains []string) (List, error) { var domainList List for _, d := range domains { - validDomain, err := ToValidDomain(d, false, false) + validDomain, err := ToValidDomain(d, true, false) if err != nil { return nil, fmt.Errorf("invalid domain %s: %w", d, err) } diff --git a/shared/management/domain/validate_test.go b/shared/management/domain/validate_test.go index 30efcd9a95f..c52a8ee1055 100644 --- a/shared/management/domain/validate_test.go +++ b/shared/management/domain/validate_test.go @@ -59,7 +59,7 @@ func TestValidateDomains(t *testing.T) { { name: "Multiple domains valid and invalid", domains: []string{"google.com", "invalid,nbdomain.com", "münchen.de"}, - expected: List{"google.com"}, + expected: nil, wantErr: true, }, { @@ -146,9 +146,9 @@ func TestValidateDomainsList(t *testing.T) { wantErr: true, }, { - name: "Valid wildcard domain", + name: "Invalid wildcard domain", domains: []string{"*.example.com"}, - wantErr: false, + wantErr: true, }, { name: "Wildcard with leading dot - invalid", From 30b387ba02856d1da1a100c9f3729402ca375312 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Fri, 8 Aug 2025 15:19:13 +0200 Subject: [PATCH 5/6] separate fqdn and domain validation --- client/cmd/up.go | 2 +- .../http/handlers/routes/routes_handler.go | 8 +- management/server/peer.go | 2 +- shared/management/domain/validate.go | 53 +++++++----- shared/management/domain/validate_test.go | 85 ++++++++++++++++--- 5 files changed, 113 insertions(+), 37 deletions(-) diff --git a/client/cmd/up.go b/client/cmd/up.go index 8732a687dc0..8af3b4a5061 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -667,7 +667,7 @@ func validateDnsLabels(labels []string) (domain.List, error) { return domains, nil } - domains, err = domain.ValidateDomains(labels) + domains, err = domain.ValidateFQDNs(labels) if err != nil { return nil, fmt.Errorf("failed to validate dns labels: %v", err) } diff --git a/management/server/http/handlers/routes/routes_handler.go b/management/server/http/handlers/routes/routes_handler.go index 7950db1e845..112d43c08b2 100644 --- a/management/server/http/handlers/routes/routes_handler.go +++ b/management/server/http/handlers/routes/routes_handler.go @@ -8,13 +8,13 @@ import ( "github.com/gorilla/mux" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/route" ) const failedToConvertRoute = "failed to convert route to response: %v" @@ -94,7 +94,7 @@ func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) { var networkType route.NetworkType var newPrefix netip.Prefix if req.Domains != nil { - d, err := domain.ValidateDomains(*req.Domains) + d, err := domain.ValidateFQDNs(*req.Domains) if err != nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w) return @@ -217,7 +217,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) { } if req.Domains != nil { - d, err := domain.ValidateDomains(*req.Domains) + d, err := domain.ValidateFQDNs(*req.Domains) if err != nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w) return diff --git a/management/server/peer.go b/management/server/peer.go index d72eac91acb..eede5c8c360 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -539,7 +539,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } } - if err := domain.ValidateDomainsList(peer.ExtraDNSLabels); err != nil { + if err := domain.ValidateFQDNsList(peer.ExtraDNSLabels); err != nil { return nil, nil, nil, status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err) } diff --git a/shared/management/domain/validate.go b/shared/management/domain/validate.go index 74a7901c19e..7100eeb4c55 100644 --- a/shared/management/domain/validate.go +++ b/shared/management/domain/validate.go @@ -7,11 +7,13 @@ import ( "sync" ) -const maxDomains = 32 +const maxFQDN = 32 var regexCache = map[string]*regexp.Regexp{} var regexCacheMu sync.Mutex +var fqdnRegex = regexp.MustCompile(`^(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`) + func buildDomainRegex(allowWildcard, allowSingleToplevel bool) *regexp.Regexp { key := fmt.Sprintf("%t:%t", allowWildcard, allowSingleToplevel) @@ -44,19 +46,19 @@ func buildDomainRegex(allowWildcard, allowSingleToplevel bool) *regexp.Regexp { return re } -// ValidateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList. -func ValidateDomains(domains []string) (List, error) { - if len(domains) == 0 { - return nil, fmt.Errorf("domains list is empty") +// ValidateFQDNs checks if each domain in the list is valid and returns a punycode-encoded DomainList. +func ValidateFQDNs(fqdns []string) (List, error) { + if len(fqdns) == 0 { + return nil, fmt.Errorf("fqdns list is empty") } - if len(domains) > maxDomains { - return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains) + if len(fqdns) > maxFQDN { + return nil, fmt.Errorf("fqdns list exceeds maximum allowed fqdns: %d", maxFQDN) } var domainList List - for _, d := range domains { - validDomain, err := ToValidDomain(d, true, false) + for _, d := range fqdns { + validDomain, err := ToValidFQDN(d) if err != nil { return nil, fmt.Errorf("invalid domain %s: %w", d, err) } @@ -65,21 +67,19 @@ func ValidateDomains(domains []string) (List, error) { return domainList, nil } -// ValidateDomainsList checks if each domain in the list is valid -func ValidateDomainsList(domains []string) error { - if len(domains) == 0 { +// ValidateFQDNsList checks if each domain in the list is valid +func ValidateFQDNsList(fqdns []string) error { + if len(fqdns) == 0 { return nil } - if len(domains) > maxDomains { - return fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains) + if len(fqdns) > maxFQDN { + return fmt.Errorf("fqdns list exceeds maximum allowed fqdns: %d", maxFQDN) } - domainRegex := buildDomainRegex(false, true) - - for _, d := range domains { + for _, d := range fqdns { d := strings.ToLower(d) - if !domainRegex.MatchString(d) { - return fmt.Errorf("invalid domain format: %s", d) + if !fqdnRegex.MatchString(d) { + return fmt.Errorf("invalid fqdns format: %s", d) } } return nil @@ -112,3 +112,18 @@ func ToValidDomain(domain string, allowWildcard, allowSingleToplevel bool) (Doma return punycode, nil } + +// ToValidFQDN converts a domain to a valid fqdn format. +func ToValidFQDN(domain string) (Domain, error) { + // handles length and idna conversion + punycode, err := FromString(domain) + if err != nil { + return "", fmt.Errorf("convert domain to punycode: %s: %w", domain, err) + } + + if !fqdnRegex.MatchString(string(punycode)) { + return "", fmt.Errorf("invalid domain format: %s", domain) + } + + return punycode, nil +} diff --git a/shared/management/domain/validate_test.go b/shared/management/domain/validate_test.go index c52a8ee1055..f71130f1932 100644 --- a/shared/management/domain/validate_test.go +++ b/shared/management/domain/validate_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestValidateDomains(t *testing.T) { +func TestValidateFQDNs(t *testing.T) { tests := []struct { name string domains []string @@ -63,10 +63,10 @@ func TestValidateDomains(t *testing.T) { wantErr: true, }, { - name: "Valid wildcard domain", + name: "Invalid wildcard domain", domains: []string{"*.example.com"}, - expected: List{"*.example.com"}, - wantErr: false, + expected: nil, + wantErr: true, }, { name: "Wildcard with dot domain", @@ -90,16 +90,16 @@ func TestValidateDomains(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := ValidateDomains(tt.domains) + got, err := ValidateFQDNs(tt.domains) assert.Equal(t, tt.wantErr, err != nil) assert.Equal(t, got, tt.expected) }) } } -func TestValidateDomainsList(t *testing.T) { - validDomains := make([]string, maxDomains) - for i := range maxDomains { +func TestValidateFQDNsList(t *testing.T) { + validDomains := make([]string, maxFQDN) + for i := range maxFQDN { validDomains[i] = fmt.Sprintf("example%d.com", i) } @@ -124,7 +124,7 @@ func TestValidateDomainsList(t *testing.T) { wantErr: false, }, { - // Unlike ValidateDomains (which converts to punycode), + // Unlike ValidateFQDNs (which converts to punycode), // ValidateDomainsStrSlice will fail on non-ASCII domain chars. name: "Unicode domain fails (no punycode conversion)", domains: []string{"münchen.de"}, @@ -161,12 +161,12 @@ func TestValidateDomainsList(t *testing.T) { wantErr: true, }, { - name: "Exactly maxDomains items (valid)", + name: "Exactly maxFQDN items (valid)", domains: validDomains, wantErr: false, }, { - name: "Exceeds maxDomains items", + name: "Exceeds maxFQDN items", domains: append(validDomains, "extra.com"), wantErr: true, }, @@ -174,7 +174,7 @@ func TestValidateDomainsList(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := ValidateDomainsList(tt.domains) + err := ValidateFQDNsList(tt.domains) if tt.wantErr { assert.Error(t, err) } else { @@ -183,3 +183,64 @@ func TestValidateDomainsList(t *testing.T) { }) } } + +func TestIsValidDomain(t *testing.T) { + tests := []struct { + name string + domain string + valid bool + }{ + { + name: "Empty domain", + domain: "", + valid: false, + }, + { + name: "Single valid ASCII domain", + domain: "sub.ex-ample.com", + valid: true, + }, + { + name: "Underscores in labels", + domain: "_jabber._tcp.gmail.com", + valid: false, + }, + { + name: "Unicode domain fails (no punycode conversion)", + domain: "münchen.de", + valid: true, + }, + { + name: "Invalid domain format - leading dash", + domain: "-example.com", + valid: false, + }, + { + name: "Invalid domain format - trailing dash", + domain: "example-.com", + valid: false, + }, + { + name: "Valid wildcard domain", + domain: "*.example.com", + valid: true, + }, + { + name: "Wildcard with leading dot - invalid", + domain: ".*.example.com", + valid: false, + }, + { + name: "Invalid wildcard with multiple asterisks", + domain: "a.*.example.com", + valid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + valid := IsValidDomain(tt.domain, true, true) + assert.Equal(t, tt.valid, valid) + }) + } +} From 78c886eb53ecd8d712469488ebbaeca7911b48f4 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Fri, 8 Aug 2025 15:32:33 +0200 Subject: [PATCH 6/6] use domains validate for routes --- .../http/handlers/routes/routes_handler.go | 2 +- shared/management/domain/validate.go | 21 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/management/server/http/handlers/routes/routes_handler.go b/management/server/http/handlers/routes/routes_handler.go index 112d43c08b2..dcd56639320 100644 --- a/management/server/http/handlers/routes/routes_handler.go +++ b/management/server/http/handlers/routes/routes_handler.go @@ -94,7 +94,7 @@ func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) { var networkType route.NetworkType var newPrefix netip.Prefix if req.Domains != nil { - d, err := domain.ValidateFQDNs(*req.Domains) + d, err := domain.ValidateDomains(*req.Domains) if err != nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w) return diff --git a/shared/management/domain/validate.go b/shared/management/domain/validate.go index 7100eeb4c55..bd907852a9e 100644 --- a/shared/management/domain/validate.go +++ b/shared/management/domain/validate.go @@ -67,6 +67,27 @@ func ValidateFQDNs(fqdns []string) (List, error) { return domainList, nil } +// ValidateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList. +func ValidateDomains(domains []string) (List, error) { + if len(domains) == 0 { + return nil, fmt.Errorf("domains list is empty") + } + if len(domains) > maxFQDN { + return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxFQDN) + } + + var domainList List + + for _, d := range domains { + validDomain, err := ToValidDomain(d, true, true) + if err != nil { + return nil, fmt.Errorf("invalid domain %s: %w", d, err) + } + domainList = append(domainList, validDomain) + } + return domainList, nil +} + // ValidateFQDNsList checks if each domain in the list is valid func ValidateFQDNsList(fqdns []string) error { if len(fqdns) == 0 {