Skip to content

Commit ff4b3de

Browse files
committed
Added websocket handler to backend.
1 parent e7dfd4a commit ff4b3de

File tree

5 files changed

+207
-8
lines changed

5 files changed

+207
-8
lines changed

backend/Clients/UsenetStreamingClient.cs

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1-
using Microsoft.Extensions.Caching.Memory;
1+
using System.Text.Json;
2+
using Microsoft.EntityFrameworkCore.Storage.Json;
3+
using Microsoft.Extensions.Caching.Memory;
24
using NzbWebDAV.Clients.Connections;
35
using NzbWebDAV.Config;
46
using NzbWebDAV.Exceptions;
57
using NzbWebDAV.Extensions;
68
using NzbWebDAV.Streams;
9+
using NzbWebDAV.Websocket;
710
using Usenet.Nntp.Responses;
811
using Usenet.Nzb;
912
using Usenet.Yenc;
@@ -13,9 +16,13 @@ namespace NzbWebDAV.Clients;
1316
public class UsenetStreamingClient
1417
{
1518
private readonly INntpClient _client;
19+
private readonly WebsocketManager _websocketManager;
1620

17-
public UsenetStreamingClient(ConfigManager configManager)
21+
public UsenetStreamingClient(ConfigManager configManager, WebsocketManager websocketManager)
1822
{
23+
// initialize private members
24+
_websocketManager = websocketManager;
25+
1926
// get connection settings from config-manager
2027
var host = configManager.GetConfigValue("usenet.host") ?? string.Empty;
2128
var port = int.Parse(configManager.GetConfigValue("usenet.port") ?? "119");
@@ -26,7 +33,7 @@ public UsenetStreamingClient(ConfigManager configManager)
2633

2734
// initialize the nntp-client
2835
var createNewConnection = (CancellationToken ct) => CreateNewConnection(host, port, useSsl, user, pass, ct);
29-
ConnectionPool<INntpClient> connectionPool = new(connections, createNewConnection);
36+
var connectionPool = CreateNewConnectionPool(connections, createNewConnection);
3037
var multiConnectionClient = new MultiConnectionNntpClient(connectionPool);
3138
var cache = new MemoryCache(new MemoryCacheOptions() { SizeLimit = 8192 });
3239
_client = new CachingNntpClient(multiConnectionClient, cache);
@@ -49,8 +56,9 @@ public UsenetStreamingClient(ConfigManager configManager)
4956
var newUseSsl = bool.Parse(configEventArgs.NewConfig.GetValueOrDefault("usenet.use-ssl", "false"));
5057
var newUser = configEventArgs.NewConfig["usenet.user"];
5158
var newPass = configEventArgs.NewConfig["usenet.pass"];
52-
multiConnectionClient.UpdateConnectionPool(new(connectionCount, cancellationToken =>
53-
CreateNewConnection(newHost, newPort, newUseSsl, newUser, newPass, cancellationToken)));
59+
var newConnectionPool = CreateNewConnectionPool(connectionCount, cancellationToken =>
60+
CreateNewConnection(newHost, newPort, newUseSsl, newUser, newPass, cancellationToken));
61+
multiConnectionClient.UpdateConnectionPool(newConnectionPool);
5462
};
5563
}
5664

@@ -100,6 +108,25 @@ public Task<long> GetFileSizeAsync(NzbFile file, CancellationToken cancellationT
100108
return _client.GetFileSizeAsync(file, cancellationToken);
101109
}
102110

111+
private ConnectionPool<INntpClient> CreateNewConnectionPool
112+
(
113+
int maxConnections,
114+
Func<CancellationToken, ValueTask<INntpClient>> connectionFactory
115+
)
116+
{
117+
var connectionPool = new ConnectionPool<INntpClient>(maxConnections, connectionFactory);
118+
connectionPool.OnConnectionPoolChanged += OnConnectionPoolChanged;
119+
var args = new ConnectionPool<INntpClient>.ConnectionPoolChangedEventArgs(0, 0, maxConnections);
120+
OnConnectionPoolChanged(connectionPool, args);
121+
return connectionPool;
122+
}
123+
124+
private void OnConnectionPoolChanged(object? _, ConnectionPool<INntpClient>.ConnectionPoolChangedEventArgs args)
125+
{
126+
var message = $"{args.Live}|{args.Max}|{args.Idle}";
127+
_websocketManager.SendMessage(WebsocketTopic.UsenetConnections, message);
128+
}
129+
103130
public static async ValueTask<INntpClient> CreateNewConnection
104131
(
105132
string host,

backend/Extensions/NWebDavOptionsExtensions.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ public static Func<HttpContext, bool> GetFilter(this NWebDavOptions options)
99
{
1010
return context => !context.Request.Path.StartsWithSegments("/api") &&
1111
!context.Request.Path.StartsWithSegments("/view") &&
12-
!context.Request.Path.StartsWithSegments("/health");
12+
!context.Request.Path.StartsWithSegments("/health") &&
13+
!context.Request.Path.StartsWithSegments("/ws");
1314
}
1415
}

