88 "net"
99 "net/http"
1010 "reflect"
11+ "slices"
1112 "strconv"
1213 "strings"
1314 "sync"
@@ -21,6 +22,7 @@ import (
2122 "github.com/gravitl/netclient/daemon"
2223 "github.com/gravitl/netclient/dns"
2324 "github.com/gravitl/netclient/firewall"
25+ "github.com/gravitl/netclient/metrics"
2426 "github.com/gravitl/netclient/ncutils"
2527 "github.com/gravitl/netclient/networking"
2628 "github.com/gravitl/netclient/wireguard"
@@ -292,6 +294,10 @@ func HostPeerUpdate(client mqtt.Client, msg mqtt.Message) {
292294 if peerUpdate .ServerConfig .EndpointDetection {
293295 go handleEndpointDetection (peerUpdate .Peers , peerUpdate .HostNetworkInfo )
294296 }
297+ if len (peerUpdate .EgressWithDomains ) > 0 {
298+ wireguard .SetEgressDomains (peerUpdate .EgressWithDomains )
299+ }
300+ go CheckEgressDomainUpdates ()
295301
296302 if len (server .NameServers ) != len (peerUpdate .NameServers ) || reflect .DeepEqual (server .NameServers , peerUpdate .NameServers ) {
297303 server .NameServers = peerUpdate .NameServers
@@ -495,6 +501,11 @@ func HostUpdate(client mqtt.Client, msg mqtt.Message) {
495501 mqFallbackPull (response , resetInterface , replacePeers )
496502 }
497503 writeToDisk = false
504+ case models .EgressUpdate :
505+
506+ slog .Info ("processing egress update" , "domain" , hostUpdate .EgressDomain .Domain )
507+ go processEgressDomain (hostUpdate .EgressDomain )
508+
498509 default :
499510 slog .Error ("unknown host action" , "action" , hostUpdate .Action )
500511 return
@@ -571,10 +582,10 @@ func handleEndpointDetection(peers []wgtypes.PeerConfig, peerInfo models.HostInf
571582 }
572583 }
573584 if peerInfo , ok := peerInfo [peerPubKey ]; ok {
574- if peerInfo .IsStatic {
575- // peer is a static host shouldn't disturb the configuration set by the user
576- continue
577- }
585+ // if peerInfo.IsStatic {
586+ // // peer is a static host shouldn't disturb the configuration set by the user
587+ // continue
588+ // }
578589 for i := range peerInfo .Interfaces {
579590 peerIface := peerInfo .Interfaces [i ]
580591 peerIP := peerIface .Address .IP
@@ -836,9 +847,197 @@ func mqFallbackPull(pullResponse models.HostPull, resetInterface, replacePeers b
836847 cache .EndpointCache = sync.Map {}
837848 cache .SkipEndpointCache = sync.Map {}
838849 }
850+ if len (pullResponse .EgressWithDomains ) > 0 {
851+ wireguard .SetEgressDomains (pullResponse .EgressWithDomains )
852+ }
853+ go CheckEgressDomainUpdates ()
839854 handleFwUpdate (serverName , & pullResponse .FwUpdate )
840855
841856 if resetInterface {
842857 resetInterfaceFunc ()
843858 }
844859}
860+
861+ func CheckEgressDomainUpdates () {
862+ slog .Debug ("checking egress domain updates" )
863+
864+ egressDomains := wireguard .GetEgressDomains ()
865+ if len (egressDomains ) == 0 {
866+ slog .Debug ("no egress domains to process" )
867+ return
868+ }
869+
870+ slog .Info ("processing egress domains" , "count" , len (egressDomains ))
871+ for _ , domainI := range egressDomains {
872+ slog .Debug ("checking egress domain" , "domain" , domainI .Domain )
873+ processEgressDomain (domainI )
874+ }
875+ }
876+
877+ func processEgressDomain (domainI models.EgressDomain ) {
878+ slog .Info ("processing egress domain" , "domain" , domainI .Domain )
879+
880+ // Resolve domain to IP addresses
881+ ips , err := resolveDomainToIPs (domainI .Domain )
882+ if err != nil {
883+ slog .Error ("failed to resolve egress domain" , "domain" , domainI .Domain , "error" , err )
884+ return
885+ }
886+ if len (ips ) == 0 {
887+ slog .Warn ("no IP addresses resolved for domain" , "domain" , domainI .Domain )
888+ return
889+ }
890+
891+ // Get current cached IPs for this domain
892+ currentIps := wireguard .GetDomainAnsFromCache (domainI )
893+ slog .Debug ("domain resolution check" , "domain" , domainI .Domain , "domain_id" , domainI .ID , "cached_ips" , currentIps , "resolved_ips" , ips )
894+
895+ // Check if there are any changes
896+ hasChanges := false
897+ shouldUpdateCache := false
898+
899+ if len (currentIps ) == 0 {
900+ // First time processing this domain or cache miss
901+ hasChanges = true
902+ shouldUpdateCache = true
903+ slog .Info ("first time processing domain or cache miss" , "domain" , domainI .Domain , "ips" , ips )
904+ } else {
905+ // Compare current and new IPs
906+ slices .Sort (currentIps )
907+ slices .Sort (ips )
908+ if ! slices .Equal (currentIps , ips ) {
909+ hasChanges = true
910+ slog .Info ("domain IPs changed" , "domain" , domainI .Domain , "old_ips" , currentIps , "new_ips" , ips )
911+
912+ // Check if old IPs are still reachable before updating cache
913+ oldIPsReachable := checkIPConnectivity (currentIps )
914+ if ! oldIPsReachable {
915+ shouldUpdateCache = true
916+ slog .Info ("old IPs are not reachable, updating cache with new IPs" , "domain" , domainI .Domain , "old_ips" , currentIps , "new_ips" , ips )
917+ } else {
918+ slog .Info ("old IPs are still reachable, keeping current cache to maintain stability" , "domain" , domainI .Domain , "old_ips" , currentIps )
919+ }
920+ } else {
921+ slog .Debug ("no changes detected for domain" , "domain" , domainI .Domain , "ips" , ips )
922+ }
923+ }
924+
925+ // Only proceed if there are changes and we should update the cache
926+ if ! hasChanges || ! shouldUpdateCache {
927+ if ! hasChanges {
928+ slog .Debug ("skipping server update - no changes detected" , "domain" , domainI .Domain )
929+ } else {
930+ slog .Debug ("skipping server update - old IPs are still reachable" , "domain" , domainI .Domain )
931+ }
932+ return
933+ }
934+
935+ // Update the cache with new IPs only after confirming changes and connectivity check
936+ wireguard .SetDomainAnsInCache (domainI , ips )
937+ // Clear existing ranges and add new ones
938+ domainI .Node .EgressGatewayRanges = []string {}
939+ for _ , ip := range ips {
940+ // Add as /32 for IPv4 or /128 for IPv6
941+ if net .ParseIP (ip ).To4 () != nil {
942+ domainI .Node .EgressGatewayRanges = append (domainI .Node .EgressGatewayRanges , ip + "/32" )
943+ } else {
944+ domainI .Node .EgressGatewayRanges = append (domainI .Node .EgressGatewayRanges , ip + "/128" )
945+ }
946+ }
947+
948+ slog .Info ("sending egress domain update to server" , "domain" , domainI .Domain , "ips" , ips , "ranges" , domainI .Node .EgressGatewayRanges )
949+
950+ // Send the updated host info back to server
951+ hostServerUpdate (models.HostUpdate {
952+ Action : models .EgressUpdate ,
953+ Host : domainI .Host ,
954+ Node : domainI .Node ,
955+ EgressDomain : domainI ,
956+ })
957+ }
958+
959+ // checkIPConnectivity checks if all of the given IP addresses are reachable
960+ func checkIPConnectivity (ips []string ) bool {
961+ if len (ips ) == 0 {
962+ return false
963+ }
964+
965+ // Check connectivity for each IP - ALL must be reachable
966+ for _ , ipStr := range ips {
967+ ip := net .ParseIP (ipStr )
968+ if ip == nil {
969+ slog .Debug ("invalid IP address" , "ip" , ipStr )
970+ return false
971+ }
972+
973+ ipReachable := false
974+
975+ // Use a simple TCP connection test to check reachability
976+ // Try common ports that might be open (80, 443, 22)
977+ ports := []int {80 , 443 , 22 }
978+ for _ , port := range ports {
979+ address := fmt .Sprintf ("%s:%d" , ipStr , port )
980+ conn , err := net .DialTimeout ("tcp" , address , 3 * time .Second )
981+ if err == nil {
982+ conn .Close ()
983+ slog .Debug ("IP is reachable" , "ip" , ipStr , "port" , port )
984+ ipReachable = true
985+ break
986+ }
987+ }
988+
989+ // If TCP connection fails, try ICMP ping for external IPs
990+ if ! ipReachable {
991+ // Use the existing ping functionality from metrics package
992+ connected , _ := metrics .ExtPeerConnStatus (ipStr )
993+ if connected {
994+ slog .Debug ("IP is reachable via ping" , "ip" , ipStr )
995+ ipReachable = true
996+ }
997+ }
998+
999+ // If this IP is not reachable, fail the entire check
1000+ if ! ipReachable {
1001+ slog .Debug ("IP is not reachable" , "ip" , ipStr )
1002+ return false
1003+ }
1004+ }
1005+
1006+ slog .Debug ("all IPs are reachable" , "ips" , ips )
1007+ return true
1008+ }
1009+
1010+ // resolveDomainToIPs resolves a domain name to IP addresses using the existing DNS infrastructure
1011+ func resolveDomainToIPs (domain string ) ([]string , error ) {
1012+ if domain == "" {
1013+ return nil , fmt .Errorf ("domain cannot be empty" )
1014+ }
1015+
1016+ // Use the existing DNS infrastructure to resolve the domain
1017+ ips := dns .FindDnsAns (domain )
1018+ if len (ips ) == 0 {
1019+ lookUpIPs , err := net .LookupIP (domain )
1020+ if err != nil {
1021+ return nil , fmt .Errorf ("failed to resolve domain %s: %w" , domain , err )
1022+ }
1023+ if len (lookUpIPs ) == 0 {
1024+ return nil , fmt .Errorf ("no IP addresses found for domain %s" , domain )
1025+ }
1026+ ips = lookUpIPs
1027+ }
1028+
1029+ // Filter out any invalid IPs and return unique IPs
1030+ uniqueIPs := make (map [string ]string )
1031+ for _ , ip := range ips {
1032+ if ip != nil {
1033+ uniqueIPs [ip .String ()] = ip .String ()
1034+ }
1035+ }
1036+
1037+ result := make ([]string , 0 , len (uniqueIPs ))
1038+ for _ , ip := range uniqueIPs {
1039+ result = append (result , ip )
1040+ }
1041+
1042+ return result , nil
1043+ }
0 commit comments