Skip to content

Commit 74cc483

Browse files
committed
Support Client TLS
1 parent f494ac2 commit 74cc483

File tree

2 files changed

+41
-19
lines changed

2 files changed

+41
-19
lines changed

src/Apache.IoTDB/SessionPool.Builder.cs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
* under the License.
1818
*/
1919

20-
using System;
2120
using System.Collections.Generic;
2221

2322
namespace Apache.IoTDB;
@@ -35,6 +34,8 @@ public class Builder
3534
private int _poolSize = 8;
3635
private bool _enableRpcCompression = false;
3736
private int _connectionTimeoutInMs = 500;
37+
private bool _useSsl = false;
38+
private string _certificatePath = null;
3839
private string _sqlDialect = IoTDBConstant.TREE_SQL_DIALECT;
3940
private string _database = "";
4041
private List<string> _nodeUrls = new List<string>();
@@ -93,6 +94,18 @@ public Builder SetConnectionTimeoutInMs(int timeout)
9394
return this;
9495
}
9596

97+
public Builder SetUseSsl(bool useSsl)
98+
{
99+
_useSsl = useSsl;
100+
return this;
101+
}
102+
103+
public Builder SetCertificatePath(string certificatePath)
104+
{
105+
_certificatePath = certificatePath;
106+
return this;
107+
}
108+
96109
public Builder SetNodeUrl(List<string> nodeUrls)
97110
{
98111
_nodeUrls = nodeUrls;
@@ -122,6 +135,8 @@ public Builder()
122135
_poolSize = 8;
123136
_enableRpcCompression = false;
124137
_connectionTimeoutInMs = 500;
138+
_useSsl = false;
139+
_certificatePath = null;
125140
_sqlDialect = IoTDBConstant.TREE_SQL_DIALECT;
126141
_database = "";
127142
}
@@ -131,9 +146,9 @@ public SessionPool Build()
131146
// if nodeUrls is not empty, use nodeUrls to create session pool
132147
if (_nodeUrls.Count > 0)
133148
{
134-
return new SessionPool(_nodeUrls, _username, _password, _fetchSize, _zoneId, _poolSize, _enableRpcCompression, _connectionTimeoutInMs, _sqlDialect, _database);
149+
return new SessionPool(_nodeUrls, _username, _password, _fetchSize, _zoneId, _poolSize, _enableRpcCompression, _connectionTimeoutInMs, _useSsl, _certificatePath ,_sqlDialect, _database);
135150
}
136-
return new SessionPool(_host, _port, _username, _password, _fetchSize, _zoneId, _poolSize, _enableRpcCompression, _connectionTimeoutInMs, _sqlDialect, _database);
151+
return new SessionPool(_host, _port, _username, _password, _fetchSize, _zoneId, _poolSize, _enableRpcCompression, _connectionTimeoutInMs, _useSsl, _certificatePath, _sqlDialect, _database);
137152
}
138153
}
139154
}

src/Apache.IoTDB/SessionPool.cs

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,12 @@
1919

