Skip to content

Commit 683bff0

Browse files
committed
xref #75 - simple server exception handling
1 parent 30a9682 commit 683bff0

File tree

5 files changed

+284
-26
lines changed

5 files changed

+284
-26
lines changed

src/protobuf-net.Grpc/Configuration/ServerBinder.cs

Lines changed: 79 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
using ProtoBuf.Grpc.Internal;
33
using System;
44
using System.Collections.Generic;
5+
using System.Linq;
56
using System.Linq.Expressions;
67
using System.Reflection;
8+
using System.Threading.Tasks;
79

810
namespace ProtoBuf.Grpc.Configuration
911
{
@@ -40,16 +42,20 @@ public int Bind(object state, Type serviceType, BinderConfiguration? binderConfi
4042
? new HashSet<Type> { serviceType }
4143
: ContractOperation.ExpandInterfaces(serviceType);
4244

45+
bool serviceImplSimplifiedExceptions = serviceType.IsDefined(typeof(SimpleRpcExceptionsAttribute));
4346
foreach (var serviceContract in serviceContracts)
4447
{
4548
if (!binderConfiguration.Binder.IsServiceContract(serviceContract, out serviceName)) continue;
4649

50+
var serviceContractSimplifiedExceptions = serviceImplSimplifiedExceptions || serviceContract.IsDefined(typeof(SimpleRpcExceptionsAttribute));
4751
int svcOpCount = 0;
4852
var bindCtx = new ServiceBindContext(serviceContract, serviceType, state);
4953
foreach (var op in ContractOperation.FindOperations(binderConfiguration, serviceContract, this))
5054
{
5155
if (ServerInvokerLookup.TryGetValue(op.MethodType, op.Context, op.Result, op.Void, out var invoker)
52-
&& AddMethod(op.From, op.To, op.Name, op.Method, op.MethodType, invoker, bindCtx))
56+
&& AddMethod(op.From, op.To, op.Name, op.Method, op.MethodType, invoker, bindCtx,
57+
serviceContractSimplifiedExceptions || op.Method.IsDefined(typeof(SimpleRpcExceptionsAttribute))
58+
))
5359
{
5460
// yay!
5561
totalCount++;
@@ -61,7 +67,7 @@ public int Bind(object state, Type serviceType, BinderConfiguration? binderConfi
6167
return totalCount;
6268

6369
bool AddMethod(Type @in, Type @out, string on, MethodInfo m, MethodType t,
64-
Func<MethodInfo, ParameterExpression[], Expression>? invoker, ServiceBindContext bindContext)
70+
Func<MethodInfo, ParameterExpression[], Expression>? invoker, ServiceBindContext bindContext, bool simplifiedExceptionHandling)
6571
{
6672
try
6773
{
@@ -72,16 +78,20 @@ bool AddMethod(Type @in, Type @out, string on, MethodInfo m, MethodType t,
7278
typesBuffer[1] = @in;
7379
typesBuffer[2] = @out;
7480

75-
if (argsBuffer == null)
81+
if (argsBuffer is null)
7682
{
77-
argsBuffer = new object?[] { null, null, null, null, null, null, binderConfiguration!.MarshallerCache, service == null ? null : Expression.Constant(service, serviceType) };
83+
argsBuffer = new object?[9];
84+
argsBuffer[6] = binderConfiguration!.MarshallerCache;
85+
argsBuffer[7] = service is null ? null : Expression.Constant(service, serviceType);
7886
}
7987
argsBuffer[0] = serviceName;
8088
argsBuffer[1] = on;
8189
argsBuffer[2] = m;
8290
argsBuffer[3] = t;
8391
argsBuffer[4] = bindContext;
8492
argsBuffer[5] = invoker;
93+
// 6, 7 set during array initialization
94+
argsBuffer[8] = simplifiedExceptionHandling;
8595

8696
return (bool)s_addMethod.MakeGenericMethod(typesBuffer).Invoke(this, argsBuffer)!;
8797
}
@@ -116,14 +126,16 @@ protected readonly struct MethodStub<TService>
116126
{
117127
private readonly ConstantExpression? _service;
118128
private readonly Func<MethodInfo, Expression[], Expression>? _invoker;
129+
private readonly bool _simpleExceptionHandling;
119130

120131
/// <summary>
121132
/// The runtime method being considered
122133
/// </summary>
123134
public MethodInfo Method { get; }
124135

125-
internal MethodStub(Func<MethodInfo, Expression[], Expression>? invoker, MethodInfo method, ConstantExpression? service)
136+
internal MethodStub(Func<MethodInfo, Expression[], Expression>? invoker, MethodInfo method, ConstantExpression? service, bool simpleExceptionHandling)
126137
{
138+
_simpleExceptionHandling = simpleExceptionHandling;
127139
_invoker = invoker;
128140
_service = service;
129141
Method = method;
@@ -137,43 +149,86 @@ public TDelegate CreateDelegate<TDelegate>()
137149
{
138150
if (_invoker == null)
139151
{
140-
// basic - direct call
141-
return (TDelegate)Delegate.CreateDelegate(typeof(TDelegate), _service, Method);
142-
}
143-
var lambdaArgs = ParameterCache<TDelegate>.Parameters;
152+
if (_simpleExceptionHandling)
153+
{
154+
var lambdaArgs = ParameterCache<TDelegate>.Parameters;
155+
156+
var call = _service is null
157+
? Expression.Call(Method, lambdaArgs)
158+
: Expression.Call(_service, Method, lambdaArgs);
144159

145-
Expression[] mapArgs;
146-
if (_service == null)
147-
{ // if no service object, then the service is part of the signature, i.e. (svc, req) => svc.Blah();
148-
mapArgs = lambdaArgs;
160+
return Expression.Lambda<TDelegate>(
161+
ApplySimpleExceptionHandling(call), lambdaArgs).Compile();
162+
}
163+
else
164+
{
165+
// basic - direct call
166+
return (TDelegate)Delegate.CreateDelegate(typeof(TDelegate), _service, Method);
167+
}
149168
}
150169
else
151170
{
152-
// if there *is* a service object, then that is *not* part of the signature, i.e. (req) => svc.Blah(req)
153-
// where the svc instance comes in separately
154-
mapArgs = new Expression[lambdaArgs.Length + 1];
155-
mapArgs[0] = _service;
156-
lambdaArgs.CopyTo(mapArgs, 1);
157-
}
171+
var lambdaArgs = ParameterCache<TDelegate>.Parameters;
172+
173+
Expression[] mapArgs;
174+
if (_service is null)
175+
{ // if no service object, then the service is part of the signature, i.e. (svc, req) => svc.Blah();
176+
mapArgs = lambdaArgs;
177+
}
178+
else
179+
{
180+
// if there *is* a service object, then that is *not* part of the signature, i.e. (req) => svc.Blah(req)
181+
// where the svc instance comes in separately
182+
mapArgs = new Expression[lambdaArgs.Length + 1];
183+
mapArgs[0] = _service;
184+
lambdaArgs.CopyTo(mapArgs, 1);
185+
}
158186

159-
var body = _invoker.Invoke(Method, mapArgs);
160-
var lambda = Expression.Lambda<TDelegate>(body, lambdaArgs);
187+
var body = _invoker.Invoke(Method, mapArgs);
188+
if (_simpleExceptionHandling)
189+
{
190+
body = ApplySimpleExceptionHandling(body);
191+
}
192+
var lambda = Expression.Lambda<TDelegate>(body, lambdaArgs);
161193

162-
return lambda.Compile();
194+
return lambda.Compile();
195+
}
196+
}
197+
198+
static Expression ApplySimpleExceptionHandling(Expression body)
199+
{
200+
var type = body.Type;
201+
if (type == typeof(Task))
202+
{
203+
body = Expression.Call(s_ReshapeWithSimpleExceptionHandling[0], body);
204+
}
205+
else if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Task<>))
206+
{
207+
body = Expression.Call(s_ReshapeWithSimpleExceptionHandling[1].MakeGenericMethod(type.GetGenericArguments()), body);
208+
}
209+
return body;
163210
}
164211
}
165212

213+
#pragma warning disable CS0618
214+
private static readonly Dictionary<int, MethodInfo> s_ReshapeWithSimpleExceptionHandling =
215+
(from method in typeof(Reshape).GetMethods(BindingFlags.Public | BindingFlags.Static)
216+
where method.Name is nameof(Reshape.WithSimpleExceptionHandling)
217+
select method)
218+
.ToDictionary(method => method.IsGenericMethodDefinition ? method.GetGenericArguments().Length : 0);
219+
#pragma warning restore CS0618
220+
166221
private bool AddMethod<TService, TRequest, TResponse>(
167222
string serviceName, string operationName, MethodInfo method, MethodType methodType,
168223
ServiceBindContext bindContext,
169224
Func<MethodInfo, Expression[], Expression>? invoker, MarshallerCache marshallerCache,
170-
ConstantExpression? service)
225+
ConstantExpression? service, bool simplfiedExceptionHandling)
171226
where TService : class
172227
where TRequest : class
173228
where TResponse : class
174229
{
175230
var grpcMethod = new Method<TRequest, TResponse>(methodType, serviceName, operationName, marshallerCache.GetMarshaller<TRequest>(), marshallerCache.GetMarshaller<TResponse>());
176-
var stub = new MethodStub<TService>(invoker, method, service);
231+
var stub = new MethodStub<TService>(invoker, method, service, simplfiedExceptionHandling);
177232
try
178233
{
179234
return TryBind<TService, TRequest, TResponse>(bindContext, grpcMethod, stub);
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using Grpc.Core;
2+
using System;
3+
4+
namespace ProtoBuf.Grpc.Configuration
5+
{
6+
/// <summary>
7+
/// Indicates that a service or method should use simplified exception handling - which means that all server exceptions are treated as <see cref="RpcException"/>; this
8+
/// will expose the <see cref="Exception.Message"/> to the caller (and the type may be interpreted as a <see cref="StatusCode"/> when possible), which should only be
9+
/// done with caution as this may present security implications. Additional exception metadata (<see cref="Exception.Data"/>, <see cref="Exception.InnerException"/>,
10+
/// <see cref="Exception.StackTrace"/>, etc) is not propagated. The exception is still exposed at the client as an <see cref="RpcException"/>.
11+
/// </summary>
12+
/// <remarks>This feature is only currently supported on <c>async</c> methods that expose their faults via the returned awaitable, not by throwing directly.</remarks>
13+
[AttributeUsage(AttributeTargets.Interface | AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = false, Inherited = true)]
14+
public sealed class SimpleRpcExceptionsAttribute : Attribute
15+
{
16+
/// <summary>
17+
/// Gets a shared instance of this attribute type
18+
/// </summary>
19+
public static SimpleRpcExceptionsAttribute Default => s_Default ??= new SimpleRpcExceptionsAttribute();
20+
21+
private static SimpleRpcExceptionsAttribute? s_Default;
22+
}
23+
}

src/protobuf-net.Grpc/Internal/ContractOperation.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,8 @@ public static List<ContractOperation> FindOperations(BinderConfiguration binderC
296296
(from method in typeof(Reshape).GetMethods(BindingFlags.Public | BindingFlags.Static)
297297
where method.IsGenericMethodDefinition
298298
let parameters = method.GetParameters()
299-
where parameters[1].ParameterType == typeof(CallInvoker)
299+
where parameters.Length > 1
300+
&& parameters[1].ParameterType == typeof(CallInvoker)
300301
&& parameters[0].ParameterType == typeof(CallContext).MakeByRefType()
301302
select method).ToDictionary(x => x.Name);
302303

src/protobuf-net.Grpc/Internal/Reshape.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ static async Task<T> Awaited(Task<T> task)
178178
catch (Exception ex) when (!(ex is RpcException))
179179
{
180180
Rethrow(ex);
181-
return default!; // never reached
181+
return default!; // make compiler happy
182182
}
183183
}
184184
}

0 commit comments

Comments
 (0)