Skip to content

Commit b9504d1

Browse files
committed
Refactor AuditSignInManager for clarity and efficiency
Refactored `AuditSignInManager` to improve readability, maintainability, and robustness. Simplified the `GetClientIpAddress` method by introducing helper methods (`TryGetSingleHeaderIp`, `TryGetXForwardedFor`, etc.) and standardizing loopback IP normalization. Removed redundant code and unused parameters in `LogLoginAuditAsync` and sign-in methods. Enhanced logging and error handling to ensure resilience during login audits. Reformatted code for better structure while preserving core functionality.
1 parent e31afcd commit b9504d1

1 file changed

Lines changed: 92 additions & 90 deletions

File tree

src/Infrastructure/Services/Identity/AuditSignInManager.cs

Lines changed: 92 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,28 @@
44
using CleanArchitecture.Blazor.Domain.Identity;
55
using Microsoft.AspNetCore.Authentication;
66
using Microsoft.AspNetCore.Http;
7+
using System.Net; // added for IPAddress parsing
78

89
namespace CleanArchitecture.Blazor.Infrastructure.Services;
910

1011
public class AuditSignInManager<TUser> : SignInManager<TUser>
11-
where TUser : class
12+
where TUser : class
1213
{
1314
private readonly IApplicationDbContextFactory _dbContextFactory;
1415
private readonly IHttpContextAccessor _httpContextAccessor;
1516
private readonly ILogger<AuditSignInManager<TUser>> _logger;
1617

1718
public AuditSignInManager(
18-
UserManager<TUser> userManager,
19-
IHttpContextAccessor contextAccessor,
20-
IUserClaimsPrincipalFactory<TUser> claimsFactory,
21-
IOptions<IdentityOptions> optionsAccessor,
22-
ILogger<SignInManager<TUser>> logger,
23-
IAuthenticationSchemeProvider schemes,
24-
IUserConfirmation<TUser> confirmation,
25-
IApplicationDbContextFactory dbContextFactory,
26-
ILogger<AuditSignInManager<TUser>> auditLogger)
27-
: base(userManager, contextAccessor, claimsFactory, optionsAccessor, logger, schemes, confirmation)
19+
UserManager<TUser> userManager,
20+
IHttpContextAccessor contextAccessor,
21+
IUserClaimsPrincipalFactory<TUser> claimsFactory,
22+
IOptions<IdentityOptions> optionsAccessor,
23+
ILogger<SignInManager<TUser>> logger,
24+
IAuthenticationSchemeProvider schemes,
25+
IUserConfirmation<TUser> confirmation,
26+
IApplicationDbContextFactory dbContextFactory,
27+
ILogger<AuditSignInManager<TUser>> auditLogger)
28+
: base(userManager, contextAccessor, claimsFactory, optionsAccessor, logger, schemes, confirmation)
2829
{
2930
_dbContextFactory = dbContextFactory;
3031
_httpContextAccessor = contextAccessor;
@@ -42,20 +43,20 @@ public override async Task<SignInResult> PasswordSignInAsync(TUser user, string
4243
{
4344
var result = await base.PasswordSignInAsync(user, password, isPersistent, lockoutOnFailure);
4445
var userName = await UserManager.GetUserNameAsync(user) ?? "Unknown";
45-
var userId= await UserManager.GetUserIdAsync(user) ?? "Unknown";
46-
await LogLoginAuditAsync(userId,userName, result.Succeeded, "Local", result);
46+
var userId = await UserManager.GetUserIdAsync(user) ?? "Unknown";
47+
await LogLoginAuditAsync(userId, userName, result.Succeeded, "Local");
4748
return result;
4849
}
4950

5051
public override async Task<SignInResult> ExternalLoginSignInAsync(string loginProvider, string providerKey, bool isPersistent, bool bypassTwoFactor)
5152
{
5253
var result = await base.ExternalLoginSignInAsync(loginProvider, providerKey, isPersistent, bypassTwoFactor);
53-
54+
5455
// Try to get user information from external login
5556
var info = await GetExternalLoginInfoAsync();
5657
var userName = info?.Principal?.Identity?.Name ?? "External User";
5758
var userId = info?.Principal?.FindFirstValue(ClaimTypes.NameIdentifier) ?? "Unknown";
58-
await LogLoginAuditAsync(userId, userName, result.Succeeded, loginProvider, result);
59+
await LogLoginAuditAsync(userId, userName, result.Succeeded, loginProvider);
5960
return result;
6061
}
6162

@@ -64,21 +65,21 @@ public override async Task SignInAsync(TUser user, bool isPersistent, string? au
6465
await base.SignInAsync(user, isPersistent, authenticationMethod);
6566
var userName = await UserManager.GetUserNameAsync(user) ?? "Unknown";
6667
var userId = await UserManager.GetUserIdAsync(user) ?? "Unknown";
67-
await LogLoginAuditAsync(userId,userName, true, authenticationMethod ?? "Direct", null);
68+
await LogLoginAuditAsync(userId, userName, true, authenticationMethod ?? "Direct");
6869
}
6970

7071
public override async Task SignInAsync(TUser user, AuthenticationProperties authenticationProperties, string? authenticationMethod = null)
7172
{
7273
await base.SignInAsync(user, authenticationProperties, authenticationMethod);
7374
var userName = await UserManager.GetUserNameAsync(user) ?? "Unknown";
7475
var userId = await UserManager.GetUserIdAsync(user) ?? "Unknown";
75-
await LogLoginAuditAsync(userId, userName, true, authenticationMethod ?? "Direct", null);
76+
await LogLoginAuditAsync(userId, userName, true, authenticationMethod ?? "Direct");
7677
}
7778

7879
public override async Task<SignInResult> TwoFactorSignInAsync(string provider, string code, bool isPersistent, bool rememberClient)
7980
{
8081
var result = await base.TwoFactorSignInAsync(provider, code, isPersistent, rememberClient);
81-
82+
8283
// Get user from two factor info
8384
var userName = "Unknown";
8485
var userId = "Unknown";
@@ -88,12 +89,12 @@ public override async Task<SignInResult> TwoFactorSignInAsync(string provider, s
8889
userName = await UserManager.GetUserNameAsync(user) ?? "Unknown";
8990
userId = await UserManager.GetUserIdAsync(user) ?? "Unknown";
9091
}
91-
92-
await LogLoginAuditAsync(userId,userName, result.Succeeded, $"2FA-{provider}", result);
92+
93+
await LogLoginAuditAsync(userId, userName, result.Succeeded, $"2FA-{provider}");
9394
return result;
9495
}
9596

96-
private async Task LogLoginAuditAsync(string userId,string userName, bool success, string provider, SignInResult? result)
97+
private async Task LogLoginAuditAsync(string userId, string userName, bool success, string provider)
9798
{
9899
try
99100
{
@@ -104,21 +105,21 @@ private async Task LogLoginAuditAsync(string userId,string userName, bool succes
104105
return;
105106
}
106107

107-
108-
109108
// Extract client information
110109
var ipAddress = GetClientIpAddress(httpContext);
111110
var browserInfo = GetBrowserInfo(httpContext);
112111

113112
// Create login audit using the service
114-
var loginAudit = new LoginAudit() {
113+
var loginAudit = new LoginAudit()
114+
{
115115
LoginTimeUtc = DateTime.UtcNow,
116116
UserId = userId ?? string.Empty,
117-
UserName=userName,
118-
IpAddress= ipAddress,
119-
BrowserInfo= browserInfo,
120-
Provider= provider,
121-
Success= success};
117+
UserName = userName,
118+
IpAddress = ipAddress,
119+
BrowserInfo = browserInfo,
120+
Provider = provider,
121+
Success = success
122+
};
122123
loginAudit.AddDomainEvent(new Domain.Events.LoginAuditCreatedEvent(loginAudit));
123124
// Save to database
124125
await using var db = await _dbContextFactory.CreateAsync();
@@ -136,49 +137,15 @@ private async Task LogLoginAuditAsync(string userId,string userName, bool succes
136137
{
137138
try
138139
{
139-
// Priority order of common proxy headers
140-
// 1. Cloudflare / CDN specific
141-
var cfConnectingIp = httpContext.Request.Headers["CF-Connecting-IP"].FirstOrDefault();
142-
if (!string.IsNullOrWhiteSpace(cfConnectingIp)) return SanitizeAndNormalize(cfConnectingIp);
143-
144-
// 2. Standard X-Forwarded-For (may contain comma separated chain). Take first non-empty value.
145-
var forwardedForRaw = httpContext.Request.Headers["X-Forwarded-For"].FirstOrDefault();
146-
if (!string.IsNullOrWhiteSpace(forwardedForRaw))
147-
{
148-
var first = forwardedForRaw.Split(',').Select(s => s.Trim()).FirstOrDefault(s => !string.IsNullOrWhiteSpace(s));
149-
if (!string.IsNullOrWhiteSpace(first)) return SanitizeAndNormalize(first);
150-
}
151-
152-
// 3. True-Client-IP (Akamai, some CDNs)
153-
var trueClientIp = httpContext.Request.Headers["True-Client-IP"].FirstOrDefault();
154-
if (!string.IsNullOrWhiteSpace(trueClientIp)) return SanitizeAndNormalize(trueClientIp);
155-
156-
// 4. X-Real-IP (nginx) single value
157-
var realIp = httpContext.Request.Headers["X-Real-IP"].FirstOrDefault();
158-
if (!string.IsNullOrWhiteSpace(realIp)) return SanitizeAndNormalize(realIp);
159-
160-
// 5. Forwarded header (RFC 7239) e.g. Forwarded: for=203.0.113.195;proto=https;by=203.0.113.43
161-
var forwarded = httpContext.Request.Headers["Forwarded"].FirstOrDefault();
162-
if (!string.IsNullOrWhiteSpace(forwarded))
163-
{
164-
// Extract for= value
165-
var segments = forwarded.Split(';');
166-
foreach (var seg in segments)
167-
{
168-
var part = seg.Trim();
169-
if (part.StartsWith("for=", StringComparison.OrdinalIgnoreCase))
170-
{
171-
var ipPart = part.Substring(4).Trim('"');
172-
// Remove IPv6 brackets if present
173-
ipPart = ipPart.Trim('[', ']');
174-
if (!string.IsNullOrWhiteSpace(ipPart)) return SanitizeAndNormalize(ipPart);
175-
}
176-
}
177-
}
178-
179-
// 6. Fallback to connection remote IP
180-
var remoteIp = httpContext.Connection.RemoteIpAddress?.ToString();
181-
return SanitizeAndNormalize(remoteIp);
140+
// Simple & safe: only examine a short list of common headers, validate each with IPAddress.TryParse.
141+
// Order: CF-Connecting-IP -> X-Forwarded-For (first) -> True-Client-IP -> X-Real-IP -> fallback RemoteIpAddress
142+
if (TryGetSingleHeaderIp(httpContext, "CF-Connecting-IP", out var ip)) return ip;
143+
if (TryGetXForwardedFor(httpContext, out ip)) return ip;
144+
if (TryGetSingleHeaderIp(httpContext, "True-Client-IP", out ip)) return ip;
145+
if (TryGetSingleHeaderIp(httpContext, "X-Real-IP", out ip)) return ip;
146+
147+
var remote = httpContext.Connection.RemoteIpAddress;
148+
return NormalizeLoopback(remote);
182149
}
183150
catch (Exception ex)
184151
{
@@ -187,35 +154,70 @@ private async Task LogLoginAuditAsync(string userId,string userName, bool succes
187154
}
188155
}
189156

190-
private string SanitizeInput(string? input)
157+
private bool TryGetSingleHeaderIp(HttpContext ctx, string headerName, out string? ip)
191158
{
192-
if (string.IsNullOrEmpty(input))
193-
return string.Empty;
194-
// Remove newline characters and trim whitespace
195-
return input.Replace("\r", "").Replace("\n", "").Trim();
159+
ip = null;
160+
var raw = ctx.Request.Headers[headerName].FirstOrDefault();
161+
if (string.IsNullOrWhiteSpace(raw)) return false;
162+
raw = raw.Split(',')[0].Trim(); // if multiple, only first
163+
raw = StripPortAndBrackets(raw);
164+
if (IPAddress.TryParse(raw, out var parsed))
165+
{
166+
ip = NormalizeLoopback(parsed);
167+
return true;
168+
}
169+
return false;
196170
}
197-
private string? SanitizeAndNormalize(string? input)
171+
172+
private bool TryGetXForwardedFor(HttpContext ctx, out string? ip)
198173
{
199-
var value = SanitizeInput(input);
200-
if (string.IsNullOrEmpty(value)) return value;
201-
if (value == "::1" || value == "127.0.0.1") return "127.0.0.1";
202-
// Remove port if accidentally included (IPv4:port or [IPv6]:port)
203-
if (value.Contains(':'))
174+
ip = null;
175+
var raw = ctx.Request.Headers["X-Forwarded-For"].FirstOrDefault();
176+
if (string.IsNullOrWhiteSpace(raw)) return false;
177+
// Split chain client, proxy1, proxy2 ... choose first non-empty candidate that parses
178+
foreach (var candidate in raw.Split(',').Select(s => s.Trim()))
204179
{
205-
// For IPv6 keep colons; only strip if it's an IPv4 with a single ':' and digits afterward or bracketed IPv6 with port
206-
if (value.Count(c => c == ':') == 1 && value.Contains('.'))
180+
if (string.IsNullOrEmpty(candidate)) continue;
181+
var cleaned = StripPortAndBrackets(candidate);
182+
if (IPAddress.TryParse(cleaned, out var parsed))
207183
{
208-
// IPv4 with port
209-
value = value.Split(':')[0];
184+
ip = NormalizeLoopback(parsed);
185+
return true;
210186
}
211-
else if (value.StartsWith("[") && value.Contains("]:"))
187+
}
188+
return false;
189+
}
190+
191+
private string StripPortAndBrackets(string value)
192+
{
193+
if (string.IsNullOrEmpty(value)) return value;
194+
value = value.Trim('"');
195+
// IPv6 in brackets: [2001:db8::1]:443
196+
if (value.StartsWith("[") && value.Contains("]"))
197+
{
198+
var end = value.IndexOf(']');
199+
if (end > 0)
212200
{
213-
value = value.Substring(1, value.IndexOf(']') - 1); // Extract inside brackets
201+
var core = value.Substring(1, end - 1);
202+
// ignore trailing :port
203+
return core;
214204
}
215205
}
206+
// IPv4:port
207+
var colonIndex = value.LastIndexOf(':');
208+
if (colonIndex > -1 && value.Count(c => c == ':') == 1 && value.Contains('.'))
209+
{
210+
return value.Substring(0, colonIndex);
211+
}
216212
return value;
217213
}
218-
214+
215+
private string? NormalizeLoopback(IPAddress? ip)
216+
{
217+
if (ip == null) return null;
218+
if (IPAddress.IsLoopback(ip)) return "127.0.0.1"; // unify
219+
return ip.ToString();
220+
}
219221

220222
private string? GetBrowserInfo(HttpContext httpContext)
221223
{

0 commit comments

Comments
 (0)