Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion client/cmd/up.go
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ func validateDnsLabels(labels []string) (domain.List, error) {
return domains, nil
}

domains, err = domain.ValidateDomains(labels)
domains, err = domain.ValidateFQDNs(labels)
if err != nil {
return nil, fmt.Errorf("failed to validate dns labels: %v", err)
}
Expand Down
14 changes: 4 additions & 10 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"net/netip"
"os"
"reflect"
"regexp"
"slices"
"strconv"
"strings"
Expand Down Expand Up @@ -45,6 +44,7 @@ import (
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/status"
)

Expand Down Expand Up @@ -224,7 +224,7 @@ func BuildManager(
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
am.singleAccountMode = singleAccountModeDomain != "" && accountsCounter <= 1
if am.singleAccountMode {
if !isDomainValid(singleAccountModeDomain) {
if !domain.IsValidDomain(singleAccountModeDomain, false, false) {
return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain)
}
am.singleAccountModeDomain = singleAccountModeDomain
Expand Down Expand Up @@ -401,7 +401,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra
return status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
}

if newSettings.DNSDomain != "" && !isDomainValid(newSettings.DNSDomain) {
if newSettings.DNSDomain != "" && !domain.IsValidDomain(newSettings.DNSDomain, false, true) {
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
}

Expand Down Expand Up @@ -1518,7 +1518,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
return userAuth.AccountId, nil
}

if userAuth.DomainCategory != types.PrivateCategory || !isDomainValid(userAuth.Domain) {
if userAuth.DomainCategory != types.PrivateCategory || !domain.IsValidDomain(userAuth.Domain, false, false) {
return am.GetAccountIDByUserID(ctx, userAuth.UserId, userAuth.Domain)
}

Expand Down Expand Up @@ -1701,12 +1701,6 @@ func (am *DefaultAccountManager) HasConnectedChannel(peerID string) bool {
return am.peersUpdateManager.HasChannel(peerID)
}

var invalidDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)

func isDomainValid(domain string) bool {
return invalidDomainRegexp.MatchString(domain)
}

// GetDNSDomain returns the configured dnsDomain
func (am *DefaultAccountManager) GetDNSDomain(settings *types.Settings) string {
if settings == nil {
Expand Down
6 changes: 3 additions & 3 deletions management/server/http/handlers/routes/routes_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ import (

"github.com/gorilla/mux"

"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/route"
)

const failedToConvertRoute = "failed to convert route to response: %v"
Expand Down Expand Up @@ -217,7 +217,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
}

if req.Domains != nil {
d, err := domain.ValidateDomains(*req.Domains)
d, err := domain.ValidateFQDNs(*req.Domains)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
return
Expand Down
27 changes: 3 additions & 24 deletions management/server/nameserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@ package server

import (
"context"
"errors"
"regexp"
"unicode/utf8"

"github.com/miekg/dns"
"github.com/rs/xid"

nbdns "github.com/netbirdio/netbird/dns"
Expand All @@ -15,13 +12,10 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
nbDomain "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/status"
)

const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*[*.a-z]{1,}$`

var invalidDomainName = errors.New("invalid domain name")

// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Read)
Expand Down Expand Up @@ -268,8 +262,8 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo
}

for _, domain := range domains {
if err := validateDomain(domain); err != nil {
return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s %q", domain, err)
if !nbDomain.IsValidDomain(domain, false, true) {
return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s", domain)
}
}
return nil
Expand Down Expand Up @@ -313,18 +307,3 @@ func validateGroups(list []string, groups map[string]*types.Group) error {

return nil
}

var domainMatcher = regexp.MustCompile(domainPattern)

func validateDomain(domain string) error {
if !domainMatcher.MatchString(domain) {
return errors.New("domain should consists of only letters, numbers, and hyphens with no leading, trailing hyphens, or spaces")
}

_, valid := dns.IsDomainName(domain)
if !valid {
return invalidDomainName
}

return nil
}
10 changes: 5 additions & 5 deletions management/server/nameserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -910,12 +910,12 @@ func TestValidateDomain(t *testing.T) {
errFunc: require.NoError,
},
{
name: "Valid domain name with trailing dot",
name: "Invalid domain name with trailing dot",
domain: "example.",
errFunc: require.NoError,
errFunc: require.Error,
},
{
name: "Invalid wildcard domain name",
name: "Valid wildcard domain name",
domain: "*.example",
errFunc: require.Error,
},
Expand All @@ -932,7 +932,7 @@ func TestValidateDomain(t *testing.T) {
{
name: "Invalid domain name with double hyphen",
domain: "test--example.com",
errFunc: require.Error,
errFunc: require.NoError, // Note: Double hyphen is not valid but due to punicode hard to filter out
},
{
name: "Invalid domain name with a label exceeding 63 characters",
Expand Down Expand Up @@ -968,7 +968,7 @@ func TestValidateDomain(t *testing.T) {

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
testCase.errFunc(t, validateDomain(testCase.domain))
testCase.errFunc(t, validateDomainInput(false, []string{testCase.domain}, false))
})
}

Expand Down
8 changes: 3 additions & 5 deletions management/server/networks/resources/types/resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@ import (
"errors"
"fmt"
"net/netip"
"regexp"

"github.com/rs/xid"

nbDomain "github.com/netbirdio/netbird/shared/management/domain"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/route"
nbDomain "github.com/netbirdio/netbird/shared/management/domain"

"github.com/netbirdio/netbird/shared/management/http/api"
)
Expand Down Expand Up @@ -166,9 +165,8 @@ func GetResourceType(address string) (NetworkResourceType, string, netip.Prefix,
return Host, "", netip.PrefixFrom(ip, ip.BitLen()), nil
}

domainRegex := regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`)
if domainRegex.MatchString(address) {
return Domain, address, netip.Prefix{}, nil
if domain, err := nbDomain.ToValidDomain(address, true, false); err == nil {
return Domain, string(domain), netip.Prefix{}, nil
}

return "", "", netip.Prefix{}, errors.New("not a valid host, subnet, or domain")
Expand Down
2 changes: 1 addition & 1 deletion management/server/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
}
}

if err := domain.ValidateDomainsList(peer.ExtraDNSLabels); err != nil {
if err := domain.ValidateFQDNsList(peer.ExtraDNSLabels); err != nil {
return nil, nil, nil, status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err)
}

Expand Down
136 changes: 115 additions & 21 deletions shared/management/domain/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,53 +4,147 @@ import (
"fmt"
"regexp"
"strings"
"sync"
)

const maxDomains = 32
const maxFQDN = 32

var domainRegex = regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
var regexCache = map[string]*regexp.Regexp{}
var regexCacheMu sync.Mutex

var fqdnRegex = regexp.MustCompile(`^(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)

func buildDomainRegex(allowWildcard, allowSingleToplevel bool) *regexp.Regexp {
key := fmt.Sprintf("%t:%t", allowWildcard, allowSingleToplevel)

regexCacheMu.Lock()
defer regexCacheMu.Unlock()

if re, ok := regexCache[key]; ok {
return re
}

var pattern strings.Builder
pattern.WriteString("^")

if allowWildcard {
pattern.WriteString(`(?:\*\.)?`)
}

label := `(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?`

if allowSingleToplevel {
pattern.WriteString(label + `(?:\.` + label + `)*`)
} else {
pattern.WriteString(label + `(?:\.` + label + `)+`)
}

pattern.WriteString("$")

re := regexp.MustCompile(pattern.String())
regexCache[key] = re
return re
}

// ValidateFQDNs checks if each domain in the list is valid and returns a punycode-encoded DomainList.
func ValidateFQDNs(fqdns []string) (List, error) {
if len(fqdns) == 0 {
return nil, fmt.Errorf("fqdns list is empty")
}
if len(fqdns) > maxFQDN {
return nil, fmt.Errorf("fqdns list exceeds maximum allowed fqdns: %d", maxFQDN)
}

var domainList List

for _, d := range fqdns {
validDomain, err := ToValidFQDN(d)
if err != nil {
return nil, fmt.Errorf("invalid domain %s: %w", d, err)
}
domainList = append(domainList, validDomain)
}
return domainList, nil
}

// ValidateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList.
func ValidateDomains(domains []string) (List, error) {
if len(domains) == 0 {
return nil, fmt.Errorf("domains list is empty")
}
if len(domains) > maxDomains {
return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
if len(domains) > maxFQDN {
return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxFQDN)
}

var domainList List

for _, d := range domains {
// handles length and idna conversion
punycode, err := FromString(d)
validDomain, err := ToValidDomain(d, true, true)
if err != nil {
return domainList, fmt.Errorf("convert domain to punycode: %s: %w", d, err)
}

if !domainRegex.MatchString(string(punycode)) {
return domainList, fmt.Errorf("invalid domain format: %s", d)
return nil, fmt.Errorf("invalid domain %s: %w", d, err)
}

domainList = append(domainList, punycode)
domainList = append(domainList, validDomain)
}
return domainList, nil
}

// ValidateDomainsList checks if each domain in the list is valid
func ValidateDomainsList(domains []string) error {
if len(domains) == 0 {
// ValidateFQDNsList checks if each domain in the list is valid
func ValidateFQDNsList(fqdns []string) error {
if len(fqdns) == 0 {
return nil
}
if len(domains) > maxDomains {
return fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
if len(fqdns) > maxFQDN {
return fmt.Errorf("fqdns list exceeds maximum allowed fqdns: %d", maxFQDN)
}

for _, d := range domains {
for _, d := range fqdns {
d := strings.ToLower(d)
if !domainRegex.MatchString(d) {
return fmt.Errorf("invalid domain format: %s", d)
if !fqdnRegex.MatchString(d) {
return fmt.Errorf("invalid fqdns format: %s", d)
}
}
return nil
}

// IsValidDomain checks if the given domain is valid.
func IsValidDomain(domain string, allowWildcard, allowSingleToplevel bool) bool {
// handles length and idna conversion
punycode, err := FromString(domain)
if err != nil {
return false
}

domainRegex := buildDomainRegex(allowWildcard, allowSingleToplevel)
return domainRegex.MatchString(string(punycode))
}

// ToValidDomain converts a domain to a valid domain format.
func ToValidDomain(domain string, allowWildcard, allowSingleToplevel bool) (Domain, error) {
// handles length and idna conversion
punycode, err := FromString(domain)
if err != nil {
return "", fmt.Errorf("convert domain to punycode: %s: %w", domain, err)
}

domainRegex := buildDomainRegex(allowWildcard, allowSingleToplevel)
if !domainRegex.MatchString(string(punycode)) {
return "", fmt.Errorf("invalid domain format: %s", domain)
}

return punycode, nil
}

// ToValidFQDN converts a domain to a valid fqdn format.
func ToValidFQDN(domain string) (Domain, error) {
// handles length and idna conversion
punycode, err := FromString(domain)
if err != nil {
return "", fmt.Errorf("convert domain to punycode: %s: %w", domain, err)
}

if !fqdnRegex.MatchString(string(punycode)) {
return "", fmt.Errorf("invalid domain format: %s", domain)
}

return punycode, nil
}
Loading
Loading