Skip to content

Commit e199bb6

Browse files
committed
limit Google.Protobuf support to a marshaller factory (to allow it to be disabled)
1 parent 057de5c commit e199bb6

File tree

5 files changed

+144
-98
lines changed

5 files changed

+144
-98
lines changed

src/protobuf-net.Grpc/Configuration/BinderConfiguration.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace ProtoBuf.Grpc.Configuration
1212
public sealed class BinderConfiguration
1313
{
1414
// this *must* stay above Default - .cctor order is file order
15-
static readonly MarshallerFactory[] s_defaultFactories = new MarshallerFactory[] { ProtoBufMarshallerFactory.Default };
15+
static readonly MarshallerFactory[] s_defaultFactories = new MarshallerFactory[] { ProtoBufMarshallerFactory.Default, ProtoBufMarshallerFactory.GoogleProtobuf };
1616

1717
/// <summary>
1818
/// Use the default MarshallerFactory and ServiceBinder
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
using Grpc.Core;
2+
using System;
3+
using System.Buffers;
4+
using System.Collections.Concurrent;
5+
using System.Linq.Expressions;
6+
using System.Reflection;
7+
8+
namespace ProtoBuf.Grpc.Configuration
9+
{
10+
internal sealed class GoogleProtobufMarshallerFactory : MarshallerFactory
11+
{
12+
internal static MarshallerFactory Default { get; } = new GoogleProtobufMarshallerFactory();
13+
14+
private GoogleProtobufMarshallerFactory() { }
15+
16+
protected internal override bool CanSerialize(Type type)
17+
{
18+
if (_knownTypes.TryGetValue(type, out var existing))
19+
{
20+
return existing is not null;
21+
}
22+
var created = s_Create.MakeGenericMethod(type).Invoke(null, null);
23+
_knownTypes[type] = created;
24+
return created is not null;
25+
}
26+
static readonly MethodInfo s_Create = typeof(GoogleProtobufMarshallerFactory).GetMethod(nameof(AutoDetectProtobufMarshaller), BindingFlags.Static | BindingFlags.NonPublic);
27+
28+
static ConcurrentDictionary<Type, object?> s_KnownTypes = new();
29+
protected internal override Marshaller<T> CreateMarshaller<T>()
30+
{
31+
if (_knownTypes.TryGetValue(typeof(T), out var existing))
32+
{
33+
return (Marshaller<T>)existing!;
34+
}
35+
var created = AutoDetectProtobufMarshaller<T>();
36+
_knownTypes[typeof(T)] = created;
37+
return created!;
38+
}
39+
40+
private static ConcurrentDictionary<Type, object?> _knownTypes = new ConcurrentDictionary<Type, object?>();
41+
42+
// attempt to auto-detect the patterns exposed by Google.Protobuf types;
43+
// this is (by necessity) reflection-based and imperfect
44+
static Marshaller<T>? AutoDetectProtobufMarshaller<T>()
45+
{
46+
try
47+
{
48+
if (typeof(T).GetProperty("Parser", BindingFlags.Public | BindingFlags.Static) is { } parser
49+
&& FindIMessage(out var iBufferMessage) is { } iMessage
50+
&& iMessage.Assembly.GetType("Google.Protobuf.MessageExtensions") is { } me)
51+
{
52+
Func<DeserializationContext, T> deserializer;
53+
Action<T, global::Grpc.Core.SerializationContext> serializer;
54+
55+
if (iBufferMessage is not null)
56+
{
57+
/* we want to generate:
58+
// write
59+
context.SetPayloadLength(message.CalculateSize());
60+
global::Google.Protobuf.MessageExtensions.WriteTo(message, context.GetBufferWriter());
61+
context.Complete();
62+
63+
// read
64+
parser.ParseFrom(context.PayloadAsReadOnlySequence()
65+
*/
66+
var context = Expression.Parameter(typeof(global::Grpc.Core.DeserializationContext), "context");
67+
var parseFrom = parser.PropertyType.GetMethod("ParseFrom", new Type[] { typeof(ReadOnlySequence<byte>) });
68+
Expression body = Expression.Call(Expression.Constant(parser.GetValue(null), parser.PropertyType),
69+
parseFrom, Expression.Call(context, nameof(DeserializationContext.PayloadAsReadOnlySequence), Type.EmptyTypes));
70+
deserializer = Expression.Lambda<Func<DeserializationContext, T>>(body, context).Compile();
71+
72+
var message = Expression.Parameter(typeof(T), "message");
73+
context = Expression.Parameter(typeof(global::Grpc.Core.SerializationContext), "context");
74+
var setPayloadLength = typeof(global::Grpc.Core.SerializationContext).GetMethod(nameof(global::Grpc.Core.SerializationContext.SetPayloadLength), new Type[] { typeof(int) });
75+
var calculateSize = iMessage.GetMethod("CalculateSize", Type.EmptyTypes);
76+
var writeTo = me.GetMethod("WriteTo", new Type[] { iMessage, typeof(IBufferWriter<byte>) });
77+
body = Expression.Block(
78+
Expression.Call(context, setPayloadLength, Expression.Call(message, calculateSize)),
79+
Expression.Call(writeTo, message, Expression.Call(context, "GetBufferWriter", Type.EmptyTypes)),
80+
Expression.Call(context, "Complete", Type.EmptyTypes)
81+
);
82+
serializer = Expression.Lambda<Action<T, global::Grpc.Core.SerializationContext>>(body, message, context).Compile();
83+
}
84+
else
85+
{
86+
/* we want to generate:
87+
// write
88+
context.Complete(global::Google.Protobuf.MessageExtensions.ToByteArray(message));
89+
90+
// read
91+
parser.ParseFrom(context.PayloadAsNewBuffer());
92+
*/
93+
94+
var context = Expression.Parameter(typeof(global::Grpc.Core.DeserializationContext), "context");
95+
var parseFrom = parser.PropertyType.GetMethod("ParseFrom", new Type[] { typeof(byte[]) });
96+
Expression body = Expression.Call(Expression.Constant(parser.GetValue(null), parser.PropertyType),
97+
parseFrom, Expression.Call(context, nameof(DeserializationContext.PayloadAsNewBuffer), Type.EmptyTypes));
98+
deserializer = Expression.Lambda<Func<DeserializationContext, T>>(body, context).Compile();
99+
100+
var message = Expression.Parameter(typeof(T), "message");
101+
context = Expression.Parameter(typeof(global::Grpc.Core.SerializationContext), "context");
102+
var toByteArray = me.GetMethod("ToByteArray", new Type[] { iMessage });
103+
var complete = typeof(global::Grpc.Core.SerializationContext).GetMethod(
104+
nameof(global::Grpc.Core.SerializationContext.Complete), new Type[] { typeof(byte[]) });
105+
body = Expression.Call(context, complete, Expression.Call(toByteArray, message));
106+
serializer = Expression.Lambda<Action<T, global::Grpc.Core.SerializationContext>>(body, message, context).Compile();
107+
}
108+
return new Marshaller<T>(serializer, deserializer);
109+
}
110+
}
111+
catch { } // this is very much a best-efforts thing
112+
return null;
113+
114+
static Type? FindIMessage(out Type? iBufferMessage)
115+
{
116+
Type? iMessage = null;
117+
iBufferMessage = null;
118+
foreach (var it in typeof(T).GetInterfaces())
119+
{
120+
if (it.Name == "IBufferMessage" && it.Namespace == "Google.Protobuf" && !it.IsGenericType)
121+
{
122+
iBufferMessage = it;
123+
}
124+
else if (it.Name == "IMessage" && it.Namespace == "Google.Protobuf" && !it.IsGenericType)
125+
{
126+
iMessage = it;
127+
}
128+
}
129+
return iMessage;
130+
}
131+
}
132+
}
133+
}

src/protobuf-net.Grpc/Configuration/ProtoBufMarshallerFactory.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,15 @@ public enum Options
3535
}
3636

