Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,16 @@ public Task HandleCallAsync(HttpContext httpContext)
}
else
{
return AwaitHandleCall(serverCallContext, MethodInvoker.Method, handleCallTask);
return AwaitHandleCall(serverCallContext, handleCallTask);
}
}
catch (Exception ex)
{
return serverCallContext.ProcessHandlerErrorAsync(ex, MethodInvoker.Method.Name);
// Enhanced exception handling for deserialization errors
return HandleCallExceptionAsync(serverCallContext, ex);
}

static async Task AwaitHandleCall(HttpContextServerCallContext serverCallContext, Method<TRequest, TResponse> method, Task handleCall)
static async Task AwaitHandleCall(HttpContextServerCallContext serverCallContext, Task handleCall)
{
try
{
Expand All @@ -95,7 +96,7 @@ static async Task AwaitHandleCall(HttpContextServerCallContext serverCallContext
}
catch (Exception ex)
{
await serverCallContext.ProcessHandlerErrorAsync(ex, method.Name);
await HandleCallExceptionAsync(serverCallContext, ex);
}
}
}
Expand Down Expand Up @@ -153,6 +154,44 @@ protected void DisableRequestTimeout(HttpContext httpContext)
}
#endif

/// <summary>
/// Handles exceptions that occur during call processing, with special handling for deserialization cancellations.
/// </summary>
private static Task HandleCallExceptionAsync(HttpContextServerCallContext serverCallContext, Exception ex)
{
// If it's already an RpcException, let the existing logic handle it
if (ex is RpcException rpcEx)
{
return serverCallContext.ProcessHandlerErrorAsync(rpcEx, serverCallContext.Method);
}

// Convert specific exception types to proper RpcExceptions
var convertedException = ConvertToRpcException(ex, serverCallContext);
return serverCallContext.ProcessHandlerErrorAsync(convertedException, serverCallContext.Method);
}

/// <summary>
/// Converts framework exceptions to appropriate RpcExceptions.
/// </summary>
private static RpcException ConvertToRpcException(Exception ex, HttpContextServerCallContext _)
{
return ex switch
{
OperationCanceledException _ when _.HttpContext.RequestAborted.IsCancellationRequested =>
new RpcException(new Status(StatusCode.Cancelled, "Call canceled by the client.", ex)),
IOException ioEx when IsConnectionResetException(ioEx) =>
new RpcException(new Status(StatusCode.Cancelled, "Client disconnected.", ex)),
_ => new RpcException(new Status(StatusCode.Unknown, "Error processing call.", ex))
};
}

private static bool IsConnectionResetException(IOException ex)
{
return ex.Message.Contains("reset", StringComparison.OrdinalIgnoreCase) ||
ex.Message.Contains("aborted", StringComparison.OrdinalIgnoreCase) ||
ex.Message.Contains("disconnect", StringComparison.OrdinalIgnoreCase);
}

