Skip to content

Commit 388ecb9

Browse files
author
Meir Kriheli
committed
Per service contract, we will collect all the methods of inherited interfaces
and bind them as they were defined in the service contract itself. (their binding key will be based on the service contract and not based on the base-interfaces).
1 parent 1665750 commit 388ecb9

File tree

5 files changed

+120
-47
lines changed

5 files changed

+120
-47
lines changed

src/protobuf-net.Grpc.Reflection/SchemaGenerator.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,9 @@ static Type ApplySubstitutes(Type type)
105105

106106
private static MethodInfo[] GetMethodsRecursively(ServiceBinder serviceBinder, Type contractType)
107107
{
108-
var includingInheritedInterfaces = ContractOperation.ExpandInterfaces(contractType);
108+
var includingInheritedInterfaces = ContractOperation.ExpandWithInterfacesMarkedAsServiceInheritable(contractType);
109109

110110
var inheritedMethods = includingInheritedInterfaces
111-
.Where(cType => serviceBinder.IsServiceContract(cType, out _)) // only the ones marked as contract type
112111
.SelectMany(t => t.GetMethods(BindingFlags.Public | BindingFlags.Instance))
113112
.ToArray();
114113

src/protobuf-net.Grpc/Configuration/ServerBinder.cs

Lines changed: 63 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,57 +24,74 @@ are observed and respected
2424
/// <summary>
2525
/// Initiate a bind operation, causing all service methods to be crawled for the provided type
2626
/// </summary>
27-
public int Bind<TService>(object state, BinderConfiguration? binderConfiguration = null, TService? service = null)
27+
public int Bind<TService>(object state, BinderConfiguration? binderConfiguration = null,
28+
TService? service = null)
2829
where TService : class
2930
=> Bind(state, typeof(TService), binderConfiguration, service);
3031

3132
/// <summary>
3233
/// Initiate a bind operation, causing all service methods to be crawled for the provided type
3334
/// </summary>
34-
public int Bind(object state, Type serviceType, BinderConfiguration? binderConfiguration = null, object? service = null)
35+
public int Bind(object state, Type serviceType, BinderConfiguration? binderConfiguration = null,
36+
object? service = null)
3537
{
3638
int totalCount = 0;
3739
object?[]? argsBuffer = null;
3840
Type[] typesBuffer = Array.Empty<Type>();
39-
string? serviceName;
4041
if (binderConfiguration == null) binderConfiguration = BinderConfiguration.Default;
4142
var serviceContracts = typeof(IGrpcService).IsAssignableFrom(serviceType)
42-
? new HashSet<Type> { serviceType }
43+
? new HashSet<Type> {serviceType}
4344
: ContractOperation.ExpandInterfaces(serviceType);
4445

4546
bool serviceImplSimplifiedExceptions = serviceType.IsDefined(typeof(SimpleRpcExceptionsAttribute));
4647
foreach (var serviceContract in serviceContracts)
4748
{
48-
if (!binderConfiguration.Binder.IsServiceContract(serviceContract, out serviceName)) continue;
49+
if (!binderConfiguration.Binder.IsServiceContract(serviceContract, out var serviceName)) continue;
50+
51+
var typesToBeIncludedInMethodsBinding =
52+
ContractOperation.ExpandWithInterfacesMarkedAsServiceInheritable(serviceContract);
53+
54+
// Per service contract, we will collect all the methods of inherited interfaces
55+
// and bind them as they were defined in the service contract itself.
56+
// (their binding key will be based on the service contract and not based on the base-interfaces).
4957

50-
var serviceContractSimplifiedExceptions = serviceImplSimplifiedExceptions || serviceContract.IsDefined(typeof(SimpleRpcExceptionsAttribute));
5158
int svcOpCount = 0;
52-
var bindCtx = new ServiceBindContext(serviceContract, serviceType, state, binderConfiguration.Binder);
53-
foreach (var op in ContractOperation.FindOperations(binderConfiguration, serviceContract, this))
59+
foreach (var typeToBindItsMethods in typesToBeIncludedInMethodsBinding)
5460
{
55-
if (ServerInvokerLookup.TryGetValue(op.MethodType, op.Context, op.Result, op.Void, out var invoker)
56-
&& AddMethod(op.From, op.To, op.Name, op.Method, op.MethodType, invoker, bindCtx,
57-
serviceContractSimplifiedExceptions || op.Method.IsDefined(typeof(SimpleRpcExceptionsAttribute))
58-
))
61+
var serviceContractSimplifiedExceptions = serviceImplSimplifiedExceptions ||
62+
typeToBindItsMethods.IsDefined(
63+
typeof(SimpleRpcExceptionsAttribute));
64+
var bindCtx = new ServiceBindContext(serviceContract, serviceType, state, binderConfiguration.Binder);
65+
foreach (var op in ContractOperation.FindOperations(binderConfiguration, typeToBindItsMethods, this))
5966
{
60-
// yay!
61-
totalCount++;
62-
svcOpCount++;
67+
if (ServerInvokerLookup.TryGetValue(op.MethodType, op.Context, op.Result, op.Void, out var invoker)
68+
&& AddMethod(serviceName, op.From, op.To, op.Name, op.Method, op.MethodType, invoker, bindCtx,
69+
serviceContractSimplifiedExceptions || op.Method.IsDefined(typeof(SimpleRpcExceptionsAttribute))
70+
))
71+
{
72+
// yay!
73+
totalCount++;
74+
svcOpCount++;
75+
}
6376
}
6477
}
78+
6579
OnServiceBound(state, serviceName!, serviceType, serviceContract, svcOpCount);
6680
}
81+
6782
return totalCount;
6883

69-
bool AddMethod(Type @in, Type @out, string on, MethodInfo m, MethodType t,
70-
Func<MethodInfo, ParameterExpression[], Expression>? invoker, ServiceBindContext bindContext, bool simplifiedExceptionHandling)
84+
bool AddMethod(string? serviceName, Type @in, Type @out, string on, MethodInfo m, MethodType t,
85+
Func<MethodInfo, ParameterExpression[], Expression>? invoker, ServiceBindContext bindContext,
86+
bool simplifiedExceptionHandling)
7187
{
7288
try
7389
{
7490
if (typesBuffer.Length == 0)
7591
{
76-
typesBuffer = new Type[] { serviceType, typeof(void), typeof(void) };
92+
typesBuffer = new Type[] {serviceType, typeof(void), typeof(void)};
7793
}
94+
7895
typesBuffer[1] = @in;
7996
typesBuffer[2] = @out;
8097

@@ -84,6 +101,7 @@ bool AddMethod(Type @in, Type @out, string on, MethodInfo m, MethodType t,
84101
argsBuffer[6] = binderConfiguration!.MarshallerCache;
85102
argsBuffer[7] = service is null ? null : Expression.Constant(service, serviceType);
86103
}
104+
87105
argsBuffer[0] = serviceName;
88106
argsBuffer[1] = on;
89107
argsBuffer[2] = m;
@@ -93,7 +111,7 @@ bool AddMethod(Type @in, Type @out, string on, MethodInfo m, MethodType t,
93111
// 6, 7 set during array initialization
94112
argsBuffer[8] = simplifiedExceptionHandling;
95113

96-
return (bool)s_addMethod.MakeGenericMethod(typesBuffer).Invoke(this, argsBuffer)!;
114+
return (bool) s_addMethod.MakeGenericMethod(typesBuffer).Invoke(this, argsBuffer)!;
97115
}
98116
catch (Exception fail)
99117
{
@@ -106,7 +124,10 @@ bool AddMethod(Type @in, Type @out, string on, MethodInfo m, MethodType t,
106124
/// <summary>
107125
/// Reports the number of operations available for a service
108126
/// </summary>
109-
protected virtual void OnServiceBound(object state, string serviceName, Type serviceType, Type serviceContract, int operationCount) { }
127+
protected virtual void OnServiceBound(object state, string serviceName, Type serviceType, Type serviceContract,
128+
int operationCount)
129+
{
130+
}
110131

111132
private static readonly MethodInfo s_addMethod = typeof(ServerBinder).GetMethod(
112133
nameof(AddMethod), BindingFlags.Instance | BindingFlags.NonPublic)!;
@@ -133,7 +154,8 @@ protected readonly struct MethodStub<TService>
133154
/// </summary>
134155
public MethodInfo Method { get; }
135156

136-
internal MethodStub(Func<MethodInfo, Expression[], Expression>? invoker, MethodInfo method, ConstantExpression? service, bool simpleExceptionHandling)
157+
internal MethodStub(Func<MethodInfo, Expression[], Expression>? invoker, MethodInfo method,
158+
ConstantExpression? service, bool simpleExceptionHandling)
137159
{
138160
_simpleExceptionHandling = simpleExceptionHandling;
139161
_invoker = invoker;
@@ -163,7 +185,7 @@ public TDelegate CreateDelegate<TDelegate>()
163185
else
164186
{
165187
// basic - direct call
166-
return (TDelegate)Delegate.CreateDelegate(typeof(TDelegate), _service, Method);
188+
return (TDelegate) Delegate.CreateDelegate(typeof(TDelegate), _service, Method);
167189
}
168190
}
169191
else
@@ -172,7 +194,8 @@ public TDelegate CreateDelegate<TDelegate>()
172194

173195
Expression[] mapArgs;
174196
if (_service is null)
175-
{ // if no service object, then the service is part of the signature, i.e. (svc, req) => svc.Blah();
197+
{
198+
// if no service object, then the service is part of the signature, i.e. (svc, req) => svc.Blah();
176199
mapArgs = lambdaArgs;
177200
}
178201
else
@@ -189,6 +212,7 @@ public TDelegate CreateDelegate<TDelegate>()
189212
{
190213
body = ApplySimpleExceptionHandling(body);
191214
}
215+
192216
var lambda = Expression.Lambda<TDelegate>(body, lambdaArgs);
193217

194218
return lambda.Compile();
@@ -204,17 +228,19 @@ static Expression ApplySimpleExceptionHandling(Expression body)
204228
}
205229
else if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Task<>))
206230
{
207-
body = Expression.Call(s_ReshapeWithSimpleExceptionHandling[1].MakeGenericMethod(type.GetGenericArguments()), body);
231+
body = Expression.Call(
232+
s_ReshapeWithSimpleExceptionHandling[1].MakeGenericMethod(type.GetGenericArguments()), body);
208233
}
234+
209235
return body;
210236
}
211237
}
212238

213239
#pragma warning disable CS0618
214240
private static readonly Dictionary<int, MethodInfo> s_ReshapeWithSimpleExceptionHandling =
215241
(from method in typeof(Reshape).GetMethods(BindingFlags.Public | BindingFlags.Static)
216-
where method.Name is nameof(Reshape.WithSimpleExceptionHandling)
217-
select method)
242+
where method.Name is nameof(Reshape.WithSimpleExceptionHandling)
243+
select method)
218244
.ToDictionary(method => method.IsGenericMethodDefinition ? method.GetGenericArguments().Length : 0);
219245
#pragma warning restore CS0618
220246

@@ -227,7 +253,8 @@ private bool AddMethod<TService, TRequest, TResponse>(
227253
where TRequest : class
228254
where TResponse : class
229255
{
230-
var grpcMethod = new Method<TRequest, TResponse>(methodType, serviceName, operationName, marshallerCache.GetMarshaller<TRequest>(), marshallerCache.GetMarshaller<TResponse>());
256+
var grpcMethod = new Method<TRequest, TResponse>(methodType, serviceName, operationName,
257+
marshallerCache.GetMarshaller<TRequest>(), marshallerCache.GetMarshaller<TResponse>());
231258
var stub = new MethodStub<TService>(invoker, method, service, simplfiedExceptionHandling);
232259
try
233260
{
@@ -238,13 +265,13 @@ private bool AddMethod<TService, TRequest, TResponse>(
238265
OnError(ex.Message);
239266
return false;
240267
}
241-
242268
}
243269

244270
/// <summary>
245271
/// The implementing binder should bind the method to the bind-state
246272
/// </summary>
247-
protected abstract bool TryBind<TService, TRequest, TResponse>(ServiceBindContext bindContext, Method<TRequest, TResponse> method, MethodStub<TService> stub)
273+
protected abstract bool TryBind<TService, TRequest, TResponse>(ServiceBindContext bindContext,
274+
Method<TRequest, TResponse> method, MethodStub<TService> stub)
248275
where TService : class
249276
where TRequest : class
250277
where TResponse : class;
@@ -255,12 +282,16 @@ protected abstract bool TryBind<TService, TRequest, TResponse>(ServiceBindContex
255282
/// <summary>
256283
/// Publish a warning message
257284
/// </summary>
258-
protected internal virtual void OnWarn(string message, object?[]? args = null) { }
285+
protected internal virtual void OnWarn(string message, object?[]? args = null)
286+
{
287+
}
259288

260289
/// <summary>
261290
/// Publish a warning message
262291
/// </summary>
263-
protected internal virtual void OnError(string message, object?[]? args = null) { }
292+
protected internal virtual void OnError(string message, object?[]? args = null)
293+
{
294+
}
264295

265296
/// <summary>
266297
/// Describes the relationship between a service contract and a service definition
@@ -304,4 +335,4 @@ public IList<object> GetMetadata(MethodInfo method)
304335
=> ServiceBinder.GetMetadata(method, ContractType, ServiceType);
305336
}
306337
}
307-
}
338+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using System;
2+
using System.ComponentModel;
3+
4+
namespace ProtoBuf.Grpc.Configuration
5+
{
6+
/// <summary>
7+
/// Indicates that this interface can be inherited by a gRPC service.
8+
/// All methods of this interface will be routed based on inherited service name.
9+
/// </summary>
10+
[AttributeUsage(AttributeTargets.Interface, AllowMultiple = false, Inherited = false)]
11+
[ImmutableObject(true)]
12+
public sealed class ServiceInheritableAttribute : Attribute
13+
{
14+
}
15+
}

src/protobuf-net.Grpc/Internal/ContractOperation.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,29 @@ internal static ISet<Type> ExpandInterfaces(Type type)
363363
if (type.IsInterface) set.Add(type);
364364
return set;
365365
}
366+
367+
/// <summary>
368+
/// Collect all the types to be used for extracting methods for a specific Service Contract
369+
/// </summary>
370+
/// <param name="serviceContract">Must be a service contract</param>
371+
/// <returns>types to be used for extracting methods</returns>
372+
internal static ISet<Type> ExpandWithInterfacesMarkedAsServiceInheritable(Type serviceContract)
373+
{
374+
var set = new HashSet<Type>();
375+
376+
// first add the service contract by itself
377+
set.Add(serviceContract);
378+
379+
// now add all inherited interfaces which are marked as inheritable
380+
foreach (var t in serviceContract.GetInterfaces())
381+
{
382+
if (t.IsDefined(typeof(ServiceInheritableAttribute)))
383+
{
384+
set.Add(t);
385+
}
386+
}
387+
return set;
388+
}
366389
}
367390

