|
5 | 5 | using System.Security.Claims; |
6 | 6 | using Microsoft.AspNetCore.Authentication; |
7 | 7 | using Microsoft.AspNetCore.Authorization; |
| 8 | +using static System.Net.Mime.MediaTypeNames; |
8 | 9 |
|
9 | 10 | namespace GraphQL.AspNetCore3; |
10 | 11 |
|
@@ -125,6 +126,10 @@ public virtual async Task InvokeAsync(HttpContext context) |
125 | 126 | return; |
126 | 127 | } |
127 | 128 |
|
| 129 | + // Perform CSRF protection if necessary |
| 130 | + if (await HandleCsrfProtectionAsync(context, _next)) |
| 131 | + return; |
| 132 | + |
128 | 133 | // Authenticate request if necessary |
129 | 134 | if (await HandleAuthorizeAsync(context, _next)) |
130 | 135 | return; |
@@ -423,6 +428,32 @@ static void ApplyFileToRequest(IFormFile file, string target, GraphQLRequest? re |
423 | 428 | } |
424 | 429 | } |
425 | 430 |
|
| 431 | + /// <summary> |
| 432 | + /// Performs CSRF protection, if required, and returns <see langword="true"/> if the |
| 433 | + /// request was handled (typically by returning an error message). If <see langword="false"/> |
| 434 | + /// is returned, the request is processed normally. |
| 435 | + /// </summary> |
| 436 | + protected virtual async ValueTask<bool> HandleCsrfProtectionAsync(HttpContext context, RequestDelegate next) |
| 437 | + { |
| 438 | + if (!_options.CsrfProtectionEnabled) |
| 439 | + return false; |
| 440 | + if (context.Request.Headers.TryGetValue("Content-Type", out var contentTypes) && contentTypes.Count > 0 && contentTypes[0] != null) { |
| 441 | + var contentType = contentTypes[0]!; |
| 442 | + if (contentType.IndexOf(';') > 0) { |
| 443 | + contentType = contentType.Substring(0, contentType.IndexOf(';')); |
| 444 | + } |
| 445 | + contentType = contentType.Trim().ToLowerInvariant(); |
| 446 | + if (!(contentType == "text/plain" || contentType == "application/x-www-form-urlencoded" || contentType == "multipart/form-data")) |
| 447 | + return false; |
| 448 | + } |
| 449 | + foreach (var header in _options.CsrfProtectionHeaders) { |
| 450 | + if (context.Request.Headers.TryGetValue(header, out var values) && values.Count > 0 && values[0]?.Length > 0) |
| 451 | + return false; |
| 452 | + } |
| 453 | + await HandleCsrfProtectionErrorAsync(context, next); |
| 454 | + return true; |
| 455 | + } |
| 456 | + |
426 | 457 | /// <summary> |
427 | 458 | /// Perform authentication, if required, and return <see langword="true"/> if the |
428 | 459 | /// request was handled (typically by returning an error message). If <see langword="false"/> |
@@ -769,21 +800,29 @@ protected virtual Task HandleNotAuthorizedPolicyAsync(HttpContext context, Reque |
769 | 800 | /// </summary> |
770 | 801 | protected virtual async ValueTask<bool> HandleDeserializationErrorAsync(HttpContext context, RequestDelegate next, Exception exception) |
771 | 802 | { |
772 | | - await WriteErrorResponseAsync(context, HttpStatusCode.BadRequest, new JsonInvalidError(exception)); |
| 803 | + await WriteErrorResponseAsync(context, new JsonInvalidError(exception)); |
773 | 804 | return true; |
774 | 805 | } |
775 | 806 |
|
| 807 | + /// <summary> |
| 808 | + /// Writes a '.' message to the output. |
| 809 | + /// </summary> |
| 810 | + protected virtual async Task HandleCsrfProtectionErrorAsync(HttpContext context, RequestDelegate next) |
| 811 | + { |
| 812 | + await WriteErrorResponseAsync(context, new CsrfProtectionError(_options.CsrfProtectionHeaders)); |
| 813 | + } |
| 814 | + |
776 | 815 | /// <summary> |
777 | 816 | /// Writes a '400 Batched requests are not supported.' message to the output. |
778 | 817 | /// </summary> |
779 | 818 | protected virtual Task HandleBatchedRequestsNotSupportedAsync(HttpContext context, RequestDelegate next) |
780 | | - => WriteErrorResponseAsync(context, HttpStatusCode.BadRequest, new BatchedRequestsNotSupportedError()); |
| 819 | + => WriteErrorResponseAsync(context, new BatchedRequestsNotSupportedError()); |
781 | 820 |
|
782 | 821 | /// <summary> |
783 | 822 | /// Writes a '400 Invalid requested WebSocket sub-protocol(s).' message to the output. |
784 | 823 | /// </summary> |
785 | 824 | protected virtual Task HandleWebSocketSubProtocolNotSupportedAsync(HttpContext context, RequestDelegate next) |
786 | | - => WriteErrorResponseAsync(context, HttpStatusCode.BadRequest, new WebSocketSubProtocolNotSupportedError(context.WebSockets.WebSocketRequestedProtocols)); |
| 825 | + => WriteErrorResponseAsync(context, new WebSocketSubProtocolNotSupportedError(context.WebSockets.WebSocketRequestedProtocols)); |
787 | 826 |
|
788 | 827 | /// <summary> |
789 | 828 | /// Writes a '415 Invalid Content-Type header: could not be parsed.' message to the output. |
@@ -814,6 +853,12 @@ protected virtual Task HandleInvalidHttpMethodErrorAsync(HttpContext context, Re |
814 | 853 | return next(context); |
815 | 854 | } |
816 | 855 |
|
| 856 | + /// <summary> |
| 857 | + /// Writes the specified error as a JSON-formatted GraphQL response. |
| 858 | + /// </summary> |
| 859 | + protected virtual Task WriteErrorResponseAsync(HttpContext context, ExecutionError executionError) |
| 860 | + => WriteErrorResponseAsync(context, executionError is IHasPreferredStatusCode withCode ? withCode.PreferredStatusCode : HttpStatusCode.BadRequest, executionError); |
| 861 | + |
817 | 862 | /// <summary> |
818 | 863 | /// Writes the specified error message as a JSON-formatted GraphQL response, with the specified HTTP status code. |
819 | 864 | /// </summary> |
|
0 commit comments