|
| 1 | +using System.Security.Claims; |
| 2 | +using System; |
| 3 | +using System.Linq; |
| 4 | +using System.Threading.Tasks; |
| 5 | +using Microsoft.AspNetCore.Authorization; |
| 6 | +using Microsoft.AspNetCore.SignalR; |
| 7 | +using Microsoft.Extensions.Configuration; |
| 8 | +using Microsoft.Extensions.DependencyInjection; |
| 9 | +using Microsoft.Extensions.Logging; |
| 10 | + |
| 11 | +namespace Scv.Api.Hubs; |
| 12 | + |
| 13 | +[Authorize] |
| 14 | +public class NotificationsHub : Hub |
| 15 | +{ |
| 16 | + public override Task OnConnectedAsync() |
| 17 | + { |
| 18 | + var httpContext = Context.GetHttpContext(); |
| 19 | + var config = httpContext?.RequestServices.GetService<IConfiguration>(); |
| 20 | + var logger = httpContext?.RequestServices.GetService<ILogger<NotificationsHub>>(); |
| 21 | + var allowedOrigin = config?.GetValue<string>("CORS_DOMAIN"); |
| 22 | + var disableOriginCheck = config?.GetValue<bool>("DISABLE_SIGNALR_ORIGIN_CHECK") ?? false; |
| 23 | + var origin = httpContext?.Request.Headers["Origin"].ToString(); |
| 24 | + |
| 25 | + logger?.LogInformation( |
| 26 | + "SignalR connect attempt. Origin={Origin}, CORS_DOMAIN={CorsDomain}", |
| 27 | + origin, |
| 28 | + allowedOrigin); |
| 29 | + |
| 30 | + if (disableOriginCheck) |
| 31 | + { |
| 32 | + logger?.LogWarning("SignalR origin check disabled via DISABLE_SIGNALR_ORIGIN_CHECK."); |
| 33 | + } |
| 34 | + else if (!string.IsNullOrWhiteSpace(allowedOrigin)) |
| 35 | + { |
| 36 | + var allowedOrigins = allowedOrigin |
| 37 | + .Split(';', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) |
| 38 | + .Select(value => value.Trim().Trim('"', '\'')) |
| 39 | + .Where(value => !string.IsNullOrWhiteSpace(value)) |
| 40 | + .ToArray(); |
| 41 | + |
| 42 | + logger?.LogInformation( |
| 43 | + "SignalR allowed origins resolved to {AllowedOrigins}", |
| 44 | + string.Join(";", allowedOrigins)); |
| 45 | + |
| 46 | + if (allowedOrigins.Length > 0 && |
| 47 | + (string.IsNullOrWhiteSpace(origin) || |
| 48 | + !allowedOrigins.Any(value => string.Equals(origin, value, StringComparison.OrdinalIgnoreCase)))) |
| 49 | + { |
| 50 | + logger?.LogWarning( |
| 51 | + "SignalR connection aborted due to origin mismatch. Origin={Origin}", |
| 52 | + origin); |
| 53 | + Context.Abort(); |
| 54 | + return Task.CompletedTask; |
| 55 | + } |
| 56 | + } |
| 57 | + |
| 58 | + var userId = Context.User?.FindFirstValue(ClaimTypes.NameIdentifier); |
| 59 | + if (string.IsNullOrWhiteSpace(userId)) |
| 60 | + { |
| 61 | + if (logger != null) |
| 62 | + { |
| 63 | + var claims = Context.User?.Claims |
| 64 | + .Select(claim => $"{claim.Type}={claim.Value}") |
| 65 | + .ToArray() ?? Array.Empty<string>(); |
| 66 | + logger.LogDebug( |
| 67 | + "SignalR user claims: {Claims}", |
| 68 | + string.Join(";", claims)); |
| 69 | + } |
| 70 | + logger?.LogWarning("SignalR connection aborted due to missing user id claim."); |
| 71 | + Context.Abort(); |
| 72 | + return Task.CompletedTask; |
| 73 | + } |
| 74 | + |
| 75 | + logger?.LogInformation("SignalR connection accepted for user {UserId}.", userId); |
| 76 | + |
| 77 | + return base.OnConnectedAsync(); |
| 78 | + } |
| 79 | +} |
0 commit comments