368391
internal enum ContextKind

tests/protobuf-net.Grpc.Reflection.Test/SchemaGeneration.cs

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ service MyService {
5959
}
6060

6161
[Fact]
62-
public void CheckIngeritedInterfaceSchema()
62+
public void CheckInheritedInterfaceSchema()
6363
{
6464
var generator = new SchemaGenerator();
6565
var schema = generator.GetSchema<IMyInheritedService>();
@@ -84,19 +84,13 @@ message MyResponse {
8484
string RefId = 3; // default value could not be applied: 00000000-0000-0000-0000-000000000000
8585
}
8686
service MyInheritedService {
87-
rpc AsyncEmpty (.google.protobuf.Empty) returns (.google.protobuf.Empty);
88-
rpc ClientStreaming (stream MyRequest) returns (MyResponse);
89-
rpc FullDuplex (stream MyRequest) returns (stream MyResponse);
9087
rpc GenericUnary (MyRequest) returns (MyResponse);
9188
rpc InheritedAsyncEmpty (.google.protobuf.Empty) returns (.google.protobuf.Empty);
9289
rpc InheritedClientStreaming (stream MyRequest) returns (MyResponse);
9390
rpc InheritedFullDuplex (stream MyRequest) returns (stream MyResponse);
9491
rpc InheritedServerStreaming (MyRequest) returns (stream MyResponse);
9592
rpc InheritedSyncEmpty (.google.protobuf.Empty) returns (.google.protobuf.Empty);
9693
rpc InheritedUnary (MyRequest) returns (MyResponse);
97-
rpc ServerStreaming (MyRequest) returns (stream MyResponse);
98-
rpc SyncEmpty (.google.protobuf.Empty) returns (.google.protobuf.Empty);
99-
rpc Unary (MyRequest) returns (MyResponse);
10094
}
10195
", schema, ignoreLineEndingDifferences: true);
10296
}
@@ -131,7 +125,7 @@ public interface INotAService
131125
[Theory]
132126
[InlineData(typeof(IMyService))]
133127
[InlineData(typeof(IMyInheritedService))]
134-
[InlineData(typeof(IMyAnotherLevelOfInheritedService))]
128+
[InlineData(typeof(IMyServiceInheritTwoLevelsOfHierarchy))]
135129
public void CompareRouteTable(Type type)
136130
{
137131
// 1: use the existing binder logic to build the routes, using the server logic
@@ -184,14 +178,14 @@ public string[] Collect()
184178
}
185179
}
186180

