|
| 1 | +using Grpc.Core; |
| 2 | +using System; |
| 3 | +using System.Buffers; |
| 4 | +using System.Collections.Concurrent; |
| 5 | +using System.Linq.Expressions; |
| 6 | +using System.Reflection; |
| 7 | + |
| 8 | +namespace ProtoBuf.Grpc.Configuration |
| 9 | +{ |
| 10 | + internal sealed class GoogleProtobufMarshallerFactory : MarshallerFactory |
| 11 | + { |
| 12 | + internal static MarshallerFactory Default { get; } = new GoogleProtobufMarshallerFactory(); |
| 13 | + |
| 14 | + private GoogleProtobufMarshallerFactory() { } |
| 15 | + |
| 16 | + protected internal override bool CanSerialize(Type type) |
| 17 | + { |
| 18 | + if (_knownTypes.TryGetValue(type, out var existing)) |
| 19 | + { |
| 20 | + return existing is not null; |
| 21 | + } |
| 22 | + var created = s_Create.MakeGenericMethod(type).Invoke(null, null); |
| 23 | + _knownTypes[type] = created; |
| 24 | + return created is not null; |
| 25 | + } |
| 26 | + static readonly MethodInfo s_Create = typeof(GoogleProtobufMarshallerFactory).GetMethod(nameof(AutoDetectProtobufMarshaller), BindingFlags.Static | BindingFlags.NonPublic); |
| 27 | + |
| 28 | + static ConcurrentDictionary<Type, object?> s_KnownTypes = new(); |
| 29 | + protected internal override Marshaller<T> CreateMarshaller<T>() |
| 30 | + { |
| 31 | + if (_knownTypes.TryGetValue(typeof(T), out var existing)) |
| 32 | + { |
| 33 | + return (Marshaller<T>)existing!; |
| 34 | + } |
| 35 | + var created = AutoDetectProtobufMarshaller<T>(); |
| 36 | + _knownTypes[typeof(T)] = created; |
| 37 | + return created!; |
| 38 | + } |
| 39 | + |
| 40 | + private static ConcurrentDictionary<Type, object?> _knownTypes = new ConcurrentDictionary<Type, object?>(); |
| 41 | + |
| 42 | + // attempt to auto-detect the patterns exposed by Google.Protobuf types; |
| 43 | + // this is (by necessity) reflection-based and imperfect |
| 44 | + static Marshaller<T>? AutoDetectProtobufMarshaller<T>() |
| 45 | + { |
| 46 | + try |
| 47 | + { |
| 48 | + if (typeof(T).GetProperty("Parser", BindingFlags.Public | BindingFlags.Static) is { } parser |
| 49 | + && FindIMessage(out var iBufferMessage) is { } iMessage |
| 50 | + && iMessage.Assembly.GetType("Google.Protobuf.MessageExtensions") is { } me) |
| 51 | + { |
| 52 | + Func<DeserializationContext, T> deserializer; |
| 53 | + Action<T, global::Grpc.Core.SerializationContext> serializer; |
| 54 | + |
| 55 | + if (iBufferMessage is not null) |
| 56 | + { |
| 57 | + /* we want to generate: |
| 58 | +// write |
| 59 | +context.SetPayloadLength(message.CalculateSize()); |
| 60 | +global::Google.Protobuf.MessageExtensions.WriteTo(message, context.GetBufferWriter()); |
| 61 | +context.Complete(); |
| 62 | +
|
| 63 | +// read |
| 64 | +parser.ParseFrom(context.PayloadAsReadOnlySequence() |
| 65 | +*/ |
| 66 | + var context = Expression.Parameter(typeof(global::Grpc.Core.DeserializationContext), "context"); |
| 67 | + var parseFrom = parser.PropertyType.GetMethod("ParseFrom", new Type[] { typeof(ReadOnlySequence<byte>) }); |
| 68 | + Expression body = Expression.Call(Expression.Constant(parser.GetValue(null), parser.PropertyType), |
| 69 | + parseFrom, Expression.Call(context, nameof(DeserializationContext.PayloadAsReadOnlySequence), Type.EmptyTypes)); |
| 70 | + deserializer = Expression.Lambda<Func<DeserializationContext, T>>(body, context).Compile(); |
| 71 | + |
| 72 | + var message = Expression.Parameter(typeof(T), "message"); |
| 73 | + context = Expression.Parameter(typeof(global::Grpc.Core.SerializationContext), "context"); |
| 74 | + var setPayloadLength = typeof(global::Grpc.Core.SerializationContext).GetMethod(nameof(global::Grpc.Core.SerializationContext.SetPayloadLength), new Type[] { typeof(int) }); |
| 75 | + var calculateSize = iMessage.GetMethod("CalculateSize", Type.EmptyTypes); |
| 76 | + var writeTo = me.GetMethod("WriteTo", new Type[] { iMessage, typeof(IBufferWriter<byte>) }); |
| 77 | + body = Expression.Block( |
| 78 | + Expression.Call(context, setPayloadLength, Expression.Call(message, calculateSize)), |
| 79 | + Expression.Call(writeTo, message, Expression.Call(context, "GetBufferWriter", Type.EmptyTypes)), |
| 80 | + Expression.Call(context, "Complete", Type.EmptyTypes) |
| 81 | + ); |
| 82 | + serializer = Expression.Lambda<Action<T, global::Grpc.Core.SerializationContext>>(body, message, context).Compile(); |
| 83 | + } |
| 84 | + else |
| 85 | + { |
| 86 | + /* we want to generate: |
| 87 | +// write |
| 88 | +context.Complete(global::Google.Protobuf.MessageExtensions.ToByteArray(message)); |
| 89 | +
|
| 90 | +// read |
| 91 | +parser.ParseFrom(context.PayloadAsNewBuffer()); |
| 92 | +*/ |
| 93 | + |
| 94 | + var context = Expression.Parameter(typeof(global::Grpc.Core.DeserializationContext), "context"); |
| 95 | + var parseFrom = parser.PropertyType.GetMethod("ParseFrom", new Type[] { typeof(byte[]) }); |
| 96 | + Expression body = Expression.Call(Expression.Constant(parser.GetValue(null), parser.PropertyType), |
| 97 | + parseFrom, Expression.Call(context, nameof(DeserializationContext.PayloadAsNewBuffer), Type.EmptyTypes)); |
| 98 | + deserializer = Expression.Lambda<Func<DeserializationContext, T>>(body, context).Compile(); |
| 99 | + |
| 100 | + var message = Expression.Parameter(typeof(T), "message"); |
| 101 | + context = Expression.Parameter(typeof(global::Grpc.Core.SerializationContext), "context"); |
| 102 | + var toByteArray = me.GetMethod("ToByteArray", new Type[] { iMessage }); |
| 103 | + var complete = typeof(global::Grpc.Core.SerializationContext).GetMethod( |
| 104 | + nameof(global::Grpc.Core.SerializationContext.Complete), new Type[] { typeof(byte[]) }); |
| 105 | + body = Expression.Call(context, complete, Expression.Call(toByteArray, message)); |
| 106 | + serializer = Expression.Lambda<Action<T, global::Grpc.Core.SerializationContext>>(body, message, context).Compile(); |
| 107 | + } |
| 108 | + return new Marshaller<T>(serializer, deserializer); |
| 109 | + } |
| 110 | + } |
| 111 | + catch { } // this is very much a best-efforts thing |
| 112 | + return null; |
| 113 | + |
| 114 | + static Type? FindIMessage(out Type? iBufferMessage) |
| 115 | + { |
| 116 | + Type? iMessage = null; |
| 117 | + iBufferMessage = null; |
| 118 | + foreach (var it in typeof(T).GetInterfaces()) |
| 119 | + { |
| 120 | + if (it.Name == "IBufferMessage" && it.Namespace == "Google.Protobuf" && !it.IsGenericType) |
| 121 | + { |
| 122 | + iBufferMessage = it; |
| 123 | + } |
| 124 | + else if (it.Name == "IMessage" && it.Namespace == "Google.Protobuf" && !it.IsGenericType) |
| 125 | + { |
| 126 | + iMessage = it; |
| 127 | + } |
| 128 | + } |
| 129 | + return iMessage; |
| 130 | + } |
| 131 | + } |
| 132 | + } |
| 133 | +} |
0 commit comments