|
| 1 | +package config |
| 2 | + |
| 3 | +import ( |
| 4 | + "errors" |
| 5 | + "fmt" |
| 6 | + "net" |
| 7 | + "net/netip" |
| 8 | + "net/url" |
| 9 | + "strings" |
| 10 | + |
| 11 | + log "github.com/sirupsen/logrus" |
| 12 | + |
| 13 | + "github.com/netbirdio/netbird/shared/management/domain" |
| 14 | + mgmProto "github.com/netbirdio/netbird/shared/management/proto" |
| 15 | +) |
| 16 | + |
| 17 | +var ( |
| 18 | + ErrEmptyURL = errors.New("empty URL") |
| 19 | + ErrEmptyHost = errors.New("empty host") |
| 20 | + ErrIPNotAllowed = errors.New("IP address not allowed") |
| 21 | +) |
| 22 | + |
| 23 | +// ServerDomains represents the management server domains extracted from NetBird configuration |
| 24 | +type ServerDomains struct { |
| 25 | + Signal domain.Domain |
| 26 | + Relay []domain.Domain |
| 27 | + Flow domain.Domain |
| 28 | + Stuns []domain.Domain |
| 29 | + Turns []domain.Domain |
| 30 | +} |
| 31 | + |
| 32 | +// ExtractFromNetbirdConfig extracts domain information from NetBird protobuf configuration |
| 33 | +func ExtractFromNetbirdConfig(config *mgmProto.NetbirdConfig) ServerDomains { |
| 34 | + if config == nil { |
| 35 | + return ServerDomains{} |
| 36 | + } |
| 37 | + |
| 38 | + domains := ServerDomains{} |
| 39 | + |
| 40 | + domains.Signal = extractSignalDomain(config) |
| 41 | + domains.Relay = extractRelayDomains(config) |
| 42 | + domains.Flow = extractFlowDomain(config) |
| 43 | + domains.Stuns = extractStunDomains(config) |
| 44 | + domains.Turns = extractTurnDomains(config) |
| 45 | + |
| 46 | + return domains |
| 47 | +} |
| 48 | + |
| 49 | +// ExtractValidDomain extracts a valid domain from a URL, filtering out IP addresses |
| 50 | +func ExtractValidDomain(rawURL string) (domain.Domain, error) { |
| 51 | + if rawURL == "" { |
| 52 | + return "", ErrEmptyURL |
| 53 | + } |
| 54 | + |
| 55 | + parsedURL, err := url.Parse(rawURL) |
| 56 | + if err == nil { |
| 57 | + if domain, err := extractFromParsedURL(parsedURL); err != nil || domain != "" { |
| 58 | + return domain, err |
| 59 | + } |
| 60 | + } |
| 61 | + |
| 62 | + return extractFromRawString(rawURL) |
| 63 | +} |
| 64 | + |
| 65 | +// extractFromParsedURL handles domain extraction from successfully parsed URLs |
| 66 | +func extractFromParsedURL(parsedURL *url.URL) (domain.Domain, error) { |
| 67 | + if parsedURL.Hostname() != "" { |
| 68 | + return extractDomainFromHost(parsedURL.Hostname()) |
| 69 | + } |
| 70 | + |
| 71 | + if parsedURL.Opaque == "" || parsedURL.Scheme == "" { |
| 72 | + return "", nil |
| 73 | + } |
| 74 | + |
| 75 | + // Handle URLs with opaque content (e.g., stun:host:port) |
| 76 | + if strings.Contains(parsedURL.Scheme, ".") { |
| 77 | + // This is likely "domain.com:port" being parsed as scheme:opaque |
| 78 | + reconstructed := parsedURL.Scheme + ":" + parsedURL.Opaque |
| 79 | + if host, _, err := net.SplitHostPort(reconstructed); err == nil { |
| 80 | + return extractDomainFromHost(host) |
| 81 | + } |
| 82 | + return extractDomainFromHost(parsedURL.Scheme) |
| 83 | + } |
| 84 | + |
| 85 | + // Valid scheme with opaque content (e.g., stun:host:port) |
| 86 | + host := parsedURL.Opaque |
| 87 | + if queryIndex := strings.Index(host, "?"); queryIndex > 0 { |
| 88 | + host = host[:queryIndex] |
| 89 | + } |
| 90 | + |
| 91 | + if hostOnly, _, err := net.SplitHostPort(host); err == nil { |
| 92 | + return extractDomainFromHost(hostOnly) |
| 93 | + } |
| 94 | + |
| 95 | + return extractDomainFromHost(host) |
| 96 | +} |
| 97 | + |
| 98 | +// extractFromRawString handles domain extraction when URL parsing fails or returns no results |
| 99 | +func extractFromRawString(rawURL string) (domain.Domain, error) { |
| 100 | + if host, _, err := net.SplitHostPort(rawURL); err == nil { |
| 101 | + return extractDomainFromHost(host) |
| 102 | + } |
| 103 | + |
| 104 | + return extractDomainFromHost(rawURL) |
| 105 | +} |
| 106 | + |
| 107 | +// extractDomainFromHost extracts domain from a host string, filtering out IP addresses |
| 108 | +func extractDomainFromHost(host string) (domain.Domain, error) { |
| 109 | + if host == "" { |
| 110 | + return "", ErrEmptyHost |
| 111 | + } |
| 112 | + |
| 113 | + if _, err := netip.ParseAddr(host); err == nil { |
| 114 | + return "", fmt.Errorf("%w: %s", ErrIPNotAllowed, host) |
| 115 | + } |
| 116 | + |
| 117 | + d, err := domain.FromString(host) |
| 118 | + if err != nil { |
| 119 | + return "", fmt.Errorf("invalid domain: %v", err) |
| 120 | + } |
| 121 | + |
| 122 | + return d, nil |
| 123 | +} |
| 124 | + |
| 125 | +// extractSingleDomain extracts a single domain from a URL with error logging |
| 126 | +func extractSingleDomain(url, serviceType string) domain.Domain { |
| 127 | + if url == "" { |
| 128 | + return "" |
| 129 | + } |
| 130 | + |
| 131 | + d, err := ExtractValidDomain(url) |
| 132 | + if err != nil { |
| 133 | + log.Debugf("Skipping %s: %v", serviceType, err) |
| 134 | + return "" |
| 135 | + } |
| 136 | + |
| 137 | + return d |
| 138 | +} |
| 139 | + |
| 140 | +// extractMultipleDomains extracts multiple domains from URLs with error logging |
| 141 | +func extractMultipleDomains(urls []string, serviceType string) []domain.Domain { |
| 142 | + var domains []domain.Domain |
| 143 | + for _, url := range urls { |
| 144 | + if url == "" { |
| 145 | + continue |
| 146 | + } |
| 147 | + d, err := ExtractValidDomain(url) |
| 148 | + if err != nil { |
| 149 | + log.Debugf("Skipping %s: %v", serviceType, err) |
| 150 | + continue |
| 151 | + } |
| 152 | + domains = append(domains, d) |
| 153 | + } |
| 154 | + return domains |
| 155 | +} |
| 156 | + |
| 157 | +// extractSignalDomain extracts the signal domain from NetBird configuration. |
| 158 | +func extractSignalDomain(config *mgmProto.NetbirdConfig) domain.Domain { |
| 159 | + if config.Signal != nil { |
| 160 | + return extractSingleDomain(config.Signal.Uri, "signal") |
| 161 | + } |
| 162 | + return "" |
| 163 | +} |
| 164 | + |
| 165 | +// extractRelayDomains extracts relay server domains from NetBird configuration. |
| 166 | +func extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain { |
| 167 | + if config.Relay != nil { |
| 168 | + return extractMultipleDomains(config.Relay.Urls, "relay") |
| 169 | + } |
| 170 | + return nil |
| 171 | +} |
| 172 | + |
| 173 | +// extractFlowDomain extracts the traffic flow domain from NetBird configuration. |
| 174 | +func extractFlowDomain(config *mgmProto.NetbirdConfig) domain.Domain { |
| 175 | + if config.Flow != nil { |
| 176 | + return extractSingleDomain(config.Flow.Url, "flow") |
| 177 | + } |
| 178 | + return "" |
| 179 | +} |
| 180 | + |
| 181 | +// extractStunDomains extracts STUN server domains from NetBird configuration. |
| 182 | +func extractStunDomains(config *mgmProto.NetbirdConfig) []domain.Domain { |
| 183 | + var urls []string |
| 184 | + for _, stun := range config.Stuns { |
| 185 | + if stun != nil && stun.Uri != "" { |
| 186 | + urls = append(urls, stun.Uri) |
| 187 | + } |
| 188 | + } |
| 189 | + return extractMultipleDomains(urls, "STUN") |
| 190 | +} |
| 191 | + |
| 192 | +// extractTurnDomains extracts TURN server domains from NetBird configuration. |
| 193 | +func extractTurnDomains(config *mgmProto.NetbirdConfig) []domain.Domain { |
| 194 | + var urls []string |
| 195 | + for _, turn := range config.Turns { |
| 196 | + if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" { |
| 197 | + urls = append(urls, turn.HostConfig.Uri) |
| 198 | + } |
| 199 | + } |
| 200 | + return extractMultipleDomains(urls, "TURN") |
| 201 | +} |
0 commit comments