Skip to content

Commit b745304

Browse files
authored
Merge pull request #11 from koralium/10_ssl_support
Added support for ssl traffic
2 parents 4f2a031 + d14b2e4 commit b745304

File tree

3 files changed

+138
-21
lines changed

3 files changed

+138
-21
lines changed

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,13 @@ The connection string takes the following parameters:
9797
| ExtraCredentials | Extra credentials to send. | ExtraCredentials=key1:value1,key2:value2; |
9898
| Trino | Use trino headers (required for trino) | Trino=true; |
9999
| Password | Password for the user | Password=test; |
100+
| Ssl | Https or http protocol | Ssl=true; |
101+
102+
# SSL Traffic
103+
104+
If the SSL connection string option is left out, the ADO.Net provider tries to figure out the protocol by itself.
105+
It first tries https but if that fails it tests http. This is saved as long as the application is running.
106+
But for better first time performance if one is not using https is to set ssl=false in the connection string.
100107

101108
# Nuget Package
102109

src/Data.Presto/Client/PrestoClient.cs

Lines changed: 108 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Data.Presto.Models;
22
using Data.Presto.Utils;
33
using System;
4+
using System.Collections.Concurrent;
45
using System.Collections.Generic;
56
using System.Net.Http;
67
using System.Net.Http.Headers;
@@ -15,10 +16,28 @@ class PrestoClient
1516
private readonly HttpClient _httpClient;
1617
private readonly PrestoConnectionStringBuilder _connectionString;
1718

19+
private bool? _useSsl = null;
20+
21+
//Dictionary used to store which http protocol to use for connections where it is not marked explicitly.
22+
private static readonly ConcurrentDictionary<string, bool> _protocolLookup = new ConcurrentDictionary<string, bool>();
23+
1824
public PrestoClient(PrestoConnectionStringBuilder prestoConnectionString)
1925
{
2026
_connectionString = prestoConnectionString;
2127
_httpClient = new HttpClient();
28+
29+
if (_connectionString.Ssl.HasValue)
30+
{
31+
_useSsl = _connectionString.Ssl.Value;
32+
}
33+
else
34+
{
35+
//Check if another presto client has already done connections against the host
36+
if (_protocolLookup.TryGetValue(prestoConnectionString.Host, out var canUseSsl))
37+
{
38+
_useSsl = canUseSsl;
39+
}
40+
}
2241
}
2342

2443
private void AddHeaders(HttpRequestMessage httpRequestMessage)
@@ -100,18 +119,98 @@ private async Task<DecodeResult> CheckResult(HttpResponseMessage httpResponseMes
100119
return decodeResults;
101120
}
102121

103-
public async Task<DecodeResult> Query(string statement, CancellationToken cancellationToken)
122+
private string GetProtocol()
104123
{
105-
using var httpRequestMessage = new HttpRequestMessage()
124+
if (!_useSsl.Value)
106125
{
107-
Method = HttpMethod.Post,
108-
RequestUri = new Uri($"http://{_connectionString.Host}/v1/statement"),
109-
Content = new StringContent(statement)
110-
};
126+
return "http://";
127+
}
128+
return "https://";
129+
}
111130

112-
AddHeaders(httpRequestMessage);
131+
private void SetHostSsl(in string host, bool canUseSsl)
132+
{
133+
if (_useSsl == null)
134+
{
135+
_useSsl = canUseSsl;
136+
_protocolLookup.AddOrUpdate(host, canUseSsl, (key, old) => canUseSsl);
137+
}
138+
}
113139

