Skip to content

Commit fcbb685

Browse files
authored
Handle Kestrel IO exceptions (#897)
Handle Kestrel IO exceptions on bad request, returning a formatted error response.
1 parent 233ae5f commit fcbb685

File tree

5 files changed

+75
-34
lines changed

5 files changed

+75
-34
lines changed

Directory.Build.props

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
<Project>
33
<PropertyGroup>
44
<!-- Central version prefix - applies to all nuget packages. -->
5-
<Version>0.92.0</Version>
5+
<Version>0.93.0</Version>
66

77
<!-- C# lang version, https://learn.microsoft.com/dotnet/csharp/whats-new -->
88
<LangVersion>12</LangVersion>

service/Service.AspNetCore/WebAPIEndpoints.cs

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,23 @@ public static IEndpointRouteBuilder AddKernelMemoryEndpoints(
2525
this IEndpointRouteBuilder builder,
2626
string apiPrefix = "/",
2727
KernelMemoryConfig? kmConfig = null,
28-
IEndpointFilter? authFilter = null)
28+
IEndpointFilter[]? filters = null)
2929
{
30-
builder.AddPostUploadEndpoint(apiPrefix, authFilter, kmConfig?.Service.GetMaxUploadSizeInBytes());
31-
builder.AddGetIndexesEndpoint(apiPrefix, authFilter);
32-
builder.AddDeleteIndexesEndpoint(apiPrefix, authFilter);
33-
builder.AddDeleteDocumentsEndpoint(apiPrefix, authFilter);
34-
builder.AddAskEndpoint(apiPrefix, authFilter);
35-
builder.AddSearchEndpoint(apiPrefix, authFilter);
36-
builder.AddUploadStatusEndpoint(apiPrefix, authFilter);
37-
builder.AddGetDownloadEndpoint(apiPrefix, authFilter);
30+
builder.AddPostUploadEndpoint(apiPrefix, kmConfig?.Service.GetMaxUploadSizeInBytes()).AddFilters(filters);
31+
builder.AddGetIndexesEndpoint(apiPrefix).AddFilters(filters);
32+
builder.AddDeleteIndexesEndpoint(apiPrefix).AddFilters(filters);
33+
builder.AddDeleteDocumentsEndpoint(apiPrefix).AddFilters(filters);
34+
builder.AddAskEndpoint(apiPrefix).AddFilters(filters);
35+
builder.AddSearchEndpoint(apiPrefix).AddFilters(filters);
36+
builder.AddUploadStatusEndpoint(apiPrefix).AddFilters(filters);
37+
builder.AddGetDownloadEndpoint(apiPrefix).AddFilters(filters);
3838

3939
return builder;
4040
}
4141

