Skip to content

Commit dce1ef7

Browse files
JamesNKvasicvuk
andauthored
DNS resolver fix port usage (#1493)
Co-authored-by: Vuk Vasić <[email protected]>
1 parent 2029852 commit dce1ef7

File tree

6 files changed

+217
-78
lines changed

6 files changed

+217
-78
lines changed

src/Grpc.Net.Client/Balancer/DnsResolver.cs

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ internal sealed class DnsResolver : Resolver
4242
// To prevent excessive re-resolution, we enforce a rate limit on DNS resolution requests.
4343
private static readonly TimeSpan MinimumDnsResolutionRate = TimeSpan.FromSeconds(15);
4444

45-
private readonly Uri _address;
45+
private readonly Uri _originalAddress;
46+
private readonly string _dnsAddress;
47+
private readonly int _port;
4648
private readonly TimeSpan _refreshInterval;
4749
private readonly ILogger _logger;
4850

@@ -56,11 +58,18 @@ internal sealed class DnsResolver : Resolver
5658
/// Initializes a new instance of the <see cref="DnsResolver"/> class with the specified target <see cref="Uri"/>.
5759
/// </summary>
5860
/// <param name="address">The target <see cref="Uri"/>.</param>
61+
/// <param name="defaultPort">The default port.</param>
5962
/// <param name="loggerFactory">The logger factory.</param>
6063
/// <param name="refreshInterval">An interval for automatically refreshing the DNS hostname.</param>
61-
public DnsResolver(Uri address, ILoggerFactory loggerFactory, TimeSpan refreshInterval) : base(loggerFactory)
64+
public DnsResolver(Uri address, int defaultPort, ILoggerFactory loggerFactory, TimeSpan refreshInterval) : base(loggerFactory)
6265
{
63-
_address = address;
66+
_originalAddress = address;
67+
68+
// DNS address has the format: dns:[//authority/]host[:port]
69+
// Because the host is specified in the path, the port needs to be parsed manually
70+
var addressParsed = new Uri("temp://" + address.AbsolutePath.TrimStart('/'));
71+
_dnsAddress = addressParsed.Host;
72+
_port = addressParsed.Port == -1 ? defaultPort : addressParsed.Port;
6473
_refreshInterval = refreshInterval;
6574
_logger = loggerFactory.CreateLogger<DnsResolver>();
6675
}
@@ -92,28 +101,25 @@ protected override async Task ResolveAsync(CancellationToken cancellationToken)
92101

93102
_lastResolveStart = SystemClock.UtcNow;
94103

95-
var dnsAddress = _address.AbsolutePath.TrimStart('/');
96-
97-
if (string.IsNullOrEmpty(dnsAddress))
104+
if (string.IsNullOrEmpty(_dnsAddress))
98105
{
99-
throw new InvalidOperationException($"Resolver address '{_address}' doesn't have a path.");
106+
throw new InvalidOperationException($"Resolver address '{_originalAddress}' is not valid. Please use dns:/// for DNS provider.");
100107
}
101108

102-
DnsResolverLog.StartingDnsQuery(_logger, _address);
103-
var addresses = await Dns.GetHostAddressesAsync(dnsAddress).ConfigureAwait(false);
109+
DnsResolverLog.StartingDnsQuery(_logger, _dnsAddress);
110+
var addresses = await Dns.GetHostAddressesAsync(_dnsAddress).ConfigureAwait(false);
104111

105-
DnsResolverLog.ReceivedDnsResults(_logger, addresses.Length, _address, addresses);
112+
DnsResolverLog.ReceivedDnsResults(_logger, addresses.Length, _dnsAddress, addresses);
106113

107-
var resolvedPort = _address.Port == -1 ? 80 : _address.Port;
108-
var endpoints = addresses.Select(a => new BalancerAddress(a.ToString(), resolvedPort)).ToArray();
114+
var endpoints = addresses.Select(a => new BalancerAddress(a.ToString(), _port)).ToArray();
109115
var resolverResult = ResolverResult.ForResult(endpoints);
110116
Listener(resolverResult);
111117
}
112118
catch (Exception ex)
113119
{
114-
var message = $"Error getting DNS hosts for address '{_address}'.";
120+
var message = $"Error getting DNS hosts for address '{_dnsAddress}'.";
115121

116-
DnsResolverLog.ErrorQueryingDns(_logger, _address, ex);
122+
DnsResolverLog.ErrorQueryingDns(_logger, _dnsAddress, ex);
117123
Listener(ResolverResult.ForFailure(GrpcProtocolHelpers.CreateStatusFromException(message, ex, StatusCode.Unavailable)));
118124
}
119125
}
@@ -144,14 +150,14 @@ internal static class DnsResolverLog
144150
private static readonly Action<ILogger, TimeSpan, TimeSpan, Exception?> _startingRateLimitDelay =
145151
LoggerMessage.Define<TimeSpan, TimeSpan>(LogLevel.Debug, new EventId(1, "StartingRateLimitDelay"), "Starting rate limit delay of {DelayDuration}. DNS resolution rate limit is once every {RateLimitDuration}.");
146152