114-
var result = await _httpClient.SendAsync(httpRequestMessage).ConfigureAwait(false);
140+
private async Task<HttpResponseMessage> SendMessage(HttpMethod httpMethod, string path, CancellationToken cancellationToken, string content = null)
141+
{
142+
//Protocol has not yet been determined
143+
if (_useSsl == null)
144+
{
145+
try
146+
{
147+
using var httpRequestMessage = new HttpRequestMessage()
148+
{
149+
Method = httpMethod,
150+
RequestUri = new Uri($"https://{_connectionString.Host}{path}"),
151+
};
152+
153+
if (content != null)
154+
{
155+
httpRequestMessage.Content = new StringContent(content);
156+
}
157+
158+
AddHeaders(httpRequestMessage);
159+
160+
var response = await _httpClient.SendAsync(httpRequestMessage).ConfigureAwait(false);
161+
SetHostSsl(_connectionString.Host, true);
162+
return response;
163+
}
164+
catch (HttpRequestException requestException)
165+
{
166+
if (requestException?.InnerException?.Source == "System.Net.Security")
167+
{
168+
//Exception regarding security, test http:// instead
169+
using var httpRequestMessage = new HttpRequestMessage()
170+
{
171+
Method = httpMethod,
172+
RequestUri = new Uri($"http://{_connectionString.Host}{path}"),
173+
};
174+
175+
if (content != null)
176+
{
177+
httpRequestMessage.Content = new StringContent(content);
178+
}
179+
180+
AddHeaders(httpRequestMessage);
181+
182+
var response = await _httpClient.SendAsync(httpRequestMessage).ConfigureAwait(false);
183+
SetHostSsl(_connectionString.Host, false);
184+
return response;
185+
}
186+
else
187+
{
188+
throw;
189+
}
190+
}
191+
}
192+
else
193+
{
194+
using var httpRequestMessage = new HttpRequestMessage()
195+
{
196+
Method = httpMethod,
197+
RequestUri = new Uri($"{GetProtocol()}{_connectionString.Host}{path}"),
198+
};
199+
200+
if (content != null)
201+
{
202+
httpRequestMessage.Content = new StringContent(content);
203+
}
204+
205+
AddHeaders(httpRequestMessage);
206+
207+
return await _httpClient.SendAsync(httpRequestMessage).ConfigureAwait(false);
208+
}
209+
}
210+
211+
public async Task<DecodeResult> Query(string statement, CancellationToken cancellationToken)
212+
{
213+
var result = await SendMessage(HttpMethod.Post, "/v1/statement", cancellationToken, statement);
115214

116215
switch (result.StatusCode)
117216
{
@@ -126,15 +225,7 @@ public async Task<DecodeResult> Query(string statement, CancellationToken cancel
126225

127226
public async Task KillQuery(string queryId, CancellationToken cancellationToken)
128227
{
129-
using var httpRequestMessage = new HttpRequestMessage()
130-
{
131-
Method = HttpMethod.Delete,
132-
RequestUri = new Uri($"http://{_connectionString.Host}/v1/query/{queryId}")
133-
};
134-
135-
AddHeaders(httpRequestMessage);
136-
137-
var result = await _httpClient.SendAsync(httpRequestMessage).ConfigureAwait(false);
228+
await SendMessage(HttpMethod.Delete, $"/v1/query/{queryId}", cancellationToken).ConfigureAwait(false);
138229
}
139230
}
140231
}

src/Data.Presto/PrestoConnectionStringBuilder.cs

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ public class PrestoConnectionStringBuilder : DbConnectionStringBuilder
2424
private const string StreamingKeyword = "Streaming";
2525
private const string TrinoKeyword = "Trino";
2626
private const string PasswordKeyword = "Password";
27+
private const string SslKeyword = "Ssl";
2728

2829
private static readonly IReadOnlyList<string> _validKeywords;
2930
private static readonly IReadOnlyDictionary<string, Keywords> _keywords;
@@ -35,6 +36,7 @@ public class PrestoConnectionStringBuilder : DbConnectionStringBuilder
3536
private bool _streaming = true;
3637
private bool _trino = false;
3738
private string _password = string.Empty;
39+
private bool? _ssl = null;
3840
private ImmutableList<KeyValuePair<string, string>> _extraCredentials = ImmutableList.Create<KeyValuePair<string, string>>();
3941

4042
private enum Keywords
@@ -46,12 +48,13 @@ private enum Keywords
4648
ExtraCredentials,
4749
Streaming,
4850
Trino,
49-
Password
51+
Password,
52+
Ssl,
5053
}
5154

5255
static PrestoConnectionStringBuilder()
5356
{
54-
var validKeywords = new string[8];
57+
var validKeywords = new string[9];
5558
validKeywords[(int)Keywords.DataSource] = DataSourceKeyword;
5659
validKeywords[(int)Keywords.User] = UserKeyword;
5760
validKeywords[(int)Keywords.Catalog] = CatalogKeyword;
@@ -60,9 +63,10 @@ static PrestoConnectionStringBuilder()
6063
validKeywords[(int)Keywords.Streaming] = StreamingKeyword;
6164
validKeywords[(int)Keywords.Trino] = TrinoKeyword;
6265
validKeywords[(int)Keywords.Password] = PasswordKeyword;
66+
validKeywords[(int)Keywords.Ssl] = SslKeyword;
6367
_validKeywords = validKeywords;
6468

65-
_keywords = new Dictionary<string, Keywords>(9, StringComparer.OrdinalIgnoreCase)
69+
_keywords = new Dictionary<string, Keywords>(10, StringComparer.OrdinalIgnoreCase)
6670
{
6771
[DataSourceKeyword] = Keywords.DataSource,
6872
[DataSourceNoSpaceKeyword] = Keywords.DataSource,
@@ -72,7 +76,8 @@ static PrestoConnectionStringBuilder()
7276
[ExtraCredentialKeyword] = Keywords.ExtraCredentials,
7377
[StreamingKeyword] = Keywords.Streaming,
7478
[TrinoKeyword] = Keywords.Trino,
75-
[PasswordKeyword] = Keywords.Password
79+
[PasswordKeyword] = Keywords.Password,
80+
[SslKeyword] = Keywords.Ssl
7681
};
7782
}
7883

@@ -129,6 +134,12 @@ public virtual bool Trino
129134
set => base[TrinoKeyword] = _trino = value;
130135
}
131136

137+
public virtual bool? Ssl
138+
{
139+
get => _ssl;
140+
set => base[SslKeyword] = _ssl = value;
141+
}
142+
132143
public virtual string Password
133144
{
134145
get => _password;
@@ -195,6 +206,9 @@ public override object this[string keyword]
195206
case Keywords.Password:
196207
Password = Convert.ToString(value, CultureInfo.InvariantCulture);
197208
return;
209+
case Keywords.Ssl:
210+
Ssl = Convert.ToBoolean(value, CultureInfo.InvariantCulture);
211+
return;
198212
default:
199213
Debug.Assert(false, "Unexpected keyword: " + keyword);
200214
return;
@@ -290,6 +304,8 @@ private object GetAt(Keywords index)
290304
return Trino;
291305
case Keywords.Password:
292306
return Password;
307+
case Keywords.Ssl:
308+
return Ssl;
293309
default:
294310
Debug.Assert(false, "Unexpected keyword: " + index);
295311
return null;
@@ -329,6 +345,9 @@ private void Reset(Keywords index)
329345
case Keywords.Password:
330346
_password = string.Empty;
331347
return;
348+
case Keywords.Ssl:
349+
_ssl = null;
350+
return;
332351
default:
333352
Debug.Assert(false, "Unexpected keyword: " + index);
334353
return;

0 commit comments

Comments
 (0)