Skip to content

Commit 9d21592

Browse files
committed
Merge in 'release/7.0-rc1' changes
2 parents aeb67ec + 6bc3ff4 commit 9d21592

File tree

3 files changed

+58
-19
lines changed

3 files changed

+58
-19
lines changed

src/Middleware/RateLimiting/src/LeaseContext.cs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,19 @@ namespace Microsoft.AspNetCore.RateLimiting;
77

88
internal struct LeaseContext : IDisposable
99
{
10-
public bool? GlobalRejected { get; init; }
10+
public RequestRejectionReason? RequestRejectionReason { get; init; }
1111

12-
public required RateLimitLease Lease { get; init; }
12+
public RateLimitLease? Lease { get; init; }
1313

1414
public void Dispose()
1515
{
16-
Lease.Dispose();
16+
Lease?.Dispose();
1717
}
1818
}
19+
20+
internal enum RequestRejectionReason
21+
{
22+
EndpointLimiter,
23+
GlobalLimiter,
24+
RequestCanceled
25+
}

src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,20 +78,25 @@ public Task Invoke(HttpContext context)
7878
private async Task InvokeInternal(HttpContext context, EnableRateLimitingAttribute? enableRateLimitingAttribute)
7979
{
8080
using var leaseContext = await TryAcquireAsync(context);
81-
if (leaseContext.Lease.IsAcquired)
81+
if (leaseContext.Lease?.IsAcquired == true)
8282
{
8383
await _next(context);
8484
}
8585
else
8686
{
87+
// If the request was canceled, do not call OnRejected, just return.
88+
if (leaseContext.RequestRejectionReason == RequestRejectionReason.RequestCanceled)
89+
{
90+
return;
91+
}
8792
var thisRequestOnRejected = _defaultOnRejected;
8893
RateLimiterLog.RequestRejectedLimitsExceeded(_logger);
8994
// OnRejected "wins" over DefaultRejectionStatusCode - we set DefaultRejectionStatusCode first,
9095
// then call OnRejected in case it wants to do any further modification of the status code.
9196
context.Response.StatusCode = _rejectionStatusCode;
9297

9398
// If this request was rejected by the endpoint limiter, use its OnRejected if available.
94-
if (leaseContext.GlobalRejected == false)
99+
if (leaseContext.RequestRejectionReason == RequestRejectionReason.EndpointLimiter)
95100
{
96101
DefaultRateLimiterPolicy? policy;
97102
// Use custom policy OnRejected if available, else use OnRejected from the Options if available.
@@ -111,15 +116,16 @@ private async Task InvokeInternal(HttpContext context, EnableRateLimitingAttribu
111116
}
112117
if (thisRequestOnRejected is not null)
113118
{
114-
await thisRequestOnRejected(new OnRejectedContext() { HttpContext = context, Lease = leaseContext.Lease }, context.RequestAborted);
119+
// leaseContext.Lease will only be null when the request was canceled.
120+
await thisRequestOnRejected(new OnRejectedContext() { HttpContext = context, Lease = leaseContext.Lease! }, context.RequestAborted);
115121
}
116122
}
117123
}
118124

119125
private ValueTask<LeaseContext> TryAcquireAsync(HttpContext context)
120126
{
121127
var leaseContext = CombinedAcquire(context);
122-
if (leaseContext.Lease.IsAcquired)
128+
if (leaseContext.Lease?.IsAcquired == true)
123129
{
124130
return ValueTask.FromResult(leaseContext);
125131
}
@@ -139,14 +145,14 @@ private LeaseContext CombinedAcquire(HttpContext context)
139145
globalLease = _globalLimiter.AttemptAcquire(context);
140146
if (!globalLease.IsAcquired)
141147
{
142-
return new LeaseContext() { GlobalRejected = true, Lease = globalLease };
148+
return new LeaseContext() { RequestRejectionReason = RequestRejectionReason.GlobalLimiter, Lease = globalLease };
143149
}
144150
}
145151
endpointLease = _endpointLimiter.AttemptAcquire(context);
146152
if (!endpointLease.IsAcquired)
147153
{
148154
globalLease?.Dispose();
149-
return new LeaseContext() { GlobalRejected = false, Lease = endpointLease };
155+
return new LeaseContext() { RequestRejectionReason = RequestRejectionReason.EndpointLimiter, Lease = endpointLease };
150156
}
151157
}
152158
catch (Exception)
@@ -170,21 +176,30 @@ private async ValueTask<LeaseContext> CombinedWaitAsync(HttpContext context, Can
170176
globalLease = await _globalLimiter.AcquireAsync(context, cancellationToken: cancellationToken);
171177
if (!globalLease.IsAcquired)
172178
{
173-
return new LeaseContext() { GlobalRejected = true, Lease = globalLease };
179+
return new LeaseContext() { RequestRejectionReason = RequestRejectionReason.GlobalLimiter, Lease = globalLease };
174180
}
175181
}
176182
endpointLease = await _endpointLimiter.AcquireAsync(context, cancellationToken: cancellationToken);
177183
if (!endpointLease.IsAcquired)
178184
{
179185
globalLease?.Dispose();
180-
return new LeaseContext() { GlobalRejected = false, Lease = endpointLease };
186+
return new LeaseContext() { RequestRejectionReason = RequestRejectionReason.EndpointLimiter, Lease = endpointLease };
181187
}
182188
}
183-
catch (Exception)
189+
catch (Exception ex)
184190
{
185191
endpointLease?.Dispose();
186192
globalLease?.Dispose();
187-
throw;
193+
// Don't throw if the request was canceled - instead log.
194+
if (ex is OperationCanceledException && context.RequestAborted.IsCancellationRequested)
195+
{
196+
RateLimiterLog.RequestCanceled(_logger);
197+
return new LeaseContext() { RequestRejectionReason = RequestRejectionReason.RequestCanceled };
198+
}
199+
else
200+
{
201+
throw;
202+
}
188203
}
189204

