Skip to content

Commit c429beb

Browse files
authored
Add CSRF protection (#66)
* Add CSRF protection * Update * Update
1 parent 5d4f0ee commit c429beb

File tree

5 files changed

+92
-3
lines changed

5 files changed

+92
-3
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,8 @@ endpoint.
564564
| `AuthorizationRequired` | Requires `HttpContext.User` to represent an authenticated user. | False |
565565
| `AuthorizedPolicy` | If set, requires `HttpContext.User` to pass authorization of the specified policy. | |
566566
| `AuthorizedRoles` | If set, requires `HttpContext.User` to be a member of any one of a list of roles. | |
567+
| `CsrfProtectionEnabled` | Enables cross-site request forgery (CSRF) protection for both GET and POST requests. | True |
568+
| `CsrfProtectionHeaders` | Sets the headers used for CSRF protection when necessary. | `GraphQL-Require-Preflight` |
567569
| `EnableBatchedRequests` | Enables handling of batched GraphQL requests for POST requests when formatted as JSON. | True |
568570
| `ExecuteBatchedRequestsInParallel` | Enables parallel execution of batched GraphQL requests. | True |
569571
| `HandleGet` | Enables handling of GET requests. | True |
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
namespace GraphQL.AspNetCore3.Errors;
2+
3+
/// <summary>
4+
/// Represents an error indicating that the request may not have triggered a CORS preflight request.
5+
/// </summary>
6+
public class CsrfProtectionError : RequestError
7+
{
8+
/// <inheritdoc cref="CsrfProtectionError"/>
9+
public CsrfProtectionError(IEnumerable<string> headersRequired) : base($"This request requires a non-empty header from the following list: {FormatHeaders(headersRequired)}.") { }
10+
11+
/// <inheritdoc cref="CsrfProtectionError"/>
12+
public CsrfProtectionError(IEnumerable<string> headersRequired, Exception innerException) : base($"This request requires a non-empty header from the following list: {FormatHeaders(headersRequired)}. {innerException.Message}") { }
13+
14+
private static string FormatHeaders(IEnumerable<string> headersRequired)
15+
=> string.Join(", ", headersRequired.Select(x => $"'{x}'"));
16+
}

src/GraphQL.AspNetCore3/GraphQLHttpMiddleware.cs

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Security.Claims;
66
using Microsoft.AspNetCore.Authentication;
77
using Microsoft.AspNetCore.Authorization;
8+
using static System.Net.Mime.MediaTypeNames;
89

910
namespace GraphQL.AspNetCore3;
1011

@@ -125,6 +126,10 @@ public virtual async Task InvokeAsync(HttpContext context)
125126
return;
126127
}
127128

129+
// Perform CSRF protection if necessary
130+
if (await HandleCsrfProtectionAsync(context, _next))
131+
return;
132+
128133
// Authenticate request if necessary
129134
if (await HandleAuthorizeAsync(context, _next))
130135
return;
@@ -423,6 +428,32 @@ static void ApplyFileToRequest(IFormFile file, string target, GraphQLRequest? re
423428
}
424429
}
425430

