diff --git a/src/Grpc.AspNetCore.Server/Internal/CallHandlers/ServerCallHandlerBase.cs b/src/Grpc.AspNetCore.Server/Internal/CallHandlers/ServerCallHandlerBase.cs index f4db8601f..f9ff9caa8 100644 --- a/src/Grpc.AspNetCore.Server/Internal/CallHandlers/ServerCallHandlerBase.cs +++ b/src/Grpc.AspNetCore.Server/Internal/CallHandlers/ServerCallHandlerBase.cs @@ -78,15 +78,16 @@ public Task HandleCallAsync(HttpContext httpContext) } else { - return AwaitHandleCall(serverCallContext, MethodInvoker.Method, handleCallTask); + return AwaitHandleCall(serverCallContext, handleCallTask); } } catch (Exception ex) { - return serverCallContext.ProcessHandlerErrorAsync(ex, MethodInvoker.Method.Name); + // Enhanced exception handling for deserialization errors + return HandleCallExceptionAsync(serverCallContext, ex); } - static async Task AwaitHandleCall(HttpContextServerCallContext serverCallContext, Method method, Task handleCall) + static async Task AwaitHandleCall(HttpContextServerCallContext serverCallContext, Task handleCall) { try { @@ -95,7 +96,7 @@ static async Task AwaitHandleCall(HttpContextServerCallContext serverCallContext } catch (Exception ex) { - await serverCallContext.ProcessHandlerErrorAsync(ex, method.Name); + await HandleCallExceptionAsync(serverCallContext, ex); } } } @@ -153,6 +154,44 @@ protected void DisableRequestTimeout(HttpContext httpContext) } #endif + /// + /// Handles exceptions that occur during call processing, with special handling for deserialization cancellations. + /// + private static Task HandleCallExceptionAsync(HttpContextServerCallContext serverCallContext, Exception ex) + { + // If it's already an RpcException, let the existing logic handle it + if (ex is RpcException rpcEx) + { + return serverCallContext.ProcessHandlerErrorAsync(rpcEx, serverCallContext.Method); + } + + // Convert specific exception types to proper RpcExceptions + var convertedException = ConvertToRpcException(ex, serverCallContext); + return serverCallContext.ProcessHandlerErrorAsync(convertedException, serverCallContext.Method); + } + + /// + /// Converts framework exceptions to appropriate RpcExceptions. + /// + private static RpcException ConvertToRpcException(Exception ex, HttpContextServerCallContext _) + { + return ex switch + { + OperationCanceledException _ when _.HttpContext.RequestAborted.IsCancellationRequested => + new RpcException(new Status(StatusCode.Cancelled, "Call canceled by the client.", ex)), + IOException ioEx when IsConnectionResetException(ioEx) => + new RpcException(new Status(StatusCode.Cancelled, "Client disconnected.", ex)), + _ => new RpcException(new Status(StatusCode.Unknown, "Error processing call.", ex)) + }; + } + + private static bool IsConnectionResetException(IOException ex) + { + return ex.Message.Contains("reset", StringComparison.OrdinalIgnoreCase) || + ex.Message.Contains("aborted", StringComparison.OrdinalIgnoreCase) || + ex.Message.Contains("disconnect", StringComparison.OrdinalIgnoreCase); + } + private Task ProcessNonHttp2Request(HttpContext httpContext) { GrpcServerLog.UnsupportedRequestProtocol(Logger, httpContext.Request.Protocol); diff --git a/src/Grpc.AspNetCore.Server/Internal/HttpContextServerCallContext.cs b/src/Grpc.AspNetCore.Server/Internal/HttpContextServerCallContext.cs index 7f26bc046..2af364450 100644 --- a/src/Grpc.AspNetCore.Server/Internal/HttpContextServerCallContext.cs +++ b/src/Grpc.AspNetCore.Server/Internal/HttpContextServerCallContext.cs @@ -581,6 +581,14 @@ internal void ValidateAcceptEncodingContainsResponseEncoding() private string DebuggerToString() => $"Method = {Method}"; + internal void EnsureRequestNotAborted() + { + if (HttpContext.RequestAborted.IsCancellationRequested) + { + throw new RpcException(new Status(StatusCode.Cancelled, "Request was aborted by the client.")); + } + } + private sealed class HttpContextServerCallContextDebugView { private readonly HttpContextServerCallContext _context; diff --git a/src/Grpc.AspNetCore.Server/Internal/PipeExtensions.cs b/src/Grpc.AspNetCore.Server/Internal/PipeExtensions.cs index e1f23326f..0beb25260 100644 --- a/src/Grpc.AspNetCore.Server/Internal/PipeExtensions.cs +++ b/src/Grpc.AspNetCore.Server/Internal/PipeExtensions.cs @@ -23,6 +23,7 @@ using System.Runtime.CompilerServices; using Grpc.Core; using Grpc.Net.Compression; +using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.Logging; namespace Grpc.AspNetCore.Server.Internal; @@ -206,6 +207,9 @@ public static async ValueTask ReadSingleMessageAsync(this PipeReader input while (true) { + // Check for client disconnect before reading + serverCallContext.EnsureRequestNotAborted(); + var result = await input.ReadAsync(); var buffer = result.Buffer; @@ -216,6 +220,9 @@ public static async ValueTask ReadSingleMessageAsync(this PipeReader input throw new RpcException(MessageCancelledStatus); } + // Check for client disconnect during processing + serverCallContext.EnsureRequestNotAborted(); + if (!buffer.IsEmpty) { if (request != null) @@ -225,6 +232,9 @@ public static async ValueTask ReadSingleMessageAsync(this PipeReader input if (TryReadMessage(ref buffer, serverCallContext, out var data)) { + // Check for client disconnect before deserialization + serverCallContext.EnsureRequestNotAborted(); + // Finished and the complete message has arrived GrpcServerLog.DeserializingMessage(logger, (int)data.Length, typeof(T)); @@ -275,6 +285,21 @@ public static async ValueTask ReadSingleMessageAsync(this PipeReader input } } } + catch (OperationCanceledException ex) when (serverCallContext.HttpContext.RequestAborted.IsCancellationRequested) + { + // Convert operation canceled due to client disconnect to proper RpcException + throw new RpcException(new Status(StatusCode.Cancelled, "Call canceled by the client.", ex)); + } + catch (IOException ex) when (IsConnectionResetException(ex)) + { + // Convert connection reset to proper RpcException + throw new RpcException(new Status(StatusCode.Cancelled, "Client disconnected during request.", ex)); + } + catch (Exception ex) when (IsConnectionAbortedException(ex)) + { + // Convert connection aborted to proper RpcException + throw new RpcException(new Status(StatusCode.Cancelled, "Connection aborted during request.", ex)); + } catch (Exception ex) when (ex is not OperationCanceledException) { // Don't write error when user cancels read @@ -303,6 +328,10 @@ public static async ValueTask ReadSingleMessageAsync(this PipeReader input while (true) { var completeMessage = false; + + // Check for client disconnect before reading + serverCallContext.EnsureRequestNotAborted(); + var result = await input.ReadAsync(cancellationToken); var buffer = result.Buffer; @@ -313,12 +342,18 @@ public static async ValueTask ReadSingleMessageAsync(this PipeReader input throw new RpcException(MessageCancelledStatus); } + // Check for client disconnect during processing + serverCallContext.EnsureRequestNotAborted(); + if (!buffer.IsEmpty) { if (TryReadMessage(ref buffer, serverCallContext, out var data)) { completeMessage = true; + // Check for client disconnect before deserialization + serverCallContext.EnsureRequestNotAborted(); + GrpcServerLog.DeserializingMessage(logger, (int)data.Length, typeof(T)); serverCallContext.DeserializationContext.SetPayload(data); @@ -363,6 +398,21 @@ public static async ValueTask ReadSingleMessageAsync(this PipeReader input } } } + catch (OperationCanceledException ex) when (serverCallContext.HttpContext.RequestAborted.IsCancellationRequested) + { + // Convert operation canceled due to client disconnect to proper RpcException + throw new RpcException(new Status(StatusCode.Cancelled, "Call canceled by the client.", ex)); + } + catch (IOException ex) when (IsConnectionResetException(ex)) + { + // Convert connection reset to proper RpcException + throw new RpcException(new Status(StatusCode.Cancelled, "Client disconnected during request.", ex)); + } + catch (Exception ex) when (IsConnectionAbortedException(ex)) + { + // Convert connection aborted to proper RpcException + throw new RpcException(new Status(StatusCode.Cancelled, "Connection aborted during request.", ex)); + } catch (Exception ex) when (!(ex is OperationCanceledException && cancellationToken.IsCancellationRequested)) { // Don't write error when user cancels read @@ -371,6 +421,22 @@ public static async ValueTask ReadSingleMessageAsync(this PipeReader input } } + // Add helper methods for detecting connection issues + private static bool IsConnectionResetException(IOException ex) + { + return ex.Message.Contains("reset", StringComparison.OrdinalIgnoreCase) || + ex.Message.Contains("aborted", StringComparison.OrdinalIgnoreCase) || + ex.Message.Contains("disconnect", StringComparison.OrdinalIgnoreCase) || + ex.Message.Contains("canceled", StringComparison.OrdinalIgnoreCase); + } + + private static bool IsConnectionAbortedException(Exception ex) + { + return ex is ObjectDisposedException || + ex is ConnectionAbortedException || + (ex is IOException ioEx && ioEx.InnerException is ConnectionAbortedException); + } + private static bool TryReadMessage(ref ReadOnlySequence buffer, HttpContextServerCallContext context, out ReadOnlySequence message) { if (!TryReadHeader(buffer, out var compressed, out var messageLength))