147-
private static readonly Action<ILogger, Uri, Exception?> _startingDnsQuery =
148-
LoggerMessage.Define<Uri>(LogLevel.Trace, new EventId(2, "StartingDnsQuery"), "Starting DNS query to get hosts from '{DnsAddress}'.");
153+
private static readonly Action<ILogger, string, Exception?> _startingDnsQuery =
154+
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(2, "StartingDnsQuery"), "Starting DNS query to get hosts from '{DnsAddress}'.");
149155

150-
private static readonly Action<ILogger, int, Uri, string, Exception?> _receivedDnsResults =
151-
LoggerMessage.Define<int, Uri, string>(LogLevel.Debug, new EventId(3, "ReceivedDnsResults"), "Received {ResultCount} DNS results from '{DnsAddress}'. Results: {DnsResults}");
156+
private static readonly Action<ILogger, int, string, string, Exception?> _receivedDnsResults =
157+
LoggerMessage.Define<int, string, string>(LogLevel.Debug, new EventId(3, "ReceivedDnsResults"), "Received {ResultCount} DNS results from '{DnsAddress}'. Results: {DnsResults}");
152158

153-
private static readonly Action<ILogger, Uri, Exception?> _errorQueryingDns =
154-
LoggerMessage.Define<Uri>(LogLevel.Error, new EventId(4, "ErrorQueryingDns"), "Error querying DNS hosts for '{DnsAddress}'.");
159+
private static readonly Action<ILogger, string, Exception?> _errorQueryingDns =
160+
LoggerMessage.Define<string>(LogLevel.Error, new EventId(4, "ErrorQueryingDns"), "Error querying DNS hosts for '{DnsAddress}'.");
155161

156162
private static readonly Action<ILogger, Exception?> _errorFromRefreshInterval =
157163
LoggerMessage.Define(LogLevel.Error, new EventId(5, "ErrorFromRefreshIntervalTimer"), "Error from refresh interval timer.");
@@ -161,20 +167,20 @@ public static void StartingRateLimitDelay(ILogger logger, TimeSpan delayDuration
161167
_startingRateLimitDelay(logger, delayDuration, rateLimitDuration, null);
162168
}
163169

164-
public static void StartingDnsQuery(ILogger logger, Uri dnsAddress)
170+
public static void StartingDnsQuery(ILogger logger, string dnsAddress)
165171
{
166172
_startingDnsQuery(logger, dnsAddress, null);
167173
}
168174