190205
return globalLease is null ? new LeaseContext() { Lease = endpointLease } : new LeaseContext() { Lease = new DefaultCombinedLease(globalLease, endpointLease) };
@@ -234,5 +249,8 @@ private static partial class RateLimiterLog
234249

235250
[LoggerMessage(2, LogLevel.Debug, "This endpoint requires a rate limiting policy with name {PolicyName}, but no such policy exists.", EventName = "WarnMissingPolicy")]
236251
internal static partial void WarnMissingPolicy(ILogger logger, string policyName);
252+
253+
[LoggerMessage(3, LogLevel.Debug, "The request was canceled.", EventName = "RequestCanceled")]
254+
internal static partial void RequestCanceled(ILogger logger);
237255
}
238-
}
256+
}

src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
using Microsoft.AspNetCore.Testing;
77
using Microsoft.Extensions.Logging;
88
using Microsoft.Extensions.Logging.Abstractions;
9+
using Microsoft.Extensions.Logging.Testing;
910
using Microsoft.Extensions.Options;
1011
using Moq;
1112

1213
namespace Microsoft.AspNetCore.RateLimiting;
1314

14-
public class RateLimitingMiddlewareTests : LoggedTest
15+
public class RateLimitingMiddlewareTests
1516
{
1617
[Fact]
1718
public void Ctor_ThrowsExceptionsWhenNullArgs()
@@ -115,22 +116,35 @@ public async Task RequestRejected_WinsOverDefaultStatusCode()
115116
}
116117

117118
[Fact]
118-
public async Task RequestAborted_ThrowsTaskCanceledException()
119+
public async Task RequestAborted_DoesNotThrowTaskCanceledException()
119120
{
121+
var sink = new TestSink(
122+
TestSink.EnableWithTypeName<RateLimitingMiddleware>,
123+
TestSink.EnableWithTypeName<RateLimitingMiddleware>);
124+
var loggerFactory = new TestLoggerFactory(sink, enabled: true);
125+
120126
var options = CreateOptionsAccessor();
121127
options.Value.GlobalLimiter = new TestPartitionedRateLimiter<HttpContext>(new TestRateLimiter(false));
122128

123129
var middleware = new RateLimitingMiddleware(c =>
124130
{
125131
return Task.CompletedTask;
126132
},
127-
new NullLoggerFactory().CreateLogger<RateLimitingMiddleware>(),
133+
loggerFactory.CreateLogger<RateLimitingMiddleware>(),
128134
options,
129135
Mock.Of<IServiceProvider>());
130136

131137
var context = new DefaultHttpContext();
132138
context.RequestAborted = new CancellationToken(true);
133-
await Assert.ThrowsAsync<TaskCanceledException>(() => middleware.Invoke(context)).DefaultTimeout();
139+
await middleware.Invoke(context);
140+
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
141+
142+
var logMessages = sink.Writes.ToList();
143+
144+
Assert.Single(logMessages);
145+
var message = logMessages.First();
146+
Assert.Equal(LogLevel.Debug, message.LogLevel);
147+
Assert.Equal("The request was canceled.", message.State.ToString());
134148
}
135149

136150
[Fact]
@@ -609,4 +623,4 @@ public async Task MultipleEndpointPolicies_LastOneWins()
609623

610624
private IOptions<RateLimiterOptions> CreateOptionsAccessor() => Options.Create(new RateLimiterOptions());
611625

612-
}
626+
}

0 commit comments

Comments
 (0)