diff --git a/src/GraphQL.AspNetCore3.JwtBearer/JwtWebSocketAuthenticationService.cs b/src/GraphQL.AspNetCore3.JwtBearer/JwtWebSocketAuthenticationService.cs index 89f5130..869bce6 100644 --- a/src/GraphQL.AspNetCore3.JwtBearer/JwtWebSocketAuthenticationService.cs +++ b/src/GraphQL.AspNetCore3.JwtBearer/JwtWebSocketAuthenticationService.cs @@ -129,6 +129,14 @@ public async Task AuthenticateAsync(AuthenticationRequest authenticationRequest) // set the ClaimsPrincipal for the HttpContext; authentication will take place against this object connection.HttpContext.User = principal; return; + } else if (_jwtBearerAuthenticationOptions.EnableJwtEvents) { + // If JWT events are enabled, trigger the AuthenticationFailed event + var exception = tokenValidationResult.Exception ?? new SecurityTokenValidationException($"The TokenHandler: '{tokenHandler}', was unable to validate the Token."); + var failedResult = await TriggerAuthenticationFailedEventAsync(connection.HttpContext, options, exception, scheme).ConfigureAwait(false); + if (failedResult.Handled && failedResult.Success) { + connection.HttpContext.User = failedResult.Principal!; + return; + } } } catch (Exception ex) { // If JWT events are enabled, trigger the AuthenticationFailed event diff --git a/src/Tests/JwtBearer/JwtWebSocketAuthenticationServiceTests.cs b/src/Tests/JwtBearer/JwtWebSocketAuthenticationServiceTests.cs index 670de68..d61e35f 100644 --- a/src/Tests/JwtBearer/JwtWebSocketAuthenticationServiceTests.cs +++ b/src/Tests/JwtBearer/JwtWebSocketAuthenticationServiceTests.cs @@ -17,150 +17,250 @@ public class JwtWebSocketAuthenticationServiceTests private string _issuer = "https://demo.identityserver.io"; private string _audience = "testAudience"; private readonly string _subject = "user123"; - private RSAParameters _rsaParameters; + private static readonly RSAParameters _rsaParameters; private string? _jwtAccessToken; private readonly MockHttpMessageHandler _oidcHttpMessageHandler = new(); private readonly ISchema _schema; + // Event tracking flags + private bool _messageReceived; + private bool _tokenValidated; + private bool _authenticationFailed; + private bool _enableJwtEvents; + + private readonly JwtBearerEvents _jwtBearerEvents; + private Action? _testFieldAction; + public JwtWebSocketAuthenticationServiceTests() { var query = new ObjectGraphType() { Name = "Query" }; - query.Field("test").Resolve(ctx => ctx.User?.FindFirst(ClaimTypes.NameIdentifier)?.Value); + query.Field("test").Resolve(ctx => { + _testFieldAction?.Invoke(ctx); + return ctx.User?.FindFirst(ClaimTypes.NameIdentifier)?.Value; + }); _schema = new Schema { Query = query }; + + // Initialize JwtBearerEvents + _jwtBearerEvents = new JwtBearerEvents { + OnMessageReceived = context => { + _messageReceived = true; + return Task.CompletedTask; + }, + OnTokenValidated = context => { + _tokenValidated = true; + return Task.CompletedTask; + }, + OnAuthenticationFailed = context => { + _authenticationFailed = true; + return Task.CompletedTask; + } + }; } - [Fact] - public async Task SuccessfulAuthentication() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task SuccessfulAuthentication(bool enableJwtEvents) { CreateSignedToken(); SetupOidcDiscovery(); + _enableJwtEvents = enableJwtEvents; using var testServer = CreateTestServer(); await TestGetAsync(testServer, isAuthenticated: true); await TestWebSocketAsync(testServer, isAuthenticated: true); } [Fact] - public async Task WrongKeys() + public async Task SuccessfulAuthenticationWithCustomClaim() + { + // Configure JwtBearerEvents to add the custom claim during token validation + _jwtBearerEvents.OnTokenValidated = context => { + // Add the custom claim to the user's identity + var identity = context.Principal?.Identity as ClaimsIdentity; + identity?.AddClaim(new Claim("custom:role", "admin")); + + _tokenValidated = true; + return Task.CompletedTask; + }; + + // Set up the test field action to verify the custom claim + _testFieldAction = context => { + var claim = context.User?.FindFirst("custom:role"); + claim.ShouldNotBeNull(); + claim.Value.ShouldBe("admin"); + }; + + // Create the token and set up the test server + CreateSignedToken(); + SetupOidcDiscovery(); + _enableJwtEvents = true; + using var testServer = CreateTestServer(); + await TestGetAsync(testServer, isAuthenticated: true); + await TestWebSocketAsync(testServer, isAuthenticated: true); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task WrongKeys(bool enableJwtEvents) { CreateSignedToken(); SetupOidcDiscovery(); + _enableJwtEvents = enableJwtEvents; using var testServer = CreateTestServer(); - CreateSignedToken(); // create new token with different keys + CreateSignedToken(differentKeys: true); // create new token with different keys await TestGetAsync(testServer, isAuthenticated: false); await TestWebSocketAsync(testServer, isAuthenticated: false); } - [Fact] - public async Task WrongIssuer() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task WrongIssuer(bool enableJwtEvents) { CreateSignedToken(); _issuer = "https://wrong.issuer"; SetupOidcDiscovery(); + _enableJwtEvents = enableJwtEvents; using var testServer = CreateTestServer(); await TestGetAsync(testServer, isAuthenticated: false); await TestWebSocketAsync(testServer, isAuthenticated: false); } - [Fact] - public async Task WrongAudience() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task WrongAudience(bool enableJwtEvents) { CreateSignedToken(); _audience = "wrongAudience"; SetupOidcDiscovery(); + _enableJwtEvents = enableJwtEvents; using var testServer = CreateTestServer(); await TestGetAsync(testServer, isAuthenticated: false); await TestWebSocketAsync(testServer, isAuthenticated: false); } - [Fact] - public async Task Expired() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task Expired(bool enableJwtEvents) { CreateSignedToken(expired: true); SetupOidcDiscovery(); + _enableJwtEvents = enableJwtEvents; using var testServer = CreateTestServer(); await TestGetAsync(testServer, isAuthenticated: false); await TestWebSocketAsync(testServer, isAuthenticated: false); } - [Fact] - public async Task NoDefaultScheme() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task NoDefaultScheme(bool enableJwtEvents) { CreateSignedToken(); SetupOidcDiscovery(); + _enableJwtEvents = enableJwtEvents; using var testServer = CreateTestServer(defaultScheme: false); await TestGetAsync(testServer, isAuthenticated: false); - await TestWebSocketAsync(testServer, isAuthenticated: false); + await TestWebSocketAsync(testServer, isAuthenticated: false, + expectMessageReceived: false, expectAuthenticationFailed: false); } - [Fact] - public async Task NoDefaultSchemeSpecified() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task NoDefaultSchemeSpecified(bool enableJwtEvents) { CreateSignedToken(); SetupOidcDiscovery(); + _enableJwtEvents = enableJwtEvents; using var testServer = CreateTestServer(defaultScheme: false, specifyScheme: true); await TestGetAsync(testServer, isAuthenticated: true); await TestWebSocketAsync(testServer, isAuthenticated: true); } - [Fact] - public async Task CustomScheme() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CustomScheme(bool enableJwtEvents) { CreateSignedToken(); SetupOidcDiscovery(); + _enableJwtEvents = enableJwtEvents; using var testServer = CreateTestServer(customScheme: true); await TestGetAsync(testServer, isAuthenticated: true); await TestWebSocketAsync(testServer, isAuthenticated: true); } - [Fact] - public async Task CustomNoDefaultScheme() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CustomNoDefaultScheme(bool enableJwtEvents) { CreateSignedToken(); SetupOidcDiscovery(); + _enableJwtEvents = enableJwtEvents; using var testServer = CreateTestServer(customScheme: true, defaultScheme: false); await TestGetAsync(testServer, isAuthenticated: false); - await TestWebSocketAsync(testServer, isAuthenticated: false); + await TestWebSocketAsync(testServer, isAuthenticated: false, + expectMessageReceived: false, expectAuthenticationFailed: false); } - [Fact] - public async Task CustomNoDefaultSchemeSpecified() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CustomNoDefaultSchemeSpecified(bool enableJwtEvents) { CreateSignedToken(); SetupOidcDiscovery(); + _enableJwtEvents = enableJwtEvents; using var testServer = CreateTestServer(customScheme: true, defaultScheme: false, specifyScheme: true); await TestGetAsync(testServer, isAuthenticated: true); await TestWebSocketAsync(testServer, isAuthenticated: true); } - [Fact] - public async Task WrongScheme() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task WrongScheme(bool enableJwtEvents) { CreateSignedToken(); SetupOidcDiscovery(); + _enableJwtEvents = enableJwtEvents; using var testServer = CreateTestServer(specifyInvalidScheme: true); await TestGetAsync(testServer, isAuthenticated: false); - await TestWebSocketAsync(testServer, isAuthenticated: false); + await TestWebSocketAsync(testServer, isAuthenticated: false, + expectMessageReceived: false, expectAuthenticationFailed: false); } - [Fact] - public async Task MultipleSchemes() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task MultipleSchemes(bool enableJwtEvents) { CreateSignedToken(); SetupOidcDiscovery(); + _enableJwtEvents = enableJwtEvents; using var testServer = CreateTestServer(specifyInvalidScheme: true, specifyScheme: true, defaultScheme: false); await TestGetAsync(testServer, isAuthenticated: true); await TestWebSocketAsync(testServer, isAuthenticated: true); } - [Fact] - public async Task NoToken() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task NoToken(bool enableJwtEvents) { CreateSignedToken(); SetupOidcDiscovery(); + _enableJwtEvents = enableJwtEvents; using var testServer = CreateTestServer(); _jwtAccessToken = null; await TestGetAsync(testServer, isAuthenticated: false); - await TestWebSocketAsync(testServer, isAuthenticated: false); + await TestWebSocketAsync(testServer, isAuthenticated: false, + expectMessageReceived: false, expectAuthenticationFailed: false); } private async Task TestGetAsync(TestServer testServer, bool isAuthenticated) @@ -182,7 +282,8 @@ private async Task TestGetAsync(TestServer testServer, bool isAuthenticated) } } - private async Task TestWebSocketAsync(TestServer testServer, bool isAuthenticated) + private async Task TestWebSocketAsync(TestServer testServer, bool isAuthenticated, + bool expectMessageReceived = true, bool expectAuthenticationFailed = true, bool expectTokenValidated = false) { // test an authenticated request var webSocketClient = testServer.CreateWebSocketClient(); @@ -192,6 +293,11 @@ private async Task TestWebSocketAsync(TestServer testServer, bool isAuthenticate webSocketClient.SubProtocols.Add("graphql-ws"); using var webSocket = await webSocketClient.ConnectAsync(new Uri(testServer.BaseAddress, "/graphql"), default); + // reset event tracking flags after initial connection has been made (as messageReceived is called on connection by the ASP.NET Core pipeline) + _messageReceived = false; + _authenticationFailed = false; + _tokenValidated = false; + // send CONNECTION_INIT await webSocket.SendMessageAsync(new OperationMessage { Type = "connection_init", @@ -208,6 +314,18 @@ await webSocket.SendMessageAsync(new OperationMessage { // wait for websocket closure (await webSocket.ReceiveCloseAsync()).ShouldBe((WebSocketCloseStatus)4401); + + // Verify events were triggered if _enableJwtEvents is true + if (_enableJwtEvents) { + _messageReceived.ShouldBe(expectMessageReceived); + _authenticationFailed.ShouldBe(expectAuthenticationFailed); + _tokenValidated.ShouldBe(expectTokenValidated); + } else { + _messageReceived.ShouldBeFalse(); + _authenticationFailed.ShouldBeFalse(); + _tokenValidated.ShouldBeFalse(); + } + return; } @@ -231,6 +349,17 @@ await webSocket.SendMessageAsync(new OperationMessage { message.Payload.ShouldBe($$$""" {"data":{"test":"{{{_subject}}}"}} """); + + // Verify events were triggered if _enableJwtEvents is true + if (_enableJwtEvents) { + _messageReceived.ShouldBeTrue(); + _tokenValidated.ShouldBeTrue(); + _authenticationFailed.ShouldBeFalse(); + } else { + _messageReceived.ShouldBeFalse(); + _tokenValidated.ShouldBeFalse(); + _authenticationFailed.ShouldBeFalse(); + } } /// @@ -249,11 +378,12 @@ private TestServer CreateTestServer(bool defaultScheme = true, bool customScheme o.Authority = _issuer; o.Audience = _audience; o.BackchannelHttpHandler = _oidcHttpMessageHandler; + o.Events = _jwtBearerEvents; }); services.AddGraphQL(b => b .AddSchema(_schema) .AddSystemTextJson() - .AddJwtBearerAuthentication(true) + .AddJwtBearerAuthentication(_enableJwtEvents) ); }) .Configure(app => { @@ -324,14 +454,29 @@ private void SetupOidcDiscovery() } /// - /// Creates a new RSA key pair and a signed JWT token. + /// Creates a new RSA key pair. + /// + /// + /// .NET Framework can only handle around 50 RSA keys at a time. + /// + static JwtWebSocketAuthenticationServiceTests() + { + using var rsa = RSA.Create(2048); + _rsaParameters = rsa.ExportParameters(true); + } + + /// + /// Creates a signed JWT token. /// Uses the currently configured , , and . /// Overwrites the and fields. /// - private void CreateSignedToken(bool expired = false) + private void CreateSignedToken(bool expired = false, bool differentKeys = false) { - using var rsa = RSA.Create(2048); - var rsaParameters = rsa.ExportParameters(true); + RSAParameters rsaParameters = _rsaParameters; + if (differentKeys) { + using var rsa = RSA.Create(2048); + rsaParameters = rsa.ExportParameters(true); + } var key = new RsaSecurityKey(rsaParameters); var signingCredentials = new SigningCredentials(key, SecurityAlgorithms.RsaSha256); @@ -351,7 +496,6 @@ private void CreateSignedToken(bool expired = false) var tokenHandler = new JwtSecurityTokenHandler(); var token = tokenHandler.CreateToken(tokenDescriptor); var tokenStr = tokenHandler.WriteToken(token); - _rsaParameters = rsaParameters; _jwtAccessToken = tokenStr; } }