Skip to content

Commit 057de5c

Browse files
authored
investigate #251 - auto-detect marshaller patterns for Google.Protobuf types (#252)
* investigate #251 - auto-detect marshaller patterns for Google.Protobuf types * release notes
1 parent e4487ee commit 057de5c

File tree

5 files changed

+369
-7
lines changed

5 files changed

+369
-7
lines changed

docs/releasenotes.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## unreleased
44

5+
- automatically resolve Google.Protobuf `IMessage` types used in APIs
6+
7+
## 1.0.171
8+
59
- try to improve blazor linker support (i.e. avoid removal of necessary APIs)
610

711
## 1.0.136

src/protobuf-net.Grpc/Internal/MarshallerCache.cs

Lines changed: 101 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
using Grpc.Core;
22
using ProtoBuf.Grpc.Configuration;
33
using System;
4+
using System.Buffers;
45
using System.Collections.Concurrent;
6+
using System.Linq.Expressions;
57
using System.Reflection;
68
using System.Runtime.CompilerServices;
9+
using System.ServiceModel.Channels;
710

811
namespace ProtoBuf.Grpc.Internal
912
{
@@ -25,11 +28,11 @@ static bool SlowImpl(MarshallerCache obj, Type type)
2528

2629
private readonly ConcurrentDictionary<Type, object?> _marshallers
2730
= new ConcurrentDictionary<Type, object?>
28-
{
31+
{
2932
#pragma warning disable CS0618 // Empty
30-
[typeof(Empty)] = Empty.Marshaller
33+
[typeof(Empty)] = Empty.Marshaller
3134
#pragma warning restore CS0618
32-
};
35+
};
3336

3437
internal Marshaller<T> GetMarshaller<T>()
3538
{
@@ -55,17 +58,108 @@ internal void SetMarshaller<T>(Marshaller<T>? marshaller)
5558
private Marshaller<T>? CreateAndAdd<T>()
5659
{
5760
object? obj = CreateMarshaller<T>();
58-
if (!_marshallers.TryAdd(typeof(T), obj)) obj= _marshallers[typeof(T)];
61+
if (!_marshallers.TryAdd(typeof(T), obj)) obj = _marshallers[typeof(T)];
5962
return obj as Marshaller<T>;
6063
}
6164
private Marshaller<T>? CreateMarshaller<T>()
6265
{
6366
foreach (var factory in _factories)
6467
{
6568
if (factory.CanSerialize(typeof(T)))
66-
return factory.CreateMarshaller<T>();
69+
return factory.CreateMarshaller<T>();
70+
}
71+
return AutoDetectProtobufMarshaller();
72+
73+
static Type? FindIMessage(out Type? iBufferMessage)
74+
{
75+
Type? iMessage = null;
76+
iBufferMessage = null;
77+
foreach (var it in typeof(T).GetInterfaces())
78+
{
79+
if (it.Name == "IBufferMessage" && it.Namespace == "Google.Protobuf" && !it.IsGenericType)
80+
{
81+
iBufferMessage = it;
82+
}
83+
else if (it.Name == "IMessage" && it.Namespace == "Google.Protobuf" && !it.IsGenericType)
84+
{
85+
iMessage = it;
86+
}
87+
}
88+
return iMessage;
89+
}
90+
91+
// attempt to auto-detect the patterns exposed by Google.Protobuf types;
92+
// this is (by necessity) reflection-based and imperfect
93+
static Marshaller<T>? AutoDetectProtobufMarshaller()
94+
{
95+
try
96+
{
97+
if (typeof(T).GetProperty("Parser", BindingFlags.Public | BindingFlags.Static) is { } parser
98+
&& FindIMessage(out var iBufferMessage) is { } iMessage
99+
&& iMessage.Assembly.GetType("Google.Protobuf.MessageExtensions") is { } me)
100+
{
101+
Func<DeserializationContext, T> deserializer;
102+
Action<T, global::Grpc.Core.SerializationContext> serializer;
103+
104+
if (iBufferMessage is not null)
105+
{
106+
/* we want to generate:
107+
// write
108+
context.SetPayloadLength(message.CalculateSize());
109+
global::Google.Protobuf.MessageExtensions.WriteTo(message, context.GetBufferWriter());
110+
context.Complete();
111+
112+
// read
113+
parser.ParseFrom(context.PayloadAsReadOnlySequence()
114+
*/
115+
var context = Expression.Parameter(typeof(global::Grpc.Core.DeserializationContext), "context");
116+
var parseFrom = parser.PropertyType.GetMethod("ParseFrom", new Type[] { typeof(ReadOnlySequence<byte>) });
117+
Expression body = Expression.Call(Expression.Constant(parser.GetValue(null), parser.PropertyType),
118+
parseFrom, Expression.Call(context, nameof(DeserializationContext.PayloadAsReadOnlySequence), Type.EmptyTypes));
119+
deserializer = Expression.Lambda<Func<DeserializationContext, T>>(body, context).Compile();
120+
121+
var message = Expression.Parameter(typeof(T), "message");
122+
context = Expression.Parameter(typeof(global::Grpc.Core.SerializationContext), "context");
123+
var setPayloadLength = typeof(global::Grpc.Core.SerializationContext).GetMethod(nameof(global::Grpc.Core.SerializationContext.SetPayloadLength), new Type[] { typeof(int) });
124+
var calculateSize = iMessage.GetMethod("CalculateSize", Type.EmptyTypes);
125+
var writeTo = me.GetMethod("WriteTo", new Type[] { iMessage, typeof(IBufferWriter<byte>) });
126+
body = Expression.Block(
127+
Expression.Call(context, setPayloadLength, Expression.Call(message, calculateSize)),
128+
Expression.Call(writeTo, message, Expression.Call(context, "GetBufferWriter", Type.EmptyTypes)),
129+
Expression.Call(context, "Complete", Type.EmptyTypes)
130+
);
131+
serializer = Expression.Lambda<Action<T, global::Grpc.Core.SerializationContext>>(body, message, context).Compile();
132+
}
133+
else
134+
{
135+
/* we want to generate:
136+
// write
137+
context.Complete(global::Google.Protobuf.MessageExtensions.ToByteArray(message));
138+
139+
// read
140+
parser.ParseFrom(context.PayloadAsNewBuffer());
141+
*/
142+
143+
var context = Expression.Parameter(typeof(global::Grpc.Core.DeserializationContext), "context");
144+
var parseFrom = parser.PropertyType.GetMethod("ParseFrom", new Type[] { typeof(byte[]) });
145+
Expression body = Expression.Call(Expression.Constant(parser.GetValue(null), parser.PropertyType),
146+
parseFrom, Expression.Call(context, nameof(DeserializationContext.PayloadAsNewBuffer), Type.EmptyTypes));
147+
deserializer = Expression.Lambda<Func<DeserializationContext, T>>(body, context).Compile();
148+
149+
var message = Expression.Parameter(typeof(T), "message");
150+
context = Expression.Parameter(typeof(global::Grpc.Core.SerializationContext), "context");
151+
var toByteArray = me.GetMethod("ToByteArray", new Type[] { iMessage });
152+
var complete = typeof(global::Grpc.Core.SerializationContext).GetMethod(
153+
nameof(global::Grpc.Core.SerializationContext.Complete), new Type[] { typeof(byte[]) });
154+
body = Expression.Call(context, complete, Expression.Call(toByteArray, message));
155+
serializer = Expression.Lambda<Action<T, global::Grpc.Core.SerializationContext>>(body, message, context).Compile();
156+
}
157+
return new Marshaller<T>(serializer, deserializer);
158+
}
159+
}
160+
catch { } // this is very much a best-efforts thing
161+
return null;
67162
}
68-
return null;
69163
}
70164

71165
internal MarshallerFactory? TryGetFactory(Type type)
@@ -79,7 +173,7 @@ internal void SetMarshaller<T>(Marshaller<T>? marshaller)
79173
}
80174

81175
internal TFactory? TryGetFactory<TFactory>()
82-
where TFactory : MarshallerFactory
176+
where TFactory : MarshallerFactory
83177
{
84178
foreach (var factory in _factories)
85179
{
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
using automarshaller;
2+
using Grpc.Core;
3+
using ProtoBuf.Grpc.Configuration;
4+
using System;
5+
using Xunit;
6+
7+
namespace protobuf_net.Grpc.Test
8+
{
9+
public class AutoMarshaller
10+
{
11+
12+
/* from:
13+
syntax = "proto3";
14+
option csharp_namespace = "automarshaller";
15+
message Foo {
16+
int32 value = 1;
17+
}
18+
*/
19+
20+
public class UnknownType { }
21+
[Fact]
22+
public void GetMarshallerOnUnknownTypeFailsInExpectedWay()
23+
{
24+
var ex = Assert.Throws<InvalidOperationException>(
25+
() =>
26+
{
27+
BinderConfiguration.Default.GetMarshaller<UnknownType>();
28+
});
29+
Assert.Equal("No marshaller available for protobuf_net.Grpc.Test.AutoMarshaller+UnknownType", ex.Message);
30+
}
31+
32+
[Fact]
33+
public void CanAutoDetectProtobufMarshaller()
34+
{
35+
var sctx = new TestSerializationContext();
36+
var marshaller = BinderConfiguration.Default.GetMarshaller<Foo>();
37+
marshaller.ContextualSerializer(new Foo { Value = 42 }, sctx);
38+
var hex = BitConverter.ToString(sctx.Payload);
39+
Assert.Equal("08-2A", hex);
40+
var dctx = new TestDeserializationContext(sctx.Payload);
41+
var obj = marshaller.ContextualDeserializer(dctx);
42+
Assert.Equal(42, obj.Value);
43+
}
44+
class TestSerializationContext : SerializationContext
45+
{
46+
public byte[] Payload { get; set; } = Array.Empty<byte>();
47+
public override void Complete(byte[] payload) => Payload = payload;
48+
}
49+
internal class TestDeserializationContext : DeserializationContext
50+
{
51+
private byte[] _payload;
52+
53+
public TestDeserializationContext(byte[] payload) => _payload = payload;
54+
55+
public override int PayloadLength => _payload.Length;
56+
public override byte[] PayloadAsNewBuffer() => _payload;
57+
}
58+
}
59+
}

0 commit comments

Comments
 (0)