2020
using System;
2121
using System.Collections.Generic;
22+
using System.IO;
2223
using System.Linq;
23-
using System.Net.Sockets;
24-
using System.Numerics;
2524
using System.Threading;
2625
using System.Threading.Tasks;
26+
using System.Security.Cryptography.X509Certificates;
2727
using Apache.IoTDB.DataStructure;
28-
using Microsoft.Extensions.Configuration;
2928
using Microsoft.Extensions.Logging;
3029
using Thrift;
3130
using Thrift.Protocol;
@@ -47,6 +46,8 @@ public partial class SessionPool : IDisposable
4746
private readonly List<TEndPoint> _endPoints = new();
4847
private readonly string _host;
4948
private readonly int _port;
49+
private readonly bool _useSsl;
50+
private readonly string _certificatePath;
5051
private readonly int _fetchSize;
5152
/// <summary>
5253
/// _timeout is the amount of time a Session will wait for a send operation to complete successfully.
@@ -86,10 +87,10 @@ public SessionPool(string host, int port) : this(host, port, "root", "root", 102
8687
{
8788
}
8889
public SessionPool(string host, int port, string username, string password, int fetchSize, string zoneId, int poolSize, bool enableRpcCompression, int timeout)
89-
: this(host, port, username, password, fetchSize, zoneId, poolSize, enableRpcCompression, timeout, IoTDBConstant.TREE_SQL_DIALECT, "")
90+
: this(host, port, username, password, fetchSize, zoneId, poolSize, enableRpcCompression, timeout, false, null, IoTDBConstant.TREE_SQL_DIALECT, "")
9091
{
9192
}
92-
protected internal SessionPool(string host, int port, string username, string password, int fetchSize, string zoneId, int poolSize, bool enableRpcCompression, int timeout, string sqlDialect, string database)
93+
protected internal SessionPool(string host, int port, string username, string password, int fetchSize, string zoneId, int poolSize, bool enableRpcCompression, int timeout, bool useSsl, string certificatePath, string sqlDialect, string database)
9394
{
9495
_host = host;
9596
_port = port;
@@ -101,6 +102,8 @@ protected internal SessionPool(string host, int port, string username, string pa
101102
_poolSize = poolSize;
102103
_enableRpcCompression = enableRpcCompression;
103104
_timeout = timeout;
105+
_useSsl = useSsl;
106+
_certificatePath = certificatePath;
104107
_sqlDialect = sqlDialect;
105108
_database = database;
106109
}
@@ -126,11 +129,11 @@ public SessionPool(List<string> nodeUrls, string username, string password, int
126129
{
127130
}
128131
public SessionPool(List<string> nodeUrls, string username, string password, int fetchSize, string zoneId, int poolSize, bool enableRpcCompression, int timeout)
129-
: this(nodeUrls, username, password, fetchSize, zoneId, poolSize, enableRpcCompression, timeout, IoTDBConstant.TREE_SQL_DIALECT, "")
132+
: this(nodeUrls, username, password, fetchSize, zoneId, poolSize, enableRpcCompression, timeout,false, null, IoTDBConstant.TREE_SQL_DIALECT, "")
130133
{
131134

132135
}
133-
protected internal SessionPool(List<string> nodeUrls, string username, string password, int fetchSize, string zoneId, int poolSize, bool enableRpcCompression, int timeout, string sqlDialect, string database)
136+
protected internal SessionPool(List<string> nodeUrls, string username, string password, int fetchSize, string zoneId, int poolSize, bool enableRpcCompression, int timeout, bool useSsl, string certificatePath, string sqlDialect, string database)
134137
{
135138
if (nodeUrls.Count == 0)
136139
{
@@ -146,6 +149,8 @@ protected internal SessionPool(List<string> nodeUrls, string username, string pa
146149
_poolSize = poolSize;
147150
_enableRpcCompression = enableRpcCompression;
148151
_timeout = timeout;
152+
_useSsl = useSsl;
153+
_certificatePath = certificatePath;
149154
_sqlDialect = sqlDialect;
150155
_database = database;
151156
}
@@ -241,7 +246,7 @@ public async Task Open(CancellationToken cancellationToken = default)
241246
{
242247
try
243248
{
244-
_clients.Add(await CreateAndOpen(_host, _port, _enableRpcCompression, _timeout, _sqlDialect, _database, cancellationToken));
249+
_clients.Add(await CreateAndOpen(_host, _port, _enableRpcCompression, _timeout, _useSsl,_certificatePath, _sqlDialect, _database, cancellationToken));
245250
}
246251
catch (Exception e)
247252
{
@@ -264,7 +269,7 @@ public async Task Open(CancellationToken cancellationToken = default)
264269
var endPoint = _endPoints[endPointIndex];
265270
try
266271
{
267-
var client = await CreateAndOpen(endPoint.Ip, endPoint.Port, _enableRpcCompression, _timeout, _sqlDialect, _database, cancellationToken);
272+
var client = await CreateAndOpen(endPoint.Ip, endPoint.Port, _enableRpcCompression, _timeout, _useSsl,_certificatePath, _sqlDialect, _database, cancellationToken);
268273
_clients.Add(client);
269274
isConnected = true;
270275
startIndex = (endPointIndex + 1) % _endPoints.Count;
@@ -303,7 +308,7 @@ public async Task<Client> Reconnect(Client originalClient = null, CancellationTo
303308
{
304309
try
305310
{
306-
var client = await CreateAndOpen(_host, _port, _enableRpcCompression, _timeout, _sqlDialect, _database, cancellationToken);
311+
var client = await CreateAndOpen(_host, _port, _enableRpcCompression, _timeout, _useSsl,_certificatePath, _sqlDialect, _database, cancellationToken);
307312
return client;
308313
}
309314
catch (Exception e)
@@ -330,7 +335,7 @@ public async Task<Client> Reconnect(Client originalClient = null, CancellationTo
330335
int j = (startIndex + i) % _endPoints.Count;
331336
try
332337
{
333-
var client = await CreateAndOpen(_endPoints[j].Ip, _endPoints[j].Port, _enableRpcCompression, _timeout, _sqlDialect, _database, cancellationToken);
338+
var client = await CreateAndOpen(_endPoints[j].Ip, _endPoints[j].Port, _enableRpcCompression, _timeout, _useSsl,_certificatePath, _sqlDialect, _database, cancellationToken);
334339
return client;
335340
}
336341
catch (Exception e)
@@ -423,12 +428,14 @@ public async Task<string> GetTimeZone()
423428
}
424429
}
425430

426-
private async Task<Client> CreateAndOpen(string host, int port, bool enableRpcCompression, int timeout, string sqlDialect, string database, CancellationToken cancellationToken = default)
431+
private async Task<Client> CreateAndOpen(string host, int port, bool enableRpcCompression, int timeout, bool useSsl, string cert, string sqlDialect, string database, CancellationToken cancellationToken = default)
427432
{
428-
var tcpClient = new TcpClient(host, port);
429-
tcpClient.SendTimeout = timeout;
430-
tcpClient.ReceiveTimeout = timeout;
431-
var transport = new TFramedTransport(new TSocketTransport(tcpClient, null));
433+
434+
TTransport socket = useSsl ?
435+
new TTlsSocketTransport(host, port, null, timeout, new X509Certificate2(File.ReadAllBytes(cert))) :
436+
new TSocketTransport(host, port, null, timeout);
437+
438+
var transport = new TFramedTransport(socket);
432439

433440
if (!transport.IsOpen)
434441
{

0 commit comments

Comments
 (0)