187-
[Service]
188-
public interface ISomeGenericService<in TGenericRequest, TGenericResult>
181+
[ServiceInheritable]
182+
public interface ISomeInheritableGenericService<in TGenericRequest, TGenericResult>
189183
{
190184
ValueTask<TGenericResult> GenericUnary(TGenericRequest request, CallContext callContext = default);
191185
}
192186

193187
[Service]
194-
public interface IMyInheritedService : IMyService, ISomeGenericService<MyRequest, MyResponse>, INotAService
188+
public interface IMyInheritedService : ISomeInheritableGenericService<MyRequest, MyResponse>, INotAService
195189
{
196190
ValueTask<MyResponse> InheritedUnary(MyRequest request, CallContext callContext = default);
197191
ValueTask<MyResponse> InheritedClientStreaming(IAsyncEnumerable<MyRequest> request, CallContext callContext = default);
@@ -202,9 +196,20 @@ public interface IMyInheritedService : IMyService, ISomeGenericService<MyRequest
202196
void InheritedSyncEmpty();
203197
}
204198

205-
199+
[ServiceInheritable]
200+
public interface ISecondLevelInheritable : ISomeInheritableGenericService<MyRequest, MyResponse>, INotAService
201+
{
202+
ValueTask<MyResponse> InheritedUnary(MyRequest request, CallContext callContext = default);
203+
ValueTask<MyResponse> InheritedClientStreaming(IAsyncEnumerable<MyRequest> request, CallContext callContext = default);
204+
IAsyncEnumerable<MyResponse> InheritedServerStreaming(MyRequest request, CallContext callContext = default);
205+
IAsyncEnumerable<MyResponse> InheritedFullDuplex(IAsyncEnumerable<MyRequest> request, CallContext callContext = default);
206+
207+
ValueTask InheritedAsyncEmpty();
208+
void InheritedSyncEmpty();
209+
}
210+
206211
[Service]
207-
public interface IMyAnotherLevelOfInheritedService : IMyInheritedService
212+
public interface IMyServiceInheritTwoLevelsOfHierarchy : ISecondLevelInheritable
208213
{
209214
ValueTask<MyResponse> AnotherMethod(MyRequest request, CallContext callContext = default);
210215
}

0 commit comments

Comments
 (0)