42-
public static void AddPostUploadEndpoint(
42+
public static RouteHandlerBuilder AddPostUploadEndpoint(
4343
this IEndpointRouteBuilder builder,
4444
string apiPrefix = "/",
45-
IEndpointFilter? authFilter = null,
4645
long? maxUploadSizeInBytes = null)
4746
{
4847
RouteGroupBuilder group = builder.MapGroup(apiPrefix);
@@ -109,11 +108,11 @@ public static void AddPostUploadEndpoint(
109108
.Produces<ProblemDetails>(StatusCodes.Status403Forbidden)
110109
.Produces<ProblemDetails>(StatusCodes.Status503ServiceUnavailable);
111110

112-
if (authFilter != null) { route.AddEndpointFilter(authFilter); }
111+
return route;
113112
}
114113

115-
public static void AddGetIndexesEndpoint(
116-
this IEndpointRouteBuilder builder, string apiPrefix = "/", IEndpointFilter? authFilter = null)
114+
public static RouteHandlerBuilder AddGetIndexesEndpoint(
115+
this IEndpointRouteBuilder builder, string apiPrefix = "/")
117116
{
118117
RouteGroupBuilder group = builder.MapGroup(apiPrefix);
119118

@@ -141,11 +140,11 @@ async Task<IResult> (
141140
.Produces<ProblemDetails>(StatusCodes.Status401Unauthorized)
142141
.Produces<ProblemDetails>(StatusCodes.Status403Forbidden);
143142

144-
if (authFilter != null) { route.AddEndpointFilter(authFilter); }
143+
return route;
145144
}
146145

147-
public static void AddDeleteIndexesEndpoint(
148-
this IEndpointRouteBuilder builder, string apiPrefix = "/", IEndpointFilter? authFilter = null)
146+
public static RouteHandlerBuilder AddDeleteIndexesEndpoint(
147+
this IEndpointRouteBuilder builder, string apiPrefix = "/", IEndpointFilter[]? filters = null)
149148
{
150149
RouteGroupBuilder group = builder.MapGroup(apiPrefix);
151150

@@ -173,11 +172,11 @@ await service.DeleteIndexAsync(index: index, cancellationToken)
173172
.Produces<ProblemDetails>(StatusCodes.Status401Unauthorized)
174173
.Produces<ProblemDetails>(StatusCodes.Status403Forbidden);
175174

176-
if (authFilter != null) { route.AddEndpointFilter(authFilter); }
175+
return route;
177176
}
178177

179-
public static void AddDeleteDocumentsEndpoint(
180-
this IEndpointRouteBuilder builder, string apiPrefix = "/", IEndpointFilter? authFilter = null)
178+
public static RouteHandlerBuilder AddDeleteDocumentsEndpoint(
179+
this IEndpointRouteBuilder builder, string apiPrefix = "/", IEndpointFilter[]? filters = null)
181180
{
182181
RouteGroupBuilder group = builder.MapGroup(apiPrefix);
183182

@@ -209,11 +208,11 @@ await service.DeleteDocumentAsync(documentId: documentId, index: index, cancella
209208
.Produces<ProblemDetails>(StatusCodes.Status401Unauthorized)
210209
.Produces<ProblemDetails>(StatusCodes.Status403Forbidden);
211210

212-
if (authFilter != null) { route.AddEndpointFilter(authFilter); }
211+
return route;
213212
}
214213

215-
public static void AddAskEndpoint(
216-
this IEndpointRouteBuilder builder, string apiPrefix = "/", IEndpointFilter? authFilter = null)
214+
public static RouteHandlerBuilder AddAskEndpoint(
215+
this IEndpointRouteBuilder builder, string apiPrefix = "/", IEndpointFilter[]? filters = null)
217216
{
218217
RouteGroupBuilder group = builder.MapGroup(apiPrefix);
219218

@@ -244,11 +243,11 @@ async Task<IResult> (
244243
.Produces<ProblemDetails>(StatusCodes.Status401Unauthorized)
245244
.Produces<ProblemDetails>(StatusCodes.Status403Forbidden);
246245

247-
if (authFilter != null) { route.AddEndpointFilter(authFilter); }
246+
return route;
248247
}
249248

250-
public static void AddSearchEndpoint(
251-
this IEndpointRouteBuilder builder, string apiPrefix = "/", IEndpointFilter? authFilter = null)
249+
public static RouteHandlerBuilder AddSearchEndpoint(
250+
this IEndpointRouteBuilder builder, string apiPrefix = "/", IEndpointFilter[]? filters = null)
252251
{
253252
RouteGroupBuilder group = builder.MapGroup(apiPrefix);
254253

@@ -280,11 +279,11 @@ async Task<IResult> (
280279
.Produces<ProblemDetails>(StatusCodes.Status401Unauthorized)
281280
.Produces<ProblemDetails>(StatusCodes.Status403Forbidden);
282281

283-
if (authFilter != null) { route.AddEndpointFilter(authFilter); }
282+
return route;
284283
}
285284

286-
public static void AddUploadStatusEndpoint(
287-
this IEndpointRouteBuilder builder, string apiPrefix = "/", IEndpointFilter? authFilter = null)
285+
public static RouteHandlerBuilder AddUploadStatusEndpoint(
286+
this IEndpointRouteBuilder builder, string apiPrefix = "/", IEndpointFilter[]? filters = null)
288287
{
289288
RouteGroupBuilder group = builder.MapGroup(apiPrefix);
290289

@@ -325,10 +324,11 @@ async Task<IResult> (
325324
.Produces<ProblemDetails>(StatusCodes.Status403Forbidden)
326325
.Produces<ProblemDetails>(StatusCodes.Status404NotFound);
327326

328-
if (authFilter != null) { route.AddEndpointFilter(authFilter); }
327+
return route;
329328
}
330329

331-
public static void AddGetDownloadEndpoint(this IEndpointRouteBuilder builder, string apiPrefix = "/", IEndpointFilter? authFilter = null)
330+
public static RouteHandlerBuilder AddGetDownloadEndpoint(
331+
this IEndpointRouteBuilder builder, string apiPrefix = "/", IEndpointFilter[]? filters = null)
332332
{
333333
RouteGroupBuilder group = builder.MapGroup(apiPrefix);
334334

@@ -406,11 +406,25 @@ public static void AddGetDownloadEndpoint(this IEndpointRouteBuilder builder, st
406406
.Produces<ProblemDetails>(StatusCodes.Status403Forbidden)
407407
.Produces<ProblemDetails>(StatusCodes.Status503ServiceUnavailable);
408408

409-
if (authFilter != null) { route.AddEndpointFilter(authFilter); }
409+
return route;
410410
}
411411

412412
#pragma warning disable CA1812 // used by logger, can't be static
413413
// Class used to tag log entries and allow log filtering
414+
// ReSharper disable once ClassNeverInstantiated.Local
414415
private sealed class KernelMemoryWebAPI;
415416
#pragma warning restore CA1812
416417
}
418+
419+
internal static class EndpointConventionBuilderExtensions
420+
{
421+
internal static void AddFilters(this IEndpointConventionBuilder route, IEndpointFilter[]? filters = null)
422+
{
423+
if (filters == null || filters.Length == 0) { return; }
424+
425+
foreach (var filter in filters)
426+
{
427+
route.AddEndpointFilter(filter);
428+
}
429+
}
430+
}
File renamed without changes.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
3+
using System.Threading.Tasks;
4+
using Microsoft.AspNetCore.Http;
5+
6+
namespace Microsoft.KernelMemory.Service;
7+
8+
public sealed class HttpErrorsEndpointFilter : IEndpointFilter
9+
{
10+
public async ValueTask<object?> InvokeAsync(
11+
EndpointFilterInvocationContext context,
12+
EndpointFilterDelegate next)
13+
{
14+
try
15+
{
16+
return await next(context);
17+
}
18+
catch (BadHttpRequestException e) when (e.StatusCode == 413)
19+
{
20+
return Results.Problem(
21+
statusCode: e.StatusCode,
22+
detail: e.Message);
23+
}
24+
}
25+
}

service/Service/Program.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,18 +142,20 @@ public static void Main(string[] args)
142142
if (enableCORS) { app.UseCors(CORSPolicyName); }
143143

144144
app.UseSwagger(config);
145+
var errorFilter = new HttpErrorsEndpointFilter();
145146
var authFilter = new HttpAuthEndpointFilter(config.ServiceAuthorization);
146147
app.MapGet("/", () => Results.Ok("Ingestion service is running. " +
147148
"Uptime: " + (DateTimeOffset.UtcNow.ToUnixTimeSeconds()
148149
- s_start.ToUnixTimeSeconds()) + " secs " +
149150
$"- Environment: {Environment.GetEnvironmentVariable("ASPNETCORE_ENVIRONMENT")}"))
151+
.AddEndpointFilter(errorFilter)
150152
.AddEndpointFilter(authFilter)
151153
.Produces<string>(StatusCodes.Status200OK)
152154
.Produces<ProblemDetails>(StatusCodes.Status401Unauthorized)
153155
.Produces<ProblemDetails>(StatusCodes.Status403Forbidden);
154156

155157
// Add HTTP endpoints using minimal API (https://learn.microsoft.com/aspnet/core/fundamentals/minimal-apis)
156-
app.AddKernelMemoryEndpoints("/", config, authFilter);
158+
app.AddKernelMemoryEndpoints("/", config, [errorFilter, authFilter]);
157159

158160
// Health probe
159161
app.MapGet("/health", () => Results.Ok("Service is running."))

0 commit comments

Comments
 (0)