Skip to content

Commit 8d66784

Browse files
authored
NM-79: Egress Domain Routing (#1087)
* resolve egress domains to ips * handle egress updates, resolve domain to ips * use internal resolver for querying * resolve using net Lookup if not through nameserver querying * remove node ip from ns list for internal lookup * fix egress domains sever update * use api for checkin * only update cache if old ips are not resolved
1 parent 6009e1b commit 8d66784

File tree

8 files changed

+290
-149
lines changed

8 files changed

+290
-149
lines changed

dns/resolver.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,58 @@ func handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
184184
_ = w.WriteMsg(reply)
185185
}
186186

187+
func FindDnsAns(domain string) []net.IP {
188+
nslist := []string{}
189+
if config.Netclient().CurrGwNmIP != nil {
190+
nslist = append(nslist, config.Netclient().CurrGwNmIP.String())
191+
} else {
192+
query := canonicalizeDomainForMatching(domain)
193+
matchNsList := findBestMatch(query, config.GetServer(config.CurrServer).DnsNameservers)
194+
for i := len(matchNsList) - 1; i >= 0; i-- {
195+
nslist = append(nslist, matchNsList[i].IPs...)
196+
}
197+
}
198+
nslist = append(nslist, "8.8.8.8")
199+
nslist = append(nslist, "8.8.4.4")
200+
nslist = append(nslist, "1.1.1.1")
201+
nslist = append(nslist, "2001:4860:4860::8888")
202+
nslist = append(nslist, "2001:4860:4860::8844")
203+
server := config.GetServer(config.CurrServer)
204+
if server != nil {
205+
nslist = append(nslist, server.NameServers...)
206+
}
207+
for _, v := range nslist {
208+
if strings.Contains(v, ":") {
209+
v = "[" + v + "]"
210+
}
211+
if ansIps, err := internalLookupA(domain, v); err == nil && len(ansIps) > 0 {
212+
return ansIps
213+
}
214+
}
215+
return []net.IP{}
216+
}
217+
218+
// Build a query and send via your pool/upstream.
219+
func internalLookupA(name, ns string) ([]net.IP, error) {
220+
r := new(dns.Msg)
221+
r.Id = dns.Id()
222+
r.RecursionDesired = true
223+
r.SetQuestion(dns.Fqdn(name), dns.TypeA)
224+
r.SetEdns0(1232, true)
225+
226+
resp, err := exchangeDNSQueryWithPool(r, ns) // e.g., "1.1.1.1"
227+
if err != nil || resp == nil {
228+
return nil, err
229+
}
230+
var ips []net.IP
231+
for _, rr := range resp.Answer {
232+
if a, ok := rr.(*dns.A); ok {
233+
ips = append(ips, a.A)
234+
}
235+
}
236+
return ips, nil
237+
}
238+
187239
// Register A record
188240
func (d *DNSResolver) RegisterA(record dnsRecord) error {
189241
dnsMapMutex.Lock()

functions/mqhandlers.go

Lines changed: 203 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"net"
99
"net/http"
1010
"reflect"
11+
"slices"
1112
"strconv"
1213
"strings"
1314
"sync"
@@ -21,6 +22,7 @@ import (
2122
"github.com/gravitl/netclient/daemon"
2223
"github.com/gravitl/netclient/dns"
2324
"github.com/gravitl/netclient/firewall"
25+
"github.com/gravitl/netclient/metrics"
2426
"github.com/gravitl/netclient/ncutils"
2527
"github.com/gravitl/netclient/networking"
2628
"github.com/gravitl/netclient/wireguard"
@@ -292,6 +294,10 @@ func HostPeerUpdate(client mqtt.Client, msg mqtt.Message) {
292294
if peerUpdate.ServerConfig.EndpointDetection {
293295
go handleEndpointDetection(peerUpdate.Peers, peerUpdate.HostNetworkInfo)
294296
}
297+
if len(peerUpdate.EgressWithDomains) > 0 {
298+
wireguard.SetEgressDomains(peerUpdate.EgressWithDomains)
299+
}
300+
go CheckEgressDomainUpdates()
295301

296302
if len(server.NameServers) != len(peerUpdate.NameServers) || reflect.DeepEqual(server.NameServers, peerUpdate.NameServers) {
297303
server.NameServers = peerUpdate.NameServers
@@ -495,6 +501,11 @@ func HostUpdate(client mqtt.Client, msg mqtt.Message) {
495501
mqFallbackPull(response, resetInterface, replacePeers)
496502
}
497503
writeToDisk = false
504+
case models.EgressUpdate:
505+
506+
slog.Info("processing egress update", "domain", hostUpdate.EgressDomain.Domain)
507+
go processEgressDomain(hostUpdate.EgressDomain)
508+
498509
default:
499510
slog.Error("unknown host action", "action", hostUpdate.Action)
500511
return
@@ -571,10 +582,10 @@ func handleEndpointDetection(peers []wgtypes.PeerConfig, peerInfo models.HostInf
571582
}
572583
}
573584
if peerInfo, ok := peerInfo[peerPubKey]; ok {
574-
if peerInfo.IsStatic {
575-
// peer is a static host shouldn't disturb the configuration set by the user
576-
continue
577-
}
585+
// if peerInfo.IsStatic {
586+
// // peer is a static host shouldn't disturb the configuration set by the user
587+
// continue
588+
// }
578589
for i := range peerInfo.Interfaces {
579590
peerIface := peerInfo.Interfaces[i]
580591
peerIP := peerIface.Address.IP
@@ -836,9 +847,197 @@ func mqFallbackPull(pullResponse models.HostPull, resetInterface, replacePeers b
836847
cache.EndpointCache = sync.Map{}
837848
cache.SkipEndpointCache = sync.Map{}
838849
}
850+
if len(pullResponse.EgressWithDomains) > 0 {
851+
wireguard.SetEgressDomains(pullResponse.EgressWithDomains)
852+
}
853+
go CheckEgressDomainUpdates()
839854
handleFwUpdate(serverName, &pullResponse.FwUpdate)
840855

841856
if resetInterface {
842857
resetInterfaceFunc()
843858
}
844859
}
860+
861+
func CheckEgressDomainUpdates() {
862+
slog.Debug("checking egress domain updates")
863+
864+
egressDomains := wireguard.GetEgressDomains()
865+
if len(egressDomains) == 0 {
866+
slog.Debug("no egress domains to process")
867+
return
868+
}
869+
870+
slog.Info("processing egress domains", "count", len(egressDomains))
871+
for _, domainI := range egressDomains {
872+
slog.Debug("checking egress domain", "domain", domainI.Domain)
873+
processEgressDomain(domainI)
874+
}
875+
}
876+
877+
func processEgressDomain(domainI models.EgressDomain) {
878+
slog.Info("processing egress domain", "domain", domainI.Domain)
879+
880+
// Resolve domain to IP addresses
881+
ips, err := resolveDomainToIPs(domainI.Domain)
882+
if err != nil {
883+
slog.Error("failed to resolve egress domain", "domain", domainI.Domain, "error", err)
884+
return
885+
}
886+
if len(ips) == 0 {
887+
slog.Warn("no IP addresses resolved for domain", "domain", domainI.Domain)
888+
return
889+
}
890+
891+
// Get current cached IPs for this domain
892+
currentIps := wireguard.GetDomainAnsFromCache(domainI)
893+
slog.Debug("domain resolution check", "domain", domainI.Domain, "domain_id", domainI.ID, "cached_ips", currentIps, "resolved_ips", ips)
894+
895+
// Check if there are any changes
896+
hasChanges := false
897+
shouldUpdateCache := false
898+
899+
if len(currentIps) == 0 {
900+
// First time processing this domain or cache miss
901+
hasChanges = true
902+
shouldUpdateCache = true
903+
slog.Info("first time processing domain or cache miss", "domain", domainI.Domain, "ips", ips)
904+
} else {
905+
// Compare current and new IPs
906+
slices.Sort(currentIps)
907+
slices.Sort(ips)
908+
if !slices.Equal(currentIps, ips) {
909+
hasChanges = true
910+
slog.Info("domain IPs changed", "domain", domainI.Domain, "old_ips", currentIps, "new_ips", ips)
911+
912+
// Check if old IPs are still reachable before updating cache
913+
oldIPsReachable := checkIPConnectivity(currentIps)
914+
if !oldIPsReachable {
915+
shouldUpdateCache = true
916+
slog.Info("old IPs are not reachable, updating cache with new IPs", "domain", domainI.Domain, "old_ips", currentIps, "new_ips", ips)
917+
} else {
918+
slog.Info("old IPs are still reachable, keeping current cache to maintain stability", "domain", domainI.Domain, "old_ips", currentIps)
919+
}
920+
} else {
921+
slog.Debug("no changes detected for domain", "domain", domainI.Domain, "ips", ips)
922+
}
923+
}
924+
925+
// Only proceed if there are changes and we should update the cache
926+
if !hasChanges || !shouldUpdateCache {
927+
if !hasChanges {
928+
slog.Debug("skipping server update - no changes detected", "domain", domainI.Domain)
929+
} else {
930+
slog.Debug("skipping server update - old IPs are still reachable", "domain", domainI.Domain)
931+
}
932+
return
933+
}
934+
935+
// Update the cache with new IPs only after confirming changes and connectivity check
936+
wireguard.SetDomainAnsInCache(domainI, ips)
937+
// Clear existing ranges and add new ones
938+
domainI.Node.EgressGatewayRanges = []string{}
939+
for _, ip := range ips {
940+
// Add as /32 for IPv4 or /128 for IPv6
941+
if net.ParseIP(ip).To4() != nil {
942+
domainI.Node.EgressGatewayRanges = append(domainI.Node.EgressGatewayRanges, ip+"/32")
943+
} else {
944+
domainI.Node.EgressGatewayRanges = append(domainI.Node.EgressGatewayRanges, ip+"/128")
945+
}
946+
}
947+
948+
slog.Info("sending egress domain update to server", "domain", domainI.Domain, "ips", ips, "ranges", domainI.Node.EgressGatewayRanges)
949+
950+
// Send the updated host info back to server
951+
hostServerUpdate(models.HostUpdate{
952+
Action: models.EgressUpdate,
953+
Host: domainI.Host,
954+
Node: domainI.Node,
955+
EgressDomain: domainI,
956+
})
957+
}
958+
959+
// checkIPConnectivity checks if all of the given IP addresses are reachable
960+
func checkIPConnectivity(ips []string) bool {
961+
if len(ips) == 0 {
962+
return false
963+
}
964+
965+
// Check connectivity for each IP - ALL must be reachable
966+
for _, ipStr := range ips {
967+
ip := net.ParseIP(ipStr)
968+
if ip == nil {
969+
slog.Debug("invalid IP address", "ip", ipStr)
970+
return false
971+
}
972+
973+
ipReachable := false
974+
975+
// Use a simple TCP connection test to check reachability
976+
// Try common ports that might be open (80, 443, 22)
977+
ports := []int{80, 443, 22}
978+
for _, port := range ports {
979+
address := fmt.Sprintf("%s:%d", ipStr, port)
980+
conn, err := net.DialTimeout("tcp", address, 3*time.Second)
981+
if err == nil {
982+
conn.Close()
983+
slog.Debug("IP is reachable", "ip", ipStr, "port", port)
984+
ipReachable = true
985+
break
986+
}
987+
}
988+
989+
// If TCP connection fails, try ICMP ping for external IPs
990+
if !ipReachable {
991+
// Use the existing ping functionality from metrics package
992+
connected, _ := metrics.ExtPeerConnStatus(ipStr)
993+
if connected {
994+
slog.Debug("IP is reachable via ping", "ip", ipStr)
995+
ipReachable = true
996+
}
997+
}
998+
999+
// If this IP is not reachable, fail the entire check
1000+
if !ipReachable {
1001+
slog.Debug("IP is not reachable", "ip", ipStr)
1002+
return false
1003+
}
1004+
}
1005+
1006+
slog.Debug("all IPs are reachable", "ips", ips)
1007+
return true
1008+
}
1009+
1010+
// resolveDomainToIPs resolves a domain name to IP addresses using the existing DNS infrastructure
1011+
func resolveDomainToIPs(domain string) ([]string, error) {
1012+
if domain == "" {
1013+
return nil, fmt.Errorf("domain cannot be empty")
1014+
}
1015+
1016+
// Use the existing DNS infrastructure to resolve the domain
1017+
ips := dns.FindDnsAns(domain)
1018+
if len(ips) == 0 {
1019+
lookUpIPs, err := net.LookupIP(domain)
1020+
if err != nil {
1021+
return nil, fmt.Errorf("failed to resolve domain %s: %w", domain, err)
1022+
}
1023+
if len(lookUpIPs) == 0 {
1024+
return nil, fmt.Errorf("no IP addresses found for domain %s", domain)
1025+
}
1026+
ips = lookUpIPs
1027+
}
1028+
1029+
// Filter out any invalid IPs and return unique IPs
1030+
uniqueIPs := make(map[string]string)
1031+
for _, ip := range ips {
1032+
if ip != nil {
1033+
uniqueIPs[ip.String()] = ip.String()
1034+
}
1035+
}
1036+
1037+
result := make([]string, 0, len(uniqueIPs))
1038+
for _, ip := range uniqueIPs {
1039+
result = append(result, ip)
1040+
}
1041+
1042+
return result, nil
1043+
}

functions/mqpublish.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,7 @@ func hostServerUpdate(hu models.HostUpdate) error {
163163
}
164164

165165
func checkin() {
166-
if err := PublishHostUpdate(config.CurrServer, models.HostMqAction(models.CheckIn)); err != nil {
167-
logger.Log(0, "error publishing checkin", err.Error())
168-
return
169-
}
166+
hostServerUpdate(models.HostUpdate{Action: models.CheckIn})
170167
}
171168

172169
// PublishNodeUpdate -- pushes node to broker
@@ -416,6 +413,7 @@ func UpdateHostSettings(fallback bool) error {
416413
}
417414
}
418415
}
416+
go CheckEgressDomainUpdates()
419417
if restartDaemon {
420418
if err := daemon.Restart(); err != nil {
421419
slog.Error("failed to restart daemon", "error", err)

0 commit comments

Comments
 (0)