3737
/// <summary>
38-
/// Uses the default protobuf-net serializer
38+
/// Uses the default protobuf-net serializer.
3939
/// </summary>
4040
public static MarshallerFactory Default { get; } = new ProtoBufMarshallerFactory(RuntimeTypeModel.Default, Options.None, default);
4141

42+
/// <summary>
43+
/// Provides support for <a href="https://www.nuget.org/packages/Google.Protobuf/">Google.Protobuf</a> types.
44+
/// </summary>
45+
public static MarshallerFactory GoogleProtobuf => GoogleProtobufMarshallerFactory.Default;
46+
4247
/// <summary>
4348
/// Gets the model used by this instance.
4449
/// </summary>

src/protobuf-net.Grpc/Internal/MarshallerCache.cs

Lines changed: 1 addition & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
using Grpc.Core;
22
using ProtoBuf.Grpc.Configuration;
33
using System;
4-
using System.Buffers;
54
using System.Collections.Concurrent;
6-
using System.Linq.Expressions;
75
using System.Reflection;
86
using System.Runtime.CompilerServices;
9-
using System.ServiceModel.Channels;
107

118
namespace ProtoBuf.Grpc.Internal
129
{
@@ -68,100 +65,8 @@ internal void SetMarshaller<T>(Marshaller<T>? marshaller)
6865
if (factory.CanSerialize(typeof(T)))
6966
return factory.CreateMarshaller<T>();
7067
}
71-
return AutoDetectProtobufMarshaller();
72-
73-
static Type? FindIMessage(out Type? iBufferMessage)
74-
{
75-
Type? iMessage = null;
76-
iBufferMessage = null;
77-
foreach (var it in typeof(T).GetInterfaces())
78-
{
79-
if (it.Name == "IBufferMessage" && it.Namespace == "Google.Protobuf" && !it.IsGenericType)
80-
{
81-
iBufferMessage = it;
82-
}
83-
else if (it.Name == "IMessage" && it.Namespace == "Google.Protobuf" && !it.IsGenericType)
84-
{
85-
iMessage = it;
86-
}
87-
}
88-
return iMessage;
89-
}
90-
91-
// attempt to auto-detect the patterns exposed by Google.Protobuf types;
92-
// this is (by necessity) reflection-based and imperfect
93-
static Marshaller<T>? AutoDetectProtobufMarshaller()
94-
{
95-
try
96-
{
97-
if (typeof(T).GetProperty("Parser", BindingFlags.Public | BindingFlags.Static) is { } parser
98-
&& FindIMessage(out var iBufferMessage) is { } iMessage
99-
&& iMessage.Assembly.GetType("Google.Protobuf.MessageExtensions") is { } me)
100-
{
101-
Func<DeserializationContext, T> deserializer;
102-
Action<T, global::Grpc.Core.SerializationContext> serializer;
103-
104-
if (iBufferMessage is not null)
105-
{
106-
/* we want to generate:
107-
// write
108-
context.SetPayloadLength(message.CalculateSize());
109-
global::Google.Protobuf.MessageExtensions.WriteTo(message, context.GetBufferWriter());
110-
context.Complete();
111-
112-
// read
113-
parser.ParseFrom(context.PayloadAsReadOnlySequence()
114-
*/
115-
var context = Expression.Parameter(typeof(global::Grpc.Core.DeserializationContext), "context");
116-
var parseFrom = parser.PropertyType.GetMethod("ParseFrom", new Type[] { typeof(ReadOnlySequence<byte>) });
117-
Expression body = Expression.Call(Expression.Constant(parser.GetValue(null), parser.PropertyType),
118-
parseFrom, Expression.Call(context, nameof(DeserializationContext.PayloadAsReadOnlySequence), Type.EmptyTypes));
119-
deserializer = Expression.Lambda<Func<DeserializationContext, T>>(body, context).Compile();
120-
121-
var message = Expression.Parameter(typeof(T), "message");
122-
context = Expression.Parameter(typeof(global::Grpc.Core.SerializationContext), "context");
123-
var setPayloadLength = typeof(global::Grpc.Core.SerializationContext).GetMethod(nameof(global::Grpc.Core.SerializationContext.SetPayloadLength), new Type[] { typeof(int) });
124-
var calculateSize = iMessage.GetMethod("CalculateSize", Type.EmptyTypes);
125-
var writeTo = me.GetMethod("WriteTo", new Type[] { iMessage, typeof(IBufferWriter<byte>) });
126-
body = Expression.Block(
127-
Expression.Call(context, setPayloadLength, Expression.Call(message, calculateSize)),
128-
Expression.Call(writeTo, message, Expression.Call(context, "GetBufferWriter", Type.EmptyTypes)),
129-
Expression.Call(context, "Complete", Type.EmptyTypes)
130-
);
131-
serializer = Expression.Lambda<Action<T, global::Grpc.Core.SerializationContext>>(body, message, context).Compile();
132-
}
133-
else
134-
{
135-
/* we want to generate:
136-
// write
137-
context.Complete(global::Google.Protobuf.MessageExtensions.ToByteArray(message));
138-
139-
// read
140-
parser.ParseFrom(context.PayloadAsNewBuffer());
141-
*/
142-
143-
var context = Expression.Parameter(typeof(global::Grpc.Core.DeserializationContext), "context");
144-
var parseFrom = parser.PropertyType.GetMethod("ParseFrom", new Type[] { typeof(byte[]) });
145-
Expression body = Expression.Call(Expression.Constant(parser.GetValue(null), parser.PropertyType),
146-
parseFrom, Expression.Call(context, nameof(DeserializationContext.PayloadAsNewBuffer), Type.EmptyTypes));
147-
deserializer = Expression.Lambda<Func<DeserializationContext, T>>(body, context).Compile();
148-
149-
var message = Expression.Parameter(typeof(T), "message");
150-
context = Expression.Parameter(typeof(global::Grpc.Core.SerializationContext), "context");
151-
var toByteArray = me.GetMethod("ToByteArray", new Type[] { iMessage });
152-
var complete = typeof(global::Grpc.Core.SerializationContext).GetMethod(
153-
nameof(global::Grpc.Core.SerializationContext.Complete), new Type[] { typeof(byte[]) });
154-
body = Expression.Call(context, complete, Expression.Call(toByteArray, message));
155-
serializer = Expression.Lambda<Action<T, global::Grpc.Core.SerializationContext>>(body, message, context).Compile();
156-
}
157-
return new Marshaller<T>(serializer, deserializer);
158-
}
159-
}
160-
catch { } // this is very much a best-efforts thing
161-
return null;
162-
}
68+
return null;
16369
}
164-
16570
internal MarshallerFactory? TryGetFactory(Type type)
16671
{
16772
foreach (var factory in _factories)

tests/protobuf-net.Grpc.Test/AutoMarshaller.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ public class UnknownType { }
2121
[Fact]
2222
public void GetMarshallerOnUnknownTypeFailsInExpectedWay()
2323
{
24+
Assert.False(BinderConfiguration.Default.MarshallerCache.CanSerializeType(typeof(UnknownType)));
2425
var ex = Assert.Throws<InvalidOperationException>(
2526
() =>
2627
{
@@ -33,7 +34,9 @@ public void GetMarshallerOnUnknownTypeFailsInExpectedWay()
3334
public void CanAutoDetectProtobufMarshaller()
3435
{
3536
var sctx = new TestSerializationContext();
37+
Assert.True(BinderConfiguration.Default.MarshallerCache.CanSerializeType(typeof(Foo)));
3638
var marshaller = BinderConfiguration.Default.GetMarshaller<Foo>();
39+
3740
marshaller.ContextualSerializer(new Foo { Value = 42 }, sctx);
3841
var hex = BitConverter.ToString(sctx.Payload);
3942
Assert.Equal("08-2A", hex);

0 commit comments

Comments
 (0)