431+
/// <summary>
432+
/// Performs CSRF protection, if required, and returns <see langword="true"/> if the
433+
/// request was handled (typically by returning an error message). If <see langword="false"/>
434+
/// is returned, the request is processed normally.
435+
/// </summary>
436+
protected virtual async ValueTask<bool> HandleCsrfProtectionAsync(HttpContext context, RequestDelegate next)
437+
{
438+
if (!_options.CsrfProtectionEnabled)
439+
return false;
440+
if (context.Request.Headers.TryGetValue("Content-Type", out var contentTypes) && contentTypes.Count > 0 && contentTypes[0] != null) {
441+
var contentType = contentTypes[0]!;
442+
if (contentType.IndexOf(';') > 0) {
443+
contentType = contentType.Substring(0, contentType.IndexOf(';'));
444+
}
445+
contentType = contentType.Trim().ToLowerInvariant();
446+
if (!(contentType == "text/plain" || contentType == "application/x-www-form-urlencoded" || contentType == "multipart/form-data"))
447+
return false;
448+
}
449+
foreach (var header in _options.CsrfProtectionHeaders) {
450+
if (context.Request.Headers.TryGetValue(header, out var values) && values.Count > 0 && values[0]?.Length > 0)
451+
return false;
452+
}
453+
await HandleCsrfProtectionErrorAsync(context, next);
454+
return true;
455+
}
456+
426457
/// <summary>
427458
/// Perform authentication, if required, and return <see langword="true"/> if the
428459
/// request was handled (typically by returning an error message). If <see langword="false"/>
@@ -769,21 +800,29 @@ protected virtual Task HandleNotAuthorizedPolicyAsync(HttpContext context, Reque
769800
/// </summary>
770801
protected virtual async ValueTask<bool> HandleDeserializationErrorAsync(HttpContext context, RequestDelegate next, Exception exception)
771802
{
772-
await WriteErrorResponseAsync(context, HttpStatusCode.BadRequest, new JsonInvalidError(exception));
803+
await WriteErrorResponseAsync(context, new JsonInvalidError(exception));
773804
return true;
774805
}
775806

807+
/// <summary>
808+
/// Writes a '.' message to the output.
809+
/// </summary>
810+
protected virtual async Task HandleCsrfProtectionErrorAsync(HttpContext context, RequestDelegate next)
811+
{
812+
await WriteErrorResponseAsync(context, new CsrfProtectionError(_options.CsrfProtectionHeaders));
813+
}
814+
776815
/// <summary>
777816
/// Writes a '400 Batched requests are not supported.' message to the output.
778817
/// </summary>
779818
protected virtual Task HandleBatchedRequestsNotSupportedAsync(HttpContext context, RequestDelegate next)
780-
=> WriteErrorResponseAsync(context, HttpStatusCode.BadRequest, new BatchedRequestsNotSupportedError());
819+
=> WriteErrorResponseAsync(context, new BatchedRequestsNotSupportedError());
781820

782821
/// <summary>
783822
/// Writes a '400 Invalid requested WebSocket sub-protocol(s).' message to the output.
784823
/// </summary>
785824
protected virtual Task HandleWebSocketSubProtocolNotSupportedAsync(HttpContext context, RequestDelegate next)
786-
=> WriteErrorResponseAsync(context, HttpStatusCode.BadRequest, new WebSocketSubProtocolNotSupportedError(context.WebSockets.WebSocketRequestedProtocols));
825+
=> WriteErrorResponseAsync(context, new WebSocketSubProtocolNotSupportedError(context.WebSockets.WebSocketRequestedProtocols));
787826

788827
/// <summary>
789828
/// Writes a '415 Invalid Content-Type header: could not be parsed.' message to the output.
@@ -814,6 +853,12 @@ protected virtual Task HandleInvalidHttpMethodErrorAsync(HttpContext context, Re
814853
return next(context);
815854
}
816855

856+
/// <summary>
857+
/// Writes the specified error as a JSON-formatted GraphQL response.
858+
/// </summary>
859+
protected virtual Task WriteErrorResponseAsync(HttpContext context, ExecutionError executionError)
860+
=> WriteErrorResponseAsync(context, executionError is IHasPreferredStatusCode withCode ? withCode.PreferredStatusCode : HttpStatusCode.BadRequest, executionError);
861+
817862
/// <summary>
818863
/// Writes the specified error message as a JSON-formatted GraphQL response, with the specified HTTP status code.
819864
/// </summary>

src/GraphQL.AspNetCore3/GraphQLHttpMiddlewareOptions.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,22 @@ public class GraphQLHttpMiddlewareOptions : IAuthorizationOptions
6161
/// </remarks>
6262
public bool ReadFormOnPost { get; set; } = true;
6363

64+
/// <summary>
65+
/// Enables cross-site request forgery (CSRF) protection for both GET and POST requests.
66+
/// Requires a non-empty header from the <see cref="CsrfProtectionHeaders"/> list to be
67+
/// present, or a POST request with a Content-Type header that is not <c>text/plain</c>,
68+
/// <c>application/x-www-form-urlencoded</c>, or <c>multipart/form-data</c>.
69+
/// </summary>
70+
public bool CsrfProtectionEnabled { get; set; }
71+
72+
/// <summary>
73+
/// When <see cref="CsrfProtectionEnabled"/> is enabled, requests require a non-empty
74+
/// header from this list or a POST request with a Content-Type header that is not
75+
/// <c>text/plain</c>, <c>application/x-www-form-urlencoded</c>, or <c>multipart/form-data</c>.
76+
/// Defaults to <c>GraphQL-Require-Preflight</c>.
77+
/// </summary>
78+
public List<string> CsrfProtectionHeaders { get; set; } = new() { "GraphQL-Require-Preflight" }; // see https://github.com/graphql/graphql-over-http/pull/303
79+
6480
/// <summary>
6581
/// Enables reading variables from the query string.
6682
/// Variables are interpreted as JSON and deserialized before being

src/Tests.ApiApprovals/GraphQL.AspNetCore3.approved.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ namespace GraphQL.AspNetCore3
132132
protected virtual System.Threading.Tasks.Task HandleBatchRequestAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next, System.Collections.Generic.IList<GraphQL.Transport.GraphQLRequest?> gqlRequests) { }
133133
protected virtual System.Threading.Tasks.Task HandleBatchedRequestsNotSupportedAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { }
134134
protected virtual System.Threading.Tasks.Task HandleContentTypeCouldNotBeParsedErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { }
135+
protected virtual System.Threading.Tasks.ValueTask<bool> HandleCsrfProtectionAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { }
136+
protected virtual System.Threading.Tasks.Task HandleCsrfProtectionErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { }
135137
protected virtual System.Threading.Tasks.ValueTask<bool> HandleDeserializationErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next, System.Exception exception) { }
136138
protected virtual System.Threading.Tasks.Task HandleInvalidContentTypeErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { }
137139
protected virtual System.Threading.Tasks.Task HandleInvalidHttpMethodErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { }
@@ -147,6 +149,7 @@ namespace GraphQL.AspNetCore3
147149
"BatchRequest"})]
148150
protected virtual System.Threading.Tasks.Task<System.ValueTuple<GraphQL.Transport.GraphQLRequest?, System.Collections.Generic.IList<GraphQL.Transport.GraphQLRequest?>?>?> ReadPostContentAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next, string? mediaType, System.Text.Encoding? sourceEncoding) { }
149151
protected virtual string SelectResponseContentType(Microsoft.AspNetCore.Http.HttpContext context) { }
152+
protected virtual System.Threading.Tasks.Task WriteErrorResponseAsync(Microsoft.AspNetCore.Http.HttpContext context, GraphQL.ExecutionError executionError) { }
150153
protected virtual System.Threading.Tasks.Task WriteErrorResponseAsync(Microsoft.AspNetCore.Http.HttpContext context, System.Net.HttpStatusCode httpStatusCode, GraphQL.ExecutionError executionError) { }
151154
protected virtual System.Threading.Tasks.Task WriteErrorResponseAsync(Microsoft.AspNetCore.Http.HttpContext context, System.Net.HttpStatusCode httpStatusCode, string errorMessage) { }
152155
protected virtual System.Threading.Tasks.Task WriteJsonResponseAsync<TResult>(Microsoft.AspNetCore.Http.HttpContext context, System.Net.HttpStatusCode httpStatusCode, TResult result) { }
@@ -158,6 +161,8 @@ namespace GraphQL.AspNetCore3
158161
public bool AuthorizationRequired { get; set; }
159162
public string? AuthorizedPolicy { get; set; }
160163
public System.Collections.Generic.List<string> AuthorizedRoles { get; set; }
164+
public bool CsrfProtectionEnabled { get; set; }
165+
public System.Collections.Generic.List<string> CsrfProtectionHeaders { get; set; }
161166
public bool EnableBatchedRequests { get; set; }
162167
public bool ExecuteBatchedRequestsInParallel { get; set; }
163168
public bool HandleGet { get; set; }
@@ -224,6 +229,11 @@ namespace GraphQL.AspNetCore3.Errors
224229
{
225230
public BatchedRequestsNotSupportedError() { }
226231
}
232+
public class CsrfProtectionError : GraphQL.Execution.RequestError
233+
{
234+
public CsrfProtectionError(System.Collections.Generic.IEnumerable<string> headersRequired) { }
235+
public CsrfProtectionError(System.Collections.Generic.IEnumerable<string> headersRequired, System.Exception innerException) { }
236+
}
227237
public class FileCountExceededError : GraphQL.Execution.RequestError, GraphQL.AspNetCore3.Errors.IHasPreferredStatusCode
228238
{
229239
public FileCountExceededError() { }

0 commit comments

Comments
 (0)