99using System . Net . WebSockets ;
1010using System . Text ;
1111using System . Text . Json ;
12+ using System . Threading ;
1213using System . Threading . Tasks ;
1314
1415namespace Discord . API . Gateways
@@ -17,26 +18,135 @@ internal partial class Gateway
1718 {
1819 private readonly JsonSerializerOptions _serialiseOptions ;
1920 private readonly JsonSerializerOptions _deserialiseOptions ;
20- private WebSocketClient _socket ;
21+ private ClientWebSocket ? _socket ;
22+ private Task ? _task ;
23+ private CancellationTokenSource _tokenSource = new CancellationTokenSource ( ) ;
2124 private DeflateStream ? _decompressor ;
2225 private MemoryStream ? _decompressionBuffer ;
23-
24- private WebSocketClient CreateSocket ( )
25- {
26- _socket ? . Dispose ( ) ;
27- _socket = new WebSocketClient ( ) ;
28- _socket . TextMessage += HandleTextMessage ;
29- _socket . BinaryMessage += HandleBinaryMessage ;
30- _socket . Closed += HandleClosed ;
31- return _socket ;
32- }
26+
3327
3428 private void SetupCompression ( )
3529 {
3630 _decompressionBuffer = new MemoryStream ( ) ;
3731 _decompressor = new DeflateStream ( _decompressionBuffer , CompressionMode . Decompress ) ;
3832 }
39-
33+
34+ /// <summary>
35+ /// Sets up a connection to the gateway.
36+ /// </summary>
37+ /// <exception cref="Exception">An exception will be thrown when connection fails, but not when the handshake fails.</exception>
38+ public async Task Connect ( string token )
39+ {
40+ _token = token ;
41+ await ConnectAsync ( ) ;
42+ }
43+
44+ public async Task ConnectAsync ( )
45+ {
46+ GatewayStatus = GatewayStatus == GatewayStatus . Initialized ? GatewayStatus . Connecting : GatewayStatus . Reconnecting ;
47+ await ConnectAsync ( _gatewayConfig . GetFullGatewayUrl ( "json" , "9" , "&compress=zlib-stream" ) ) ;
48+ _task = Task . Run ( async ( ) =>
49+ {
50+ await ListenOnSocket ( ) ;
51+ _socket = null ;
52+ } ) ;
53+ }
54+
55+ /// <summary>
56+ /// Resumes a connection to the gateway.
57+ /// </summary>
58+ /// <exception cref="Exception">An exception will be thrown when connection fails, but not when the handshake fails.</exception>
59+ public async Task ResumeAsync ( )
60+ {
61+ GatewayStatus = GatewayStatus . Resuming ;
62+ await ConnectAsync ( _gatewayConfig . GetFullGatewayUrl ( "json" , "9" , "&compress=zlib-stream" ) ) ;
63+ _task = Task . Run ( ListenOnSocket ) ;
64+ }
65+
66+ private async Task ListenOnSocket ( )
67+ {
68+ var buffer = new ArraySegment < byte > ( new byte [ 16 * 1024 ] ) ;
69+ while ( _tokenSource . IsCancellationRequested && _socket ! . State == WebSocketState . Open )
70+ {
71+ WebSocketReceiveResult socketResult = await _socket . ReceiveAsync ( buffer , _tokenSource . Token ) . ConfigureAwait ( false ) ;
72+ if ( socketResult . MessageType == WebSocketMessageType . Close )
73+ {
74+ switch ( socketResult . CloseStatus )
75+ {
76+ case ( WebSocketCloseStatus ) 4000 :
77+ case ( WebSocketCloseStatus ) 4001 :
78+ case ( WebSocketCloseStatus ) 4002 :
79+ case ( WebSocketCloseStatus ) 4003 :
80+ case ( WebSocketCloseStatus ) 4005 :
81+ case ( WebSocketCloseStatus ) 4007 :
82+ case ( WebSocketCloseStatus ) 4008 :
83+ case ( WebSocketCloseStatus ) 4009 :
84+ GatewayStatus = GatewayStatus . Reconnecting ;
85+ _ = ConnectAsync ( ) ;
86+ return ;
87+
88+ case ( WebSocketCloseStatus ) 4004 :
89+ default :
90+ GatewayStatus = GatewayStatus . Disconnected ;
91+ return ;
92+
93+ }
94+ }
95+
96+ byte [ ] bytes = buffer . Array ;
97+ int length = socketResult . Count ;
98+
99+ if ( ! socketResult . EndOfMessage )
100+ {
101+ // This is a large message (likely just READY), lets create a temporary expandable stream
102+ var stream = new MemoryStream ( ) ;
103+ await stream . WriteAsync ( buffer . Array , 0 , socketResult . Count ) . ConfigureAwait ( false ) ;
104+ do
105+ {
106+ if ( _tokenSource . Token . IsCancellationRequested )
107+ {
108+ return ;
109+ }
110+ socketResult = await _socket . ReceiveAsync ( buffer , _tokenSource . Token ) . ConfigureAwait ( false ) ;
111+ await stream . WriteAsync ( buffer . Array , 0 , socketResult . Count ) . ConfigureAwait ( false ) ;
112+ }
113+ while ( ! socketResult . EndOfMessage ) ;
114+
115+ bytes = stream . GetBuffer ( ) ;
116+ length = ( int ) stream . Length ;
117+ }
118+
119+ if ( socketResult . MessageType == WebSocketMessageType . Text )
120+ {
121+ HandleTextMessage ( bytes ) ;
122+ }
123+ else
124+ {
125+ HandleBinaryMessage ( bytes , length ) ;
126+ }
127+ }
128+ }
129+
130+ private async Task ReconnectAsync ( )
131+ {
132+ await CloseSocket ( ) ;
133+ await ConnectAsync ( _connectionUrl ! ) ;
134+ }
135+ private async Task ConnectAsync ( string connectionUrl )
136+ {
137+ _connectionUrl = connectionUrl ;
138+ SetupCompression ( ) ;
139+ _tokenSource = new CancellationTokenSource ( ) ;
140+ _socket ??= new ClientWebSocket ( ) ;
141+
142+
143+ if ( _socket . State is WebSocketState . Connecting or WebSocketState . Open )
144+ {
145+ throw new Exception ( "Tried to connect to socket while already connected" ) ;
146+ }
147+ await _socket . ConnectAsync ( new Uri ( connectionUrl ) , CancellationToken . None ) ;
148+ }
149+
40150 private async Task SendMessageAsync < T > ( SocketFrame < T > frame )
41151 {
42152 var stream = new MemoryStream ( ) ;
@@ -46,56 +156,42 @@ private async Task SendMessageAsync<T>(SocketFrame<T> frame)
46156
47157 private async Task SendMessageAsync ( MemoryStream stream )
48158 {
49- try
50- {
51- await _socket . SendAsync ( stream . GetBuffer ( ) , 0 , ( int ) stream . Length , true ) ;
52- }
53- catch ( WebSocketClosedException exception )
54- {
55- GatewayClosed ( exception ) ;
56- }
159+ await _socket ! . SendAsync ( new ArraySegment < byte > ( stream . GetBuffer ( ) , 0 , ( int ) stream . Length ) , WebSocketMessageType . Text , true , _tokenSource . Token ) ;
57160 }
58161
59- private void HandleTextMessage ( string message )
162+ private void HandleTextMessage ( byte [ ] buffer )
60163 {
61- using var reader = new StreamReader ( new MemoryStream ( Encoding . ASCII . GetBytes ( message ) ) ) ;
62- HandleMessage ( reader ) ;
164+ HandleMessage ( new MemoryStream ( buffer ) ) ;
63165 }
64166
65- private void HandleBinaryMessage ( byte [ ] bytes , int _ , int count )
167+ private async void HandleBinaryMessage ( byte [ ] buffer , int count )
66168 {
67169 Guard . IsNotNull ( _decompressor , nameof ( _decompressor ) ) ;
68170 Guard . IsNotNull ( _decompressionBuffer , nameof ( _decompressionBuffer ) ) ;
69-
70- using var ms = new MemoryStream ( bytes ) ;
71- ms . Position = 0 ;
72- byte [ ] data = new byte [ count ] ;
73- ms . Read ( data , 0 , count ) ;
74- int index = 0 ;
171+
75172 using var decompressed = new MemoryStream ( ) ;
76- if ( data [ 0 ] == 0x78 )
173+
174+ if ( buffer [ 0 ] == 0x78 )
77175 {
78- _decompressionBuffer . Write ( data , index + 2 , count - 2 ) ;
176+ await _decompressionBuffer . WriteAsync ( buffer , 2 , count - 2 ) ;
79177 _decompressionBuffer . SetLength ( count - 2 ) ;
80178 }
81179 else
82180 {
83- _decompressionBuffer . Write ( data , index , count ) ;
181+ await _decompressionBuffer . WriteAsync ( buffer , 0 , count ) ;
84182 _decompressionBuffer . SetLength ( count ) ;
85183 }
86184
87185 _decompressionBuffer . Position = 0 ;
88- _decompressor . CopyTo ( decompressed ) ;
186+ await _decompressor . CopyToAsync ( decompressed ) ;
89187 _decompressionBuffer . Position = 0 ;
90188 decompressed . Position = 0 ;
91-
92- using var reader = new StreamReader ( decompressed ) ;
93- HandleMessage ( reader ) ;
189+
190+ HandleMessage ( decompressed ) ;
94191 }
95192
96- private async void HandleMessage ( TextReader reader )
193+ private async void HandleMessage ( Stream stream )
97194 {
98- Stream stream = ( ( StreamReader ) reader ) . BaseStream ;
99195 SocketFrame ? frame = await ParseFrame ( stream ) ;
100196 if ( frame is null ) return ;
101197
@@ -107,18 +203,14 @@ private async void HandleMessage(TextReader reader)
107203 ProcessEvents ( frame ) ;
108204 }
109205
110- private void HandleClosed ( Exception exception )
111- {
112- GatewayClosed ( exception ) ;
113- }
114-
115206 private async Task CloseSocket ( )
116207 {
117- if ( _socket != null )
118- {
119- await _socket . DisconnectAsync ( ( WebSocketCloseStatus ) 4000 ) ;
120- await _socket . DisconnectAsync ( ) ;
121- }
208+ if ( _socket is { State : WebSocketState . Open } )
209+ await _socket . CloseAsync ( ( WebSocketCloseStatus ) 4000 , string . Empty , CancellationToken . None ) ;
210+ _tokenSource . Cancel ( ) ;
211+ if ( _task != null )
212+ await _task ;
213+ _task = null ;
122214 }
123215
124216 private async Task < SocketFrame ? > ParseFrame ( Stream stream )
0 commit comments