private Task ProcessNonHttp2Request(HttpContext httpContext)
{
GrpcServerLog.UnsupportedRequestProtocol(Logger, httpContext.Request.Protocol);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,14 @@ internal void ValidateAcceptEncodingContainsResponseEncoding()

private string DebuggerToString() => $"Method = {Method}";

internal void EnsureRequestNotAborted()
{
if (HttpContext.RequestAborted.IsCancellationRequested)
{
throw new RpcException(new Status(StatusCode.Cancelled, "Request was aborted by the client."));
}
}

private sealed class HttpContextServerCallContextDebugView
{
private readonly HttpContextServerCallContext _context;
Expand Down
66 changes: 66 additions & 0 deletions src/Grpc.AspNetCore.Server/Internal/PipeExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
using System.Runtime.CompilerServices;
using Grpc.Core;
using Grpc.Net.Compression;
using Microsoft.AspNetCore.Connections;
using Microsoft.Extensions.Logging;

namespace Grpc.AspNetCore.Server.Internal;
Expand Down Expand Up @@ -206,6 +207,9 @@ public static async ValueTask<T> ReadSingleMessageAsync<T>(this PipeReader input

while (true)
{
// Check for client disconnect before reading
serverCallContext.EnsureRequestNotAborted();

var result = await input.ReadAsync();
var buffer = result.Buffer;

Expand All @@ -216,6 +220,9 @@ public static async ValueTask<T> ReadSingleMessageAsync<T>(this PipeReader input
throw new RpcException(MessageCancelledStatus);
}

// Check for client disconnect during processing
serverCallContext.EnsureRequestNotAborted();

if (!buffer.IsEmpty)
{
if (request != null)
Expand All @@ -225,6 +232,9 @@ public static async ValueTask<T> ReadSingleMessageAsync<T>(this PipeReader input

if (TryReadMessage(ref buffer, serverCallContext, out var data))
{
// Check for client disconnect before deserialization
serverCallContext.EnsureRequestNotAborted();

// Finished and the complete message has arrived
GrpcServerLog.DeserializingMessage(logger, (int)data.Length, typeof(T));

Expand Down Expand Up @@ -275,6 +285,21 @@ public static async ValueTask<T> ReadSingleMessageAsync<T>(this PipeReader input
}
}
}
catch (OperationCanceledException ex) when (serverCallContext.HttpContext.RequestAborted.IsCancellationRequested)
{
// Convert operation canceled due to client disconnect to proper RpcException
throw new RpcException(new Status(StatusCode.Cancelled, "Call canceled by the client.", ex));
}
catch (IOException ex) when (IsConnectionResetException(ex))
{
// Convert connection reset to proper RpcException
throw new RpcException(new Status(StatusCode.Cancelled, "Client disconnected during request.", ex));
}
catch (Exception ex) when (IsConnectionAbortedException(ex))
{
// Convert connection aborted to proper RpcException
throw new RpcException(new Status(StatusCode.Cancelled, "Connection aborted during request.", ex));
}
catch (Exception ex) when (ex is not OperationCanceledException)
{
// Don't write error when user cancels read
Expand Down Expand Up @@ -303,6 +328,10 @@ public static async ValueTask<T> ReadSingleMessageAsync<T>(this PipeReader input
while (true)
{
var completeMessage = false;

// Check for client disconnect before reading
serverCallContext.EnsureRequestNotAborted();

var result = await input.ReadAsync(cancellationToken);
var buffer = result.Buffer;

Expand All @@ -313,12 +342,18 @@ public static async ValueTask<T> ReadSingleMessageAsync<T>(this PipeReader input
throw new RpcException(MessageCancelledStatus);
}

// Check for client disconnect during processing
serverCallContext.EnsureRequestNotAborted();

if (!buffer.IsEmpty)
{
if (TryReadMessage(ref buffer, serverCallContext, out var data))
{
completeMessage = true;

// Check for client disconnect before deserialization
serverCallContext.EnsureRequestNotAborted();

GrpcServerLog.DeserializingMessage(logger, (int)data.Length, typeof(T));

serverCallContext.DeserializationContext.SetPayload(data);
Expand Down Expand Up @@ -363,6 +398,21 @@ public static async ValueTask<T> ReadSingleMessageAsync<T>(this PipeReader input
}
}
}
catch (OperationCanceledException ex) when (serverCallContext.HttpContext.RequestAborted.IsCancellationRequested)
{
// Convert operation canceled due to client disconnect to proper RpcException
throw new RpcException(new Status(StatusCode.Cancelled, "Call canceled by the client.", ex));
}
catch (IOException ex) when (IsConnectionResetException(ex))
{
// Convert connection reset to proper RpcException
throw new RpcException(new Status(StatusCode.Cancelled, "Client disconnected during request.", ex));
}
catch (Exception ex) when (IsConnectionAbortedException(ex))
{
// Convert connection aborted to proper RpcException
throw new RpcException(new Status(StatusCode.Cancelled, "Connection aborted during request.", ex));
}
catch (Exception ex) when (!(ex is OperationCanceledException && cancellationToken.IsCancellationRequested))
{
// Don't write error when user cancels read
Expand All @@ -371,6 +421,22 @@ public static async ValueTask<T> ReadSingleMessageAsync<T>(this PipeReader input
}
}

// Add helper methods for detecting connection issues
private static bool IsConnectionResetException(IOException ex)
{
return ex.Message.Contains("reset", StringComparison.OrdinalIgnoreCase) ||
ex.Message.Contains("aborted", StringComparison.OrdinalIgnoreCase) ||
ex.Message.Contains("disconnect", StringComparison.OrdinalIgnoreCase) ||
ex.Message.Contains("canceled", StringComparison.OrdinalIgnoreCase);
}

private static bool IsConnectionAbortedException(Exception ex)
{
return ex is ObjectDisposedException ||
ex is ConnectionAbortedException ||
(ex is IOException ioEx && ioEx.InnerException is ConnectionAbortedException);
}

private static bool TryReadMessage(ref ReadOnlySequence<byte> buffer, HttpContextServerCallContext context, out ReadOnlySequence<byte> message)
{
if (!TryReadHeader(buffer, out var compressed, out var messageLength))
Expand Down
Loading