11using Data . Presto . Models ;
22using Data . Presto . Utils ;
33using System ;
4+ using System . Collections . Concurrent ;
45using System . Collections . Generic ;
56using System . Net . Http ;
67using System . Net . Http . Headers ;
@@ -15,10 +16,28 @@ class PrestoClient
1516 private readonly HttpClient _httpClient ;
1617 private readonly PrestoConnectionStringBuilder _connectionString ;
1718
19+ private bool ? _useSsl = null ;
20+
21+ //Dictionary used to store which http protocol to use for connections where it is not marked explicitly.
22+ private static readonly ConcurrentDictionary < string , bool > _protocolLookup = new ConcurrentDictionary < string , bool > ( ) ;
23+
1824 public PrestoClient ( PrestoConnectionStringBuilder prestoConnectionString )
1925 {
2026 _connectionString = prestoConnectionString ;
2127 _httpClient = new HttpClient ( ) ;
28+
29+ if ( _connectionString . Ssl . HasValue )
30+ {
31+ _useSsl = _connectionString . Ssl . Value ;
32+ }
33+ else
34+ {
35+ //Check if another presto client has already done connections against the host
36+ if ( _protocolLookup . TryGetValue ( prestoConnectionString . Host , out var canUseSsl ) )
37+ {
38+ _useSsl = canUseSsl ;
39+ }
40+ }
2241 }
2342
2443 private void AddHeaders ( HttpRequestMessage httpRequestMessage )
@@ -100,18 +119,98 @@ private async Task<DecodeResult> CheckResult(HttpResponseMessage httpResponseMes
100119 return decodeResults ;
101120 }
102121
103- public async Task < DecodeResult > Query ( string statement , CancellationToken cancellationToken )
122+ private string GetProtocol ( )
104123 {
105- using var httpRequestMessage = new HttpRequestMessage ( )
124+ if ( ! _useSsl . Value )
106125 {
107- Method = HttpMethod . Post ,
108- RequestUri = new Uri ( $ "http:// { _connectionString . Host } /v1/statement" ) ,
109- Content = new StringContent ( statement )
110- } ;
126+ return "http://" ;
127+ }
128+ return "https://" ;
129+ }
111130
112- AddHeaders ( httpRequestMessage ) ;
131+ private void SetHostSsl ( in string host , bool canUseSsl )
132+ {
133+ if ( _useSsl == null )
134+ {
135+ _useSsl = canUseSsl ;
136+ _protocolLookup . AddOrUpdate ( host , canUseSsl , ( key , old ) => canUseSsl ) ;
137+ }
138+ }
113139
114- var result = await _httpClient . SendAsync ( httpRequestMessage ) . ConfigureAwait ( false ) ;
140+ private async Task < HttpResponseMessage > SendMessage ( HttpMethod httpMethod , string path , CancellationToken cancellationToken , string content = null )
141+ {
142+ //Protocol has not yet been determined
143+ if ( _useSsl == null )
144+ {
145+ try
146+ {
147+ using var httpRequestMessage = new HttpRequestMessage ( )
148+ {
149+ Method = httpMethod ,
150+ RequestUri = new Uri ( $ "https://{ _connectionString . Host } { path } ") ,
151+ } ;
152+
153+ if ( content != null )
154+ {
155+ httpRequestMessage . Content = new StringContent ( content ) ;
156+ }
157+
158+ AddHeaders ( httpRequestMessage ) ;
159+
160+ var response = await _httpClient . SendAsync ( httpRequestMessage ) . ConfigureAwait ( false ) ;
161+ SetHostSsl ( _connectionString . Host , true ) ;
162+ return response ;
163+ }
164+ catch ( HttpRequestException requestException )
165+ {
166+ if ( requestException ? . InnerException ? . Source == "System.Net.Security" )
167+ {
168+ //Exception regarding security, test http:// instead
169+ using var httpRequestMessage = new HttpRequestMessage ( )
170+ {
171+ Method = httpMethod ,
172+ RequestUri = new Uri ( $ "http://{ _connectionString . Host } { path } ") ,
173+ } ;
174+
175+ if ( content != null )
176+ {
177+ httpRequestMessage . Content = new StringContent ( content ) ;
178+ }
179+
180+ AddHeaders ( httpRequestMessage ) ;
181+
182+ var response = await _httpClient . SendAsync ( httpRequestMessage ) . ConfigureAwait ( false ) ;
183+ SetHostSsl ( _connectionString . Host , false ) ;
184+ return response ;
185+ }
186+ else
187+ {
188+ throw ;
189+ }
190+ }
191+ }
192+ else
193+ {
194+ using var httpRequestMessage = new HttpRequestMessage ( )
195+ {
196+ Method = httpMethod ,
197+ RequestUri = new Uri ( $ "{ GetProtocol ( ) } { _connectionString . Host } { path } ") ,
198+ } ;
199+
200+ if ( content != null )
201+ {
202+ httpRequestMessage . Content = new StringContent ( content ) ;
203+ }
204+
205+ AddHeaders ( httpRequestMessage ) ;
206+
207+ return await _httpClient . SendAsync ( httpRequestMessage ) . ConfigureAwait ( false ) ;
208+ }
209+ }
210+
211+ public async Task < DecodeResult > Query ( string statement , CancellationToken cancellationToken )
212+ {
213+ var result = await SendMessage ( HttpMethod . Post , "/v1/statement" , cancellationToken , statement ) ;
115214
116215 switch ( result . StatusCode )
117216 {
@@ -126,15 +225,7 @@ public async Task<DecodeResult> Query(string statement, CancellationToken cancel
126225
127226 public async Task KillQuery ( string queryId , CancellationToken cancellationToken )
128227 {
129- using var httpRequestMessage = new HttpRequestMessage ( )
130- {
131- Method = HttpMethod . Delete ,
132- RequestUri = new Uri ( $ "http://{ _connectionString . Host } /v1/query/{ queryId } ")
133- } ;
134-
135- AddHeaders ( httpRequestMessage ) ;
136-
137- var result = await _httpClient . SendAsync ( httpRequestMessage ) . ConfigureAwait ( false ) ;
228+ await SendMessage ( HttpMethod . Delete , $ "/v1/query/{ queryId } ", cancellationToken ) . ConfigureAwait ( false ) ;
138229 }
139230 }
140231}
0 commit comments