169-
public static void ReceivedDnsResults(ILogger logger, int resultCount, Uri dnsAddress, IList<IPAddress> dnsResults)
175+
public static void ReceivedDnsResults(ILogger logger, int resultCount, string dnsAddress, IList<IPAddress> dnsResults)
170176
{
171177
if (logger.IsEnabled(LogLevel.Debug))
172178
{
173179
_receivedDnsResults(logger, resultCount, dnsAddress, string.Join(", ", dnsResults), null);
174180
}
175181
}
176182

177-
public static void ErrorQueryingDns(ILogger logger, Uri dnsAddress, Exception ex)
183+
public static void ErrorQueryingDns(ILogger logger, string dnsAddress, Exception ex)
178184
{
179185
_errorQueryingDns(logger, dnsAddress, ex);
180186
}
@@ -211,7 +217,7 @@ public DnsResolverFactory(TimeSpan refreshInterval)
211217
/// <inheritdoc />
212218
public override Resolver Create(ResolverOptions options)
213219
{
214-
return new DnsResolver(options.Address, options.LoggerFactory, _refreshInterval);
220+
return new DnsResolver(options.Address, options.DefaultPort, options.LoggerFactory, _refreshInterval);
215221
}
216222
}
217223
}

src/Grpc.Net.Client/Balancer/ResolverOptions.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@ public sealed class ResolverOptions
3333
/// <summary>
3434
/// Initializes a new instance of the <see cref="ResolverOptions"/> class.
3535
/// </summary>
36-
internal ResolverOptions(Uri address, bool disableServiceConfig, ILoggerFactory loggerFactory)
36+
internal ResolverOptions(Uri address, int defaultPort, bool disableServiceConfig, ILoggerFactory loggerFactory)
3737
{
3838
Address = address;
39+
DefaultPort = defaultPort;
3940
DisableServiceConfig = disableServiceConfig;
4041
LoggerFactory = loggerFactory;
4142
}
@@ -45,6 +46,11 @@ internal ResolverOptions(Uri address, bool disableServiceConfig, ILoggerFactory
4546
/// </summary>
4647
public Uri Address { get; }
4748

49+
/// <summary>
50+
/// Gets the default port. This port is used when the resolver address doesn't specify a port.
51+
/// </summary>
52+
public int DefaultPort { get; }
53+
4854
/// <summary>
4955
/// Gets a flag indicating whether the resolver should disable resolving a service config.
5056
/// </summary>

src/Grpc.Net.Client/GrpcChannel.cs

Lines changed: 53 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ public sealed class GrpcChannel : ChannelBase, IDisposable
5555
private readonly ConcurrentDictionary<IMethod, GrpcMethodInfo> _methodInfoCache;
5656
private readonly Func<IMethod, GrpcMethodInfo> _createMethodInfoFunc;
5757
private readonly Dictionary<MethodKey, MethodConfig>? _serviceConfigMethods;
58+
private readonly bool _isSecure;
59+
private readonly List<CallCredentials>? _callCredentials;
5860
// Internal for testing
5961
internal readonly HashSet<IDisposable> ActiveCalls;
6062

@@ -69,8 +71,8 @@ public sealed class GrpcChannel : ChannelBase, IDisposable
6971
internal ILoggerFactory LoggerFactory { get; }
7072
internal ILogger Logger { get; }
7173
internal bool ThrowOperationCanceledOnCancellation { get; }
72-
internal bool IsSecure { get; }
73-
internal List<CallCredentials>? CallCredentials { get; }
74+
internal bool IsSecure => _isSecure;
75+
internal List<CallCredentials>? CallCredentials => _callCredentials;
7476
internal Dictionary<string, ICompressionProvider> CompressionProviders { get; }
7577
internal string MessageAcceptEncoding { get; }
7678
internal bool Disposed { get; private set; }
@@ -118,21 +120,18 @@ internal GrpcChannel(Uri address, GrpcChannelOptions channelOptions) : base(addr
118120
HttpHandlerType = CalculateHandlerType(channelOptions);
119121

120122
#if SUPPORT_LOAD_BALANCING
123+
var resolverFactory = GetResolverFactory(channelOptions);
124+
ResolveCredentials(channelOptions, out _isSecure, out _callCredentials);
125+
121126
SubchannelTransportFactory = ResolveService<ISubchannelTransportFactory>(channelOptions.ServiceProvider, new SubChannelTransportFactory(this));
122127

123128
if (!IsHttpOrHttpsAddress() || channelOptions.ServiceConfig?.LoadBalancingConfigs.Count > 0)
124129
{
125130
ValidateHttpHandlerSupportsConnectivity();
126131
}
127132

128-
// Special case http and https schemes. These schemes don't use a dynamic resolver. An http/https
129-
// address is always just one address and that is enabled using the static resolver.
130-
//
131-
// Even with just one address we still want to use the load balancing infrastructure. This enables
132-
// the connectivity APIs on channel like GrpcChannel.State and GrpcChannel.WaitForStateChanged.
133-
var resolver = IsHttpOrHttpsAddress()
134-
? new StaticResolver(new[] { new BalancerAddress(Address.Host, Address.Port) }, LoggerFactory)
135-
: CreateResolver(channelOptions);
133+
var defaultPort = IsSecure ? 443 : 80;
134+
var resolver = resolverFactory.Create(new ResolverOptions(Address, defaultPort, channelOptions.DisableResolverServiceConfig, LoggerFactory));
136135

137136
ConnectionManager = new ConnectionManager(
138137
resolver,
@@ -149,6 +148,7 @@ internal GrpcChannel(Uri address, GrpcChannelOptions channelOptions) : base(addr
149148
{
150149
throw new ArgumentException($"Address '{address.OriginalString}' doesn't have a host. Address should include a scheme, host, and optional port. For example, 'https://localhost:5001'.");
151150
}
151+
ResolveCredentials(channelOptions, out _isSecure, out _callCredentials);
152152
#endif
153153

154154
HttpInvoker = channelOptions.HttpClient ?? CreateInternalHttpInvoker(channelOptions.HttpHandler);
@@ -169,25 +169,37 @@ internal GrpcChannel(Uri address, GrpcChannelOptions channelOptions) : base(addr
169169
_serviceConfigMethods = CreateServiceConfigMethods(serviceConfig);
170170
}
171171

172+
// Non-HTTP addresses (e.g. dns:///custom-hostname) usually specify a path instead of an authority.
173+
// Only log about a path being present if HTTP or HTTPS.
174+
if (!string.IsNullOrEmpty(Address.PathAndQuery) &&
175+
Address.PathAndQuery != "/" &&
176+
(Address.Scheme == Uri.UriSchemeHttps || Address.Scheme == Uri.UriSchemeHttp))
177+
{
178+
Log.AddressPathUnused(Logger, Address.OriginalString);
179+
}
180+
}
181+
182+
private void ResolveCredentials(GrpcChannelOptions channelOptions, out bool isSecure, out List<CallCredentials>? callCredentials)
183+
{
172184
if (channelOptions.Credentials != null)
173185
{
174186
var configurator = new DefaultChannelCredentialsConfigurator();
175187
channelOptions.Credentials.InternalPopulateConfiguration(configurator, null);
176188

177-
IsSecure = configurator.IsSecure ?? false;
178-
CallCredentials = configurator.CallCredentials;
189+
isSecure = configurator.IsSecure ?? false;
190+
callCredentials = configurator.CallCredentials;
179191

180192
ValidateChannelCredentials();
181193
}
182194
else
183195
{
184196
if (Address.Scheme == Uri.UriSchemeHttp)
185197
{
186-
IsSecure = false;
198+
isSecure = false;
187199
}
188200
else if (Address.Scheme == Uri.UriSchemeHttps)
189201
{
190-
IsSecure = true;
202+
isSecure = true;
191203
}
192204
else
193205
{
@@ -196,15 +208,7 @@ internal GrpcChannel(Uri address, GrpcChannelOptions channelOptions) : base(addr
196208
"To call TLS endpoints, set credentials to 'new SslCredentials()'. " +
197209
"To call non-TLS endpoints, set credentials to 'ChannelCredentials.Insecure'.");
198210
}
199-
}
200-
201-
// Non-HTTP addresses (e.g. dns:///custom-hostname) usually specify a path instead of an authority.
202-
// Only log about a path being present if HTTP or HTTPS.
203-
if (!string.IsNullOrEmpty(Address.PathAndQuery) &&
204-
Address.PathAndQuery != "/" &&
205-
(Address.Scheme == Uri.UriSchemeHttps || Address.Scheme == Uri.UriSchemeHttp))
206-
{
207-
Log.AddressPathUnused(Logger, Address.OriginalString);
211+
callCredentials = null;
208212
}
209213
}
210214

@@ -227,6 +231,32 @@ private static HttpHandlerType CalculateHandlerType(GrpcChannelOptions channelOp
227231
}
228232

229233
#if SUPPORT_LOAD_BALANCING
234+
private ResolverFactory GetResolverFactory(GrpcChannelOptions options)
235+
{
236+
// Special case http and https schemes. These schemes don't use a dynamic resolver. An http/https
237+
// address is always just one address and that is enabled using the static resolver.
238+
//
239+
// Even with just one address we still want to use the load balancing infrastructure. This enables
240+
// the connectivity APIs on channel like GrpcChannel.State and GrpcChannel.WaitForStateChanged.
241+
if (IsHttpOrHttpsAddress())
242+
{
243+
return new StaticResolverFactory(uri => new[] { new BalancerAddress(Address.Host, Address.Port) });
244+
}
245+
246+
var factories = ResolveService<IEnumerable<ResolverFactory>>(options.ServiceProvider, Array.Empty<ResolverFactory>());
247+
factories = factories.Union(ResolverFactory.KnownLoadResolverFactories);
248+
249+
foreach (var factory in factories)
250+
{
251+
if (string.Equals(factory.Name, Address.Scheme, StringComparison.OrdinalIgnoreCase))
252+
{
253+
return factory;
254+
}
255+
}
256+
257+
throw new InvalidOperationException($"No address resolver configured for the scheme '{Address.Scheme}'.");
258+
}
259+
230260
private void ValidateHttpHandlerSupportsConnectivity()
231261
{
232262
if (HttpHandlerType == HttpHandlerType.SocketsHttpHandler)
@@ -250,22 +280,6 @@ private void ValidateHttpHandlerSupportsConnectivity()
250280
$"The HTTP transport must be configured on the channel using {nameof(GrpcChannelOptions)}.{nameof(GrpcChannelOptions.HttpHandler)}.");
251281
}
252282

253-
private Resolver CreateResolver(GrpcChannelOptions options)
254-
{
255-
var factories = ResolveService<IEnumerable<ResolverFactory>>(options.ServiceProvider, Array.Empty<ResolverFactory>());
256-
factories = factories.Union(ResolverFactory.KnownLoadResolverFactories);
257-
258-
foreach (var factory in factories)
259-
{
260-
if (string.Equals(factory.Name, Address.Scheme, StringComparison.OrdinalIgnoreCase))
261-
{
262-
return factory.Create(new ResolverOptions(Address, options.DisableResolverServiceConfig, LoggerFactory));
263-
}
264-
}
265-
266-
throw new InvalidOperationException($"No address resolver configured for the scheme '{Address.Scheme}'.");
267-
}
268-
269283
private LoadBalancerFactory[] ResolveLoadBalancerFactories(IServiceProvider? serviceProvider)
270284
{
271285
var serviceFactories = ResolveService<IEnumerable<LoadBalancerFactory>?>(serviceProvider, defaultValue: null);

0 commit comments

Comments
 (0)