@@ -2056,3 +2056,124 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
2056
2056
assert .Equal (t , PriorityLocal , localMuxUpdates [0 ].priority , "Local handler should use PriorityLocal" )
2057
2057
assert .Equal (t , "local.example.com" , localMuxUpdates [0 ].domain )
2058
2058
}
2059
+
2060
+ func TestDNSLoopPrevention (t * testing.T ) {
2061
+ wgInterface := & mocWGIface {}
2062
+ service := NewServiceViaMemory (wgInterface )
2063
+ dnsServerIP := service .RuntimeIP ()
2064
+
2065
+ server := & DefaultServer {
2066
+ ctx : context .Background (),
2067
+ wgInterface : wgInterface ,
2068
+ service : service ,
2069
+ localResolver : local .NewResolver (),
2070
+ handlerChain : NewHandlerChain (),
2071
+ hostManager : & noopHostConfigurator {},
2072
+ dnsMuxMap : make (registeredHandlerMap ),
2073
+ }
2074
+
2075
+ tests := []struct {
2076
+ name string
2077
+ nsGroups []* nbdns.NameServerGroup
2078
+ expectedHandlers int
2079
+ expectedServers []netip.Addr
2080
+ shouldFilterOwnIP bool
2081
+ }{
2082
+ {
2083
+ name : "FilterOwnDNSServerIP" ,
2084
+ nsGroups : []* nbdns.NameServerGroup {
2085
+ {
2086
+ Primary : true ,
2087
+ NameServers : []nbdns.NameServer {
2088
+ {IP : netip .MustParseAddr ("8.8.8.8" ), NSType : nbdns .UDPNameServerType , Port : 53 },
2089
+ {IP : dnsServerIP , NSType : nbdns .UDPNameServerType , Port : 53 },
2090
+ {IP : netip .MustParseAddr ("1.1.1.1" ), NSType : nbdns .UDPNameServerType , Port : 53 },
2091
+ },
2092
+ Domains : []string {},
2093
+ },
2094
+ },
2095
+ expectedHandlers : 1 ,
2096
+ expectedServers : []netip.Addr {netip .MustParseAddr ("8.8.8.8" ), netip .MustParseAddr ("1.1.1.1" )},
2097
+ shouldFilterOwnIP : true ,
2098
+ },
2099
+ {
2100
+ name : "AllServersFiltered" ,
2101
+ nsGroups : []* nbdns.NameServerGroup {
2102
+ {
2103
+ Primary : false ,
2104
+ NameServers : []nbdns.NameServer {
2105
+ {IP : dnsServerIP , NSType : nbdns .UDPNameServerType , Port : 53 },
2106
+ },
2107
+ Domains : []string {"example.com" },
2108
+ },
2109
+ },
2110
+ expectedHandlers : 0 ,
2111
+ expectedServers : []netip.Addr {},
2112
+ shouldFilterOwnIP : true ,
2113
+ },
2114
+ {
2115
+ name : "MixedServersWithOwnIP" ,
2116
+ nsGroups : []* nbdns.NameServerGroup {
2117
+ {
2118
+ Primary : false ,
2119
+ NameServers : []nbdns.NameServer {
2120
+ {IP : netip .MustParseAddr ("8.8.8.8" ), NSType : nbdns .UDPNameServerType , Port : 53 },
2121
+ {IP : dnsServerIP , NSType : nbdns .UDPNameServerType , Port : 53 },
2122
+ {IP : netip .MustParseAddr ("1.1.1.1" ), NSType : nbdns .UDPNameServerType , Port : 53 },
2123
+ {IP : dnsServerIP , NSType : nbdns .UDPNameServerType , Port : 53 }, // duplicate
2124
+ },
2125
+ Domains : []string {"test.com" },
2126
+ },
2127
+ },
2128
+ expectedHandlers : 1 ,
2129
+ expectedServers : []netip.Addr {netip .MustParseAddr ("8.8.8.8" ), netip .MustParseAddr ("1.1.1.1" )},
2130
+ shouldFilterOwnIP : true ,
2131
+ },
2132
+ {
2133
+ name : "NoOwnIPInList" ,
2134
+ nsGroups : []* nbdns.NameServerGroup {
2135
+ {
2136
+ Primary : true ,
2137
+ NameServers : []nbdns.NameServer {
2138
+ {IP : netip .MustParseAddr ("8.8.8.8" ), NSType : nbdns .UDPNameServerType , Port : 53 },
2139
+ {IP : netip .MustParseAddr ("1.1.1.1" ), NSType : nbdns .UDPNameServerType , Port : 53 },
2140
+ },
2141
+ Domains : []string {},
2142
+ },
2143
+ },
2144
+ expectedHandlers : 1 ,
2145
+ expectedServers : []netip.Addr {netip .MustParseAddr ("8.8.8.8" ), netip .MustParseAddr ("1.1.1.1" )},
2146
+ shouldFilterOwnIP : false ,
2147
+ },
2148
+ }
2149
+
2150
+ for _ , tt := range tests {
2151
+ t .Run (tt .name , func (t * testing.T ) {
2152
+ muxUpdates , err := server .buildUpstreamHandlerUpdate (tt .nsGroups )
2153
+ assert .NoError (t , err )
2154
+ assert .Len (t , muxUpdates , tt .expectedHandlers )
2155
+
2156
+ if tt .expectedHandlers > 0 {
2157
+ handler := muxUpdates [0 ].handler .(* upstreamResolver )
2158
+ assert .Len (t , handler .upstreamServers , len (tt .expectedServers ))
2159
+
2160
+ if tt .shouldFilterOwnIP {
2161
+ for _ , upstream := range handler .upstreamServers {
2162
+ assert .NotEqual (t , dnsServerIP , upstream .Addr ())
2163
+ }
2164
+ }
2165
+
2166
+ for _ , expected := range tt .expectedServers {
2167
+ found := false
2168
+ for _ , upstream := range handler .upstreamServers {
2169
+ if upstream .Addr () == expected {
2170
+ found = true
2171
+ break
2172
+ }
2173
+ }
2174
+ assert .True (t , found , "Expected server %s not found" , expected )
2175
+ }
2176
+ }
2177
+ })
2178
+ }
2179
+ }
0 commit comments