4
4
using System . IO ;
5
5
using System . Linq ;
6
6
using System . Net ;
7
+ using System . Net . Security ;
7
8
using System . Net . Sockets ;
8
9
using System . Text ;
9
10
using System . Threading ;
@@ -18,16 +19,19 @@ public partial class PooledSocket : IDisposable
18
19
private readonly ILogger _logger ;
19
20
20
21
private bool _isAlive ;
22
+ private bool _useSslStream ;
21
23
private Socket _socket ;
22
24
private readonly EndPoint _endpoint ;
23
25
private readonly int _connectionTimeout ;
24
26
25
27
private NetworkStream _inputStream ;
28
+ private SslStream _sslStream ;
26
29
27
- public PooledSocket ( EndPoint endpoint , TimeSpan connectionTimeout , TimeSpan receiveTimeout , ILogger logger )
30
+ public PooledSocket ( EndPoint endpoint , TimeSpan connectionTimeout , TimeSpan receiveTimeout , ILogger logger , bool useSslStream )
28
31
{
29
32
_logger = logger ;
30
33
_isAlive = true ;
34
+ _useSslStream = useSslStream ;
31
35
32
36
var socket = new Socket ( AddressFamily . InterNetwork , SocketType . Stream , ProtocolType . Tcp ) ;
33
37
socket . SetSocketOption ( SocketOptionLevel . Socket , SocketOptionName . KeepAlive , true ) ;
@@ -90,7 +94,15 @@ void Cancel()
90
94
91
95
if ( success )
92
96
{
93
- _inputStream = new NetworkStream ( _socket ) ;
97
+ if ( _useSslStream )
98
+ {
99
+ _sslStream = new SslStream ( new NetworkStream ( _socket ) ) ;
100
+ _sslStream . AuthenticateAsClient ( ( ( DnsEndPoint ) _endpoint ) . Host ) ;
101
+ }
102
+ else
103
+ {
104
+ _inputStream = new NetworkStream ( _socket ) ;
105
+ }
94
106
}
95
107
else
96
108
{
@@ -141,7 +153,15 @@ public async Task ConnectAsync()
141
153
142
154
if ( success )
143
155
{
144
- _inputStream = new NetworkStream ( _socket ) ;
156
+ if ( _useSslStream )
157
+ {
158
+ _sslStream = new SslStream ( new NetworkStream ( _socket ) ) ;
159
+ await _sslStream . AuthenticateAsClientAsync ( ( ( DnsEndPoint ) _endpoint ) . Host ) ;
160
+ }
161
+ else
162
+ {
163
+ _inputStream = new NetworkStream ( _socket ) ;
164
+ }
145
165
}
146
166
else
147
167
{
@@ -251,7 +271,13 @@ protected void Dispose(bool disposing)
251
271
_inputStream . Dispose ( ) ;
252
272
}
253
273
274
+ if ( _sslStream != null )
275
+ {
276
+ _sslStream . Dispose ( ) ;
277
+ }
278
+
254
279
_inputStream = null ;
280
+ _sslStream = null ;
255
281
_socket = null ;
256
282
this . CleanupCallback = null ;
257
283
}
@@ -290,7 +316,7 @@ public int ReadByte()
290
316
291
317
try
292
318
{
293
- return _inputStream . ReadByte ( ) ;
319
+ return ( _useSslStream ? _sslStream . ReadByte ( ) : _inputStream . ReadByte ( ) ) ;
294
320
}
295
321
catch ( Exception ex )
296
322
{
@@ -309,7 +335,7 @@ public int ReadByteAsync()
309
335
310
336
try
311
337
{
312
- return _inputStream . ReadByte ( ) ;
338
+ return ( _useSslStream ? _sslStream . ReadByte ( ) : _inputStream . ReadByte ( ) ) ;
313
339
}
314
340
catch ( Exception ex )
315
341
{
@@ -332,7 +358,7 @@ public async Task ReadAsync(byte[] buffer, int offset, int count)
332
358
{
333
359
try
334
360
{
335
- int currentRead = await _inputStream . ReadAsync ( buffer , offset , shouldRead ) ;
361
+ int currentRead = ( _useSslStream ? await _sslStream . ReadAsync ( buffer , offset , shouldRead ) : await _inputStream . ReadAsync ( buffer , offset , shouldRead ) ) ;
336
362
if ( currentRead == count )
337
363
break ;
338
364
if ( currentRead < 1 )
@@ -372,7 +398,7 @@ public void Read(byte[] buffer, int offset, int count)
372
398
{
373
399
try
374
400
{
375
- int currentRead = _inputStream . Read ( buffer , offset , shouldRead ) ;
401
+ int currentRead = ( _useSslStream ? _sslStream . Read ( buffer , offset , shouldRead ) : _inputStream . Read ( buffer , offset , shouldRead ) ) ;
376
402
if ( currentRead == count )
377
403
break ;
378
404
if ( currentRead < 1 )
@@ -397,15 +423,34 @@ public void Write(byte[] data, int offset, int length)
397
423
{
398
424
this . CheckDisposed ( ) ;
399
425
400
- SocketError status ;
426
+ if ( _useSslStream )
427
+ {
428
+ try
429
+ {
430
+ _inputStream . Write ( data , offset , length ) ;
431
+ _inputStream . Flush ( ) ;
432
+ }
433
+ catch ( Exception ex )
434
+ {
435
+ if ( ex is IOException || ex is SocketException )
436
+ {
437
+ _isAlive = false ;
438
+ }
439
+ throw ;
440
+ }
441
+ }
442
+ else
443
+ {
444
+ SocketError status ;
401
445
402
- _socket . Send ( data , offset , length , SocketFlags . None , out status ) ;
446
+ _socket . Send ( data , offset , length , SocketFlags . None , out status ) ;
403
447
404
- if ( status != SocketError . Success )
405
- {
406
- _isAlive = false ;
448
+ if ( status != SocketError . Success )
449
+ {
450
+ _isAlive = false ;
407
451
408
- ThrowHelper . ThrowSocketWriteError ( _endpoint , status ) ;
452
+ ThrowHelper . ThrowSocketWriteError ( _endpoint , status ) ;
453
+ }
409
454
}
410
455
}
411
456
@@ -417,11 +462,22 @@ public void Write(IList<ArraySegment<byte>> buffers)
417
462
418
463
try
419
464
{
420
- _socket . Send ( buffers , SocketFlags . None , out status ) ;
421
- if ( status != SocketError . Success )
465
+ if ( _useSslStream )
422
466
{
423
- _isAlive = false ;
424
- ThrowHelper . ThrowSocketWriteError ( _endpoint , status ) ;
467
+ foreach ( var buf in buffers )
468
+ {
469
+ _sslStream . Write ( buf . Array ) ;
470
+ }
471
+ _sslStream . Flush ( ) ;
472
+ }
473
+ else
474
+ {
475
+ _socket . Send ( buffers , SocketFlags . None , out status ) ;
476
+ if ( status != SocketError . Success )
477
+ {
478
+ _isAlive = false ;
479
+ ThrowHelper . ThrowSocketWriteError ( _endpoint , status ) ;
480
+ }
425
481
}
426
482
}
427
483
catch ( Exception ex )
@@ -441,12 +497,23 @@ public async Task WriteAsync(IList<ArraySegment<byte>> buffers)
441
497
442
498
try
443
499
{
444
- var bytesTransferred = await _socket . SendAsync ( buffers , SocketFlags . None ) ;
445
- if ( bytesTransferred <= 0 )
500
+ if ( _useSslStream )
446
501
{
447
- _isAlive = false ;
448
- _logger . LogError ( $ "Failed to { nameof ( PooledSocket . WriteAsync ) } . bytesTransferred: { bytesTransferred } ") ;
449
- ThrowHelper . ThrowSocketWriteError ( _endpoint ) ;
502
+ foreach ( var buf in buffers )
503
+ {
504
+ await _sslStream . WriteAsync ( buf . Array , 0 , buf . Count ) ;
505
+ }
506
+ await _sslStream . FlushAsync ( ) ;
507
+ }
508
+ else
509
+ {
510
+ var bytesTransferred = await _socket . SendAsync ( buffers , SocketFlags . None ) ;
511
+ if ( bytesTransferred <= 0 )
512
+ {
513
+ _isAlive = false ;
514
+ _logger . LogError ( $ "Failed to { nameof ( PooledSocket . WriteAsync ) } . bytesTransferred: { bytesTransferred } ") ;
515
+ ThrowHelper . ThrowSocketWriteError ( _endpoint ) ;
516
+ }
450
517
}
451
518
}
452
519
catch ( Exception ex )
0 commit comments