1+ using System . Collections . Concurrent ;
2+ using System . Diagnostics ;
3+ using System . Net ;
4+ using System . Net . Sockets ;
5+
6+ // <DnsRoundRobinConnector>
7+ // This is available as NuGet Package: https://www.nuget.org/packages/DnsRoundRobin/
8+ // The original source code can be found also here: https://github.com/MihaZupan/DnsRoundRobin
9+ public sealed class DnsRoundRobinConnector : IDisposable
10+ // </DnsRoundRobinConnector>
11+ {
12+ private const int DefaultDnsRefreshIntervalSeconds = 2 * 60 ;
13+ private const int MaxCleanupIntervalSeconds = 60 ;
14+
15+ public static DnsRoundRobinConnector Shared { get ; } = new ( ) ;
16+
17+ private readonly ConcurrentDictionary < string , HostRoundRobinState > _states = new ( StringComparer . Ordinal ) ;
18+ private readonly Timer _cleanupTimer ;
19+ private readonly TimeSpan _cleanupInterval ;
20+ private readonly long _cleanupIntervalTicks ;
21+ private readonly long _dnsRefreshTimeoutTicks ;
22+ private readonly TimeSpan _endpointConnectTimeout ;
23+
24+ /// <summary>
25+ /// Creates a new <see cref="DnsRoundRobinConnector"/>.
26+ /// </summary>
27+ /// <param name="dnsRefreshInterval">Maximum amount of time a Dns resolution is cached for. Default to 2 minutes.</param>
28+ /// <param name="endpointConnectTimeout">Maximum amount of time allowed for a connection attempt to any individual endpoint. Defaults to infinite.</param>
29+ public DnsRoundRobinConnector ( TimeSpan ? dnsRefreshInterval = null , TimeSpan ? endpointConnectTimeout = null )
30+ {
31+ dnsRefreshInterval = TimeSpan . FromSeconds ( Math . Max ( 1 , dnsRefreshInterval ? . TotalSeconds ?? DefaultDnsRefreshIntervalSeconds ) ) ;
32+ _cleanupInterval = TimeSpan . FromSeconds ( Math . Clamp ( dnsRefreshInterval . Value . TotalSeconds / 2 , 1 , MaxCleanupIntervalSeconds ) ) ;
33+ _cleanupIntervalTicks = ( long ) ( _cleanupInterval . TotalSeconds * Stopwatch . Frequency ) ;
34+ _dnsRefreshTimeoutTicks = ( long ) ( dnsRefreshInterval . Value . TotalSeconds * Stopwatch . Frequency ) ;
35+ _endpointConnectTimeout = endpointConnectTimeout is null || endpointConnectTimeout . Value . Ticks < 1 ? Timeout . InfiniteTimeSpan : endpointConnectTimeout . Value ;
36+
37+ bool restoreFlow = false ;
38+ try
39+ {
40+ // Don't capture the current ExecutionContext and its AsyncLocals onto the timer causing them to live forever
41+ if ( ! ExecutionContext . IsFlowSuppressed ( ) )
42+ {
43+ ExecutionContext . SuppressFlow ( ) ;
44+ restoreFlow = true ;
45+ }
46+
47+ // Ensure the Timer has a weak reference to the connector; otherwise, it
48+ // can introduce a cycle that keeps the connector rooted by the Timer
49+ _cleanupTimer = new Timer ( static state =>
50+ {
51+ var thisWeakRef = ( WeakReference < DnsRoundRobinConnector > ) state ! ;
52+ if ( thisWeakRef . TryGetTarget ( out DnsRoundRobinConnector ? thisRef ) )
53+ {
54+ thisRef . Cleanup ( ) ;
55+ thisRef . _cleanupTimer . Change ( thisRef . _cleanupInterval , Timeout . InfiniteTimeSpan ) ;
56+ }
57+ } , new WeakReference < DnsRoundRobinConnector > ( this ) , Timeout . Infinite , Timeout . Infinite ) ;
58+
59+ _cleanupTimer . Change ( _cleanupInterval , Timeout . InfiniteTimeSpan ) ;
60+ }
61+ finally
62+ {
63+ if ( restoreFlow )
64+ {
65+ ExecutionContext . RestoreFlow ( ) ;
66+ }
67+ }
68+ }
69+
70+ private void Cleanup ( )
71+ {
72+ long minTimestamp = Stopwatch . GetTimestamp ( ) - _cleanupIntervalTicks ;
73+
74+ foreach ( KeyValuePair < string , HostRoundRobinState > state in _states )
75+ {
76+ if ( state . Value . LastAccessTimestamp < minTimestamp )
77+ {
78+ _states . TryRemove ( state ) ;
79+ }
80+ }
81+ }
82+
83+ public void Dispose ( )
84+ {
85+ _states . Clear ( ) ;
86+ }
87+
88+ public Task < Socket > ConnectAsync ( DnsEndPoint endPoint , CancellationToken cancellationToken )
89+ {
90+ if ( cancellationToken . IsCancellationRequested )
91+ {
92+ return Task . FromCanceled < Socket > ( cancellationToken ) ;
93+ }
94+
95+ if ( IPAddress . TryParse ( endPoint . Host , out IPAddress ? address ) )
96+ {
97+ // Avoid the overhead of HostRoundRobinState if we're dealing with a single endpoint
98+ return ConnectToIPAddressAsync ( address , endPoint . Port , cancellationToken ) ;
99+ }
100+
101+ HostRoundRobinState state = _states . GetOrAdd (
102+ endPoint . Host ,
103+ static ( _ , thisRef ) => new HostRoundRobinState ( thisRef . _dnsRefreshTimeoutTicks , thisRef . _endpointConnectTimeout ) ,
104+ this ) ;
105+
106+ return state . ConnectAsync ( endPoint , cancellationToken ) ;
107+ }
108+
109+ private static async Task < Socket > ConnectToIPAddressAsync ( IPAddress address , int port , CancellationToken cancellationToken )
110+ {
111+ var socket = new Socket ( SocketType . Stream , ProtocolType . Tcp ) { NoDelay = true } ;
112+ try
113+ {
114+ await socket . ConnectAsync ( address , port , cancellationToken ) ;
115+ return socket ;
116+ }
117+ catch
118+ {
119+ socket . Dispose ( ) ;
120+ throw ;
121+ }
122+ }
123+
124+ private sealed class HostRoundRobinState
125+ {
126+ private readonly long _dnsRefreshTimeoutTicks ;
127+ private readonly TimeSpan _endpointConnectTimeout ;
128+ private long _lastAccessTimestamp ;
129+ private long _lastDnsTimestamp ;
130+ private IPAddress [ ] ? _addresses ;
131+ private uint _roundRobinIndex ;
132+
133+ public long LastAccessTimestamp => Volatile . Read ( ref _lastAccessTimestamp ) ;
134+
135+ private bool AddressesAreStale => Stopwatch . GetTimestamp ( ) - Volatile . Read ( ref _lastDnsTimestamp ) > _dnsRefreshTimeoutTicks ;
136+
137+ public HostRoundRobinState ( long dnsRefreshTimeoutTicks , TimeSpan endpointConnectTimeout )
138+ {
139+ _dnsRefreshTimeoutTicks = dnsRefreshTimeoutTicks ;
140+ _endpointConnectTimeout = endpointConnectTimeout ;
141+
142+ _roundRobinIndex -- ; // Offset the first Increment to ensure we start with the first address in the list
143+
144+ RefreshLastAccessTimestamp ( ) ;
145+ }
146+
147+ private void RefreshLastAccessTimestamp ( ) => Volatile . Write ( ref _lastAccessTimestamp , Stopwatch . GetTimestamp ( ) ) ;
148+
149+ public async Task < Socket > ConnectAsync ( DnsEndPoint endPoint , CancellationToken cancellationToken )
150+ {
151+ RefreshLastAccessTimestamp ( ) ;
152+
153+ uint sharedIndex = Interlocked . Increment ( ref _roundRobinIndex ) ;
154+ IPAddress [ ] ? attemptedAddresses = null ;
155+ IPAddress [ ] ? addresses = null ;
156+ Exception ? lastException = null ;
157+
158+ while ( attemptedAddresses is null )
159+ {
160+ if ( addresses is null )
161+ {
162+ addresses = _addresses ;
163+ }
164+ else
165+ {
166+ attemptedAddresses = addresses ;
167+
168+ // Give each connection attempt a chance to do its own Dns call.
169+ addresses = null ;
170+ }
171+
172+ if ( addresses is null || AddressesAreStale )
173+ {
174+ // It's possible that multiple connection attempts are resolving the same host concurrently - that's okay.
175+ _addresses = addresses = await Dns . GetHostAddressesAsync ( endPoint . Host , cancellationToken ) ;
176+ Volatile . Write ( ref _lastDnsTimestamp , Stopwatch . GetTimestamp ( ) ) ;
177+
178+ if ( attemptedAddresses is not null && AddressListsAreEquivalent ( attemptedAddresses , addresses ) )
179+ {
180+ // We've already tried to connect to every address in the list, and a new Dns resolution returned the same list.
181+ // Instead of attempting every address again, give up early.
182+ break ;
183+ }
184+ }
185+
186+ for ( int i = 0 ; i < addresses . Length ; i ++ )
187+ {
188+ Socket ? attemptSocket = null ;
189+ CancellationTokenSource ? endpointConnectTimeoutCts = null ;
190+ try
191+ {
192+ IPAddress address = addresses [ ( int ) ( ( sharedIndex + i ) % addresses . Length ) ] ;
193+
194+ if ( Socket . OSSupportsIPv6 && address . AddressFamily == AddressFamily . InterNetworkV6 )
195+ {
196+ attemptSocket = new Socket ( AddressFamily . InterNetworkV6 , SocketType . Stream , ProtocolType . Tcp ) ;
197+ if ( address . IsIPv4MappedToIPv6 )
198+ {
199+ attemptSocket . DualMode = true ;
200+ }
201+ }
202+ else if ( Socket . OSSupportsIPv4 && address . AddressFamily == AddressFamily . InterNetwork )
203+ {
204+ attemptSocket = new Socket ( AddressFamily . InterNetwork , SocketType . Stream , ProtocolType . Tcp ) ;
205+ }
206+
207+ if ( attemptSocket is not null )
208+ {
209+ attemptSocket . NoDelay = true ;
210+
211+ if ( _endpointConnectTimeout != Timeout . InfiniteTimeSpan )
212+ {
213+ endpointConnectTimeoutCts = CancellationTokenSource . CreateLinkedTokenSource ( cancellationToken ) ;
214+ endpointConnectTimeoutCts . CancelAfter ( _endpointConnectTimeout ) ;
215+ }
216+
217+ await attemptSocket . ConnectAsync ( address , endPoint . Port , endpointConnectTimeoutCts ? . Token ?? cancellationToken ) ;
218+
219+ RefreshLastAccessTimestamp ( ) ;
220+ return attemptSocket ;
221+ }
222+ }
223+ catch ( Exception ex )
224+ {
225+ attemptSocket ? . Dispose ( ) ;
226+
227+ if ( cancellationToken . IsCancellationRequested )
228+ {
229+ throw ;
230+ }
231+
232+ if ( endpointConnectTimeoutCts ? . IsCancellationRequested == true )
233+ {
234+ ex = new TimeoutException ( $ "Failed to connect to any endpoint within the specified endpoint connect timeout of { _endpointConnectTimeout . TotalSeconds : N2} seconds.", ex ) ;
235+ }
236+
237+ lastException = ex ;
238+ }
239+ finally
240+ {
241+ endpointConnectTimeoutCts ? . Dispose ( ) ;
242+ }
243+ }
244+ }
245+
246+ throw lastException ?? new SocketException ( ( int ) SocketError . NoData ) ;
247+ }
248+
249+ private static bool AddressListsAreEquivalent ( IPAddress [ ] left , IPAddress [ ] right )
250+ {
251+ if ( ReferenceEquals ( left , right ) )
252+ {
253+ return true ;
254+ }
255+
256+ if ( left . Length != right . Length )
257+ {
258+ return false ;
259+ }
260+
261+ for ( int i = 0 ; i < left . Length ; i ++ )
262+ {
263+ if ( ! left [ i ] . Equals ( right [ i ] ) )
264+ {
265+ return false ;
266+ }
267+ }
268+
269+ return true ;
270+ }
271+ }
272+ }
0 commit comments