backend/Program.cs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
using NzbWebDAV.Utils;
1717
using NzbWebDAV.WebDav;
1818
using NzbWebDAV.WebDav.Base;
19+
using NzbWebDAV.Websocket;
1920
using Serilog;
2021
using Serilog.Events;
2122
using Serilog.Sinks.SystemConsole.Themes;
@@ -47,13 +48,17 @@ static async Task Main(string[] args)
4748
var configManager = new ConfigManager();
4849
await configManager.LoadConfig();
4950

51+
// initialize websocket-manager
52+
var websocketManager = new WebsocketManager();
53+
5054
// initialize webapp
5155
var builder = WebApplication.CreateBuilder(args);
5256
builder.Host.UseSerilog();
5357
builder.Services.AddControllers();
5458
builder.Services.AddHealthChecks();
5559
builder.Services
5660
.AddSingleton(configManager)
61+
.AddSingleton(websocketManager)
5762
.AddSingleton<UsenetStreamingClient>()
5863
.AddSingleton<QueueManager>()
5964
.AddScoped<DavDatabaseContext>()
@@ -86,11 +91,13 @@ static async Task Main(string[] args)
8691

8792
// run
8893
var app = builder.Build();
89-
app.MapHealthChecks("/health");
9094
app.UseSerilogRequestLogging();
9195
app.UseMiddleware<ExceptionMiddleware>();
92-
app.UseAuthentication();
96+
app.UseWebSockets();
97+
app.MapHealthChecks("/health");
98+
app.Map("/ws", websocketManager.HandleRoute);
9399
app.MapControllers();
100+
app.UseAuthentication();
94101
app.UseNWebDav();
95102
await app.RunAsync();
96103
}
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
using System.Net.WebSockets;
2+
using System.Text;
3+
using System.Text.Json;
4+
using Microsoft.AspNetCore.Http;
5+
using NzbWebDAV.Utils;
6+
using Serilog;
7+
8+
namespace NzbWebDAV.Websocket;
9+
10+
public class WebsocketManager
11+
{
12+
private readonly HashSet<WebSocket> _authenticatedSockets = [];
13+
private readonly Dictionary<string, string> _lastMessage = new();
14+
15+
public async Task HandleRoute(HttpContext context)
16+
{
17+
if (context.WebSockets.IsWebSocketRequest)
18+
{
19+
using var webSocket = await context.WebSockets.AcceptWebSocketAsync();
20+
if (!await Authenticate(webSocket))
21+
{
22+
Log.Warning($"Closing unauthenticated websocket connection from {context.Connection.RemoteIpAddress}");
23+
await CloseUnauthorizedConnection(webSocket);
24+
return;
25+
}
26+
27+
// mark the socket as authenticated
28+
lock (_authenticatedSockets)
29+
_authenticatedSockets.Add(webSocket);
30+
31+
// send current state for all topics
32+
List<KeyValuePair<string, string>>? lastMessage = null;
33+
lock (_lastMessage) lastMessage = _lastMessage.ToList();
34+
foreach (var message in lastMessage)
35+
await SendMessage(webSocket, message.Key, message.Value);
36+
37+
// wait for the socket to disconnect
38+
await WaitForDisconnected(webSocket);
39+
lock (_authenticatedSockets)
40+
_authenticatedSockets.Remove(webSocket);
41+
}
42+
else
43+
{
44+
context.Response.StatusCode = 400;
45+
}
46+
}
47+
48+
/// <summary>
49+
/// Send a message to all authenticated websockets.
50+
/// </summary>
51+
/// <param name="topic">The topic of the message to send</param>
52+
/// <param name="message">The message to send</param>
53+
public Task SendMessage(string topic, string message)
54+
{
55+
lock (_lastMessage) _lastMessage[topic] = message;
56+
List<WebSocket>? authenticatedSockets;
57+
lock (_authenticatedSockets) authenticatedSockets = _authenticatedSockets.ToList();
58+
var topicMessage = new TopicMessage(topic, message);
59+
var bytes = new ArraySegment<byte>(Encoding.UTF8.GetBytes(topicMessage.ToString()));
60+
return Task.WhenAll(authenticatedSockets.Select(x => SendMessage(x, bytes)));
61+
}
62+
63+
/// <summary>
64+
/// Ensure a websocket sends a valid api key.
65+
/// </summary>
66+
/// <param name="socket">The websocket to authenticate.</param>
67+
/// <returns>True if authenticated, False otherwise.</returns>
68+
private static async Task<bool> Authenticate(WebSocket socket)
69+
{
70+
var apiKey = await ReceiveAuthToken(socket);
71+
return apiKey == EnvironmentUtil.GetVariable("FRONTEND_BACKEND_API_KEY");
72+
}
73+
74+
/// <summary>
75+
/// Ignore all messages from the websocket and
76+
/// wait for it to disconnect.
77+
/// </summary>
78+
/// <param name="socket">The websocket to wait for disconnect.</param>
79+
private static async Task WaitForDisconnected(WebSocket socket)
80+
{
81+
var buffer = new byte[1024];
82+
WebSocketReceiveResult? result = null;
83+
while (result is not { CloseStatus: not null })
84+
result = await socket.ReceiveAsync(new ArraySegment<byte>(buffer), default);
85+
await socket.CloseAsync(result.CloseStatus.Value, result.CloseStatusDescription, default);
86+
}
87+
88+
/// <summary>
89+
/// Send a message to a connected websocket.
90+
/// </summary>
91+
/// <param name="socket">The websocket to send the message to.</param>
92+
/// <param name="topic">The topic of the message to send</param>
93+
/// <param name="message">The message to send</param>
94+
private static async Task SendMessage(WebSocket socket, string topic, string message)
95+
{
96+
var topicMessage = new TopicMessage(topic, message);
97+
var bytes = new ArraySegment<byte>(Encoding.UTF8.GetBytes(topicMessage.ToString()));
98+
await SendMessage(socket, bytes);
99+
}
100+
101+
/// <summary>
102+
/// Send a message to a connected websocket.
103+
/// </summary>
104+
/// <param name="socket">The websocket to send the message to.</param>
105+
/// <param name="message">The message to send.</param>
106+
private static async Task SendMessage(WebSocket socket, ArraySegment<byte> message)
107+
{
108+
try
109+
{
110+
await socket.SendAsync(message, WebSocketMessageType.Text, true, default);
111+
}
112+
catch (Exception e)
113+
{
114+
Log.Debug($"Failed to send message to websocket. {e.Message}");
115+
}
116+
}
117+
118+
/// <summary>
119+
/// Receive an authentication token from a connected websocket.
120+
/// With timeout after five seconds.
121+
/// </summary>
122+
/// <param name="socket">The websocket to receive from.</param>
123+
/// <returns>The authentication token. Or null if none provided.</returns>
124+
private static async Task<string?> ReceiveAuthToken(WebSocket socket)
125+
{
126+
try
127+
{
128+
var buffer = new byte[1024];
129+
using var cts = new CancellationTokenSource();
130+
cts.CancelAfter(TimeSpan.FromSeconds(5));
131+
var result = await socket.ReceiveAsync(new ArraySegment<byte>(buffer), cts.Token);
132+
return result.MessageType == WebSocketMessageType.Text
133+
? Encoding.UTF8.GetString(buffer, 0, result.Count)
134+
: null;
135+
}
136+
catch (OperationCanceledException)
137+
{
138+
return null;
139+
}
140+
}
141+
142+
/// <summary>
143+
/// Close a websocket connection as unauthorized.
144+
/// </summary>
145+
/// <param name="socket">The websocket whose connection to close.</param>
146+
private static async Task CloseUnauthorizedConnection(WebSocket socket)
147+
{
148+
if (socket.State == WebSocketState.Open)
149+
await socket.CloseAsync(WebSocketCloseStatus.PolicyViolation, "Unauthorized", CancellationToken.None);
150+
}
151+
152+
private sealed class TopicMessage(string topic, string message)
153+
{
154+
public string Topic { get; } = topic;
155+
public string Message { get; } = message;
156+
public override string ToString() => JsonSerializer.Serialize(this);
157+
}
158+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
namespace NzbWebDAV.Websocket;
2+
3+
public static class WebsocketTopic
4+
{
5+
public const string UsenetConnections = "cxs";
6+
}

0 commit comments

Comments
 (0)