1+ using System . Net . WebSockets ;
2+ using System . Text ;
3+ using System . Text . Json ;
4+ using Microsoft . AspNetCore . Http ;
5+ using NzbWebDAV . Utils ;
6+ using Serilog ;
7+
8+ namespace NzbWebDAV . Websocket ;
9+
10+ public class WebsocketManager
11+ {
12+ private readonly HashSet < WebSocket > _authenticatedSockets = [ ] ;
13+ private readonly Dictionary < string , string > _lastMessage = new ( ) ;
14+
15+ public async Task HandleRoute ( HttpContext context )
16+ {
17+ if ( context . WebSockets . IsWebSocketRequest )
18+ {
19+ using var webSocket = await context . WebSockets . AcceptWebSocketAsync ( ) ;
20+ if ( ! await Authenticate ( webSocket ) )
21+ {
22+ Log . Warning ( $ "Closing unauthenticated websocket connection from { context . Connection . RemoteIpAddress } ") ;
23+ await CloseUnauthorizedConnection ( webSocket ) ;
24+ return ;
25+ }
26+
27+ // mark the socket as authenticated
28+ lock ( _authenticatedSockets )
29+ _authenticatedSockets . Add ( webSocket ) ;
30+
31+ // send current state for all topics
32+ List < KeyValuePair < string , string > > ? lastMessage = null ;
33+ lock ( _lastMessage ) lastMessage = _lastMessage . ToList ( ) ;
34+ foreach ( var message in lastMessage )
35+ await SendMessage ( webSocket , message . Key , message . Value ) ;
36+
37+ // wait for the socket to disconnect
38+ await WaitForDisconnected ( webSocket ) ;
39+ lock ( _authenticatedSockets )
40+ _authenticatedSockets . Remove ( webSocket ) ;
41+ }
42+ else
43+ {
44+ context . Response . StatusCode = 400 ;
45+ }
46+ }
47+
48+ /// <summary>
49+ /// Send a message to all authenticated websockets.
50+ /// </summary>
51+ /// <param name="topic">The topic of the message to send</param>
52+ /// <param name="message">The message to send</param>
53+ public Task SendMessage ( string topic , string message )
54+ {
55+ lock ( _lastMessage ) _lastMessage [ topic ] = message ;
56+ List < WebSocket > ? authenticatedSockets ;
57+ lock ( _authenticatedSockets ) authenticatedSockets = _authenticatedSockets . ToList ( ) ;
58+ var topicMessage = new TopicMessage ( topic , message ) ;
59+ var bytes = new ArraySegment < byte > ( Encoding . UTF8 . GetBytes ( topicMessage . ToString ( ) ) ) ;
60+ return Task . WhenAll ( authenticatedSockets . Select ( x => SendMessage ( x , bytes ) ) ) ;
61+ }
62+
63+ /// <summary>
64+ /// Ensure a websocket sends a valid api key.
65+ /// </summary>
66+ /// <param name="socket">The websocket to authenticate.</param>
67+ /// <returns>True if authenticated, False otherwise.</returns>
68+ private static async Task < bool > Authenticate ( WebSocket socket )
69+ {
70+ var apiKey = await ReceiveAuthToken ( socket ) ;
71+ return apiKey == EnvironmentUtil . GetVariable ( "FRONTEND_BACKEND_API_KEY" ) ;
72+ }
73+
74+ /// <summary>
75+ /// Ignore all messages from the websocket and
76+ /// wait for it to disconnect.
77+ /// </summary>
78+ /// <param name="socket">The websocket to wait for disconnect.</param>
79+ private static async Task WaitForDisconnected ( WebSocket socket )
80+ {
81+ var buffer = new byte [ 1024 ] ;
82+ WebSocketReceiveResult ? result = null ;
83+ while ( result is not { CloseStatus : not null } )
84+ result = await socket . ReceiveAsync ( new ArraySegment < byte > ( buffer ) , default ) ;
85+ await socket . CloseAsync ( result . CloseStatus . Value , result . CloseStatusDescription , default ) ;
86+ }
87+
88+ /// <summary>
89+ /// Send a message to a connected websocket.
90+ /// </summary>
91+ /// <param name="socket">The websocket to send the message to.</param>
92+ /// <param name="topic">The topic of the message to send</param>
93+ /// <param name="message">The message to send</param>
94+ private static async Task SendMessage ( WebSocket socket , string topic , string message )
95+ {
96+ var topicMessage = new TopicMessage ( topic , message ) ;
97+ var bytes = new ArraySegment < byte > ( Encoding . UTF8 . GetBytes ( topicMessage . ToString ( ) ) ) ;
98+ await SendMessage ( socket , bytes ) ;
99+ }
100+
101+ /// <summary>
102+ /// Send a message to a connected websocket.
103+ /// </summary>
104+ /// <param name="socket">The websocket to send the message to.</param>
105+ /// <param name="message">The message to send.</param>
106+ private static async Task SendMessage ( WebSocket socket , ArraySegment < byte > message )
107+ {
108+ try
109+ {
110+ await socket . SendAsync ( message , WebSocketMessageType . Text , true , default ) ;
111+ }
112+ catch ( Exception e )
113+ {
114+ Log . Debug ( $ "Failed to send message to websocket. { e . Message } ") ;
115+ }
116+ }
117+
118+ /// <summary>
119+ /// Receive an authentication token from a connected websocket.
120+ /// With timeout after five seconds.
121+ /// </summary>
122+ /// <param name="socket">The websocket to receive from.</param>
123+ /// <returns>The authentication token. Or null if none provided.</returns>
124+ private static async Task < string ? > ReceiveAuthToken ( WebSocket socket )
125+ {
126+ try
127+ {
128+ var buffer = new byte [ 1024 ] ;
129+ using var cts = new CancellationTokenSource ( ) ;
130+ cts . CancelAfter ( TimeSpan . FromSeconds ( 5 ) ) ;
131+ var result = await socket . ReceiveAsync ( new ArraySegment < byte > ( buffer ) , cts . Token ) ;
132+ return result . MessageType == WebSocketMessageType . Text
133+ ? Encoding . UTF8 . GetString ( buffer , 0 , result . Count )
134+ : null ;
135+ }
136+ catch ( OperationCanceledException )
137+ {
138+ return null ;
139+ }
140+ }
141+
142+ /// <summary>
143+ /// Close a websocket connection as unauthorized.
144+ /// </summary>
145+ /// <param name="socket">The websocket whose connection to close.</param>
146+ private static async Task CloseUnauthorizedConnection ( WebSocket socket )
147+ {
148+ if ( socket . State == WebSocketState . Open )
149+ await socket . CloseAsync ( WebSocketCloseStatus . PolicyViolation , "Unauthorized" , CancellationToken . None ) ;
150+ }
151+
152+ private sealed class TopicMessage ( string topic , string message )
153+ {
154+ public string Topic { get ; } = topic ;
155+ public string Message { get ; } = message ;
156+ public override string ToString ( ) => JsonSerializer . Serialize ( this ) ;
157+ }
158+ }
0 commit comments