From 01740cfd800d0fc70691e7fd2cf05826b9076dcc Mon Sep 17 00:00:00 2001 From: Kengwang Date: Wed, 17 Sep 2025 01:11:26 +0800 Subject: [PATCH 1/7] [Proto] Add reflection support for polymorphic serialization for derived classes Inspired by STJ's JsonDerivedType & JsonPolymorphic --- Lagrange.Proto.Test/ProtoPolymorphismTest.cs | 58 +++++++++++++++++ Lagrange.Proto/ProtoDerivedTypeAttribute.cs | 32 ++++++++++ Lagrange.Proto/ProtoPolymorphicAttribute.cs | 8 +++ .../Serialization/Metadata/ProtoObjectInfo.cs | 2 + .../Metadata/ProtoTypeResolver.Dynamic.cs | 62 +++++++++++++++---- .../ProtoSerializer.Deserialize.cs | 35 ++++++++++- .../ProtoSerializer.Serialize.cs | 26 +++++++- Lagrange.Proto/ThrowHelper.cs | 20 ++++++ 8 files changed, 227 insertions(+), 16 deletions(-) create mode 100644 Lagrange.Proto.Test/ProtoPolymorphismTest.cs create mode 100644 Lagrange.Proto/ProtoDerivedTypeAttribute.cs create mode 100644 Lagrange.Proto/ProtoPolymorphicAttribute.cs diff --git a/Lagrange.Proto.Test/ProtoPolymorphismTest.cs b/Lagrange.Proto.Test/ProtoPolymorphismTest.cs new file mode 100644 index 00000000..77b25883 --- /dev/null +++ b/Lagrange.Proto.Test/ProtoPolymorphismTest.cs @@ -0,0 +1,58 @@ +using Lagrange.Proto.Serialization; + +namespace Lagrange.Proto.Test; + +[TestFixture] +public class ProtoPolymorphismTest +{ + #region Basic Polymorphism + + [Test] + public void BasicPolymorphism_SerializeAndDeserialize_ReturnsCorrectDerivedType() + { + // Arrange + BaseClass original = new DerivedClassA { NameProperty = "TestName" }; + + byte[] bytes = ProtoSerializer.Serialize(original); + BaseClass deserialized = ProtoSerializer.Deserialize(bytes); + + Assert.Multiple(() => + { + Assert.That(deserialized, Is.AssignableTo()); + Assert.That(deserialized, Is.AssignableTo()); + Assert.That(deserialized.IdentifierProperty, Is.EqualTo(2)); + Assert.That(((DerivedClassA)deserialized).NameProperty, Is.EqualTo("TestName")); + }); + } + + #endregion +} + +#region Test Classes + +[ProtoPolymorphic(FieldNumber = 1)] +[ProtoDerivedType(typeof(DerivedClassA), 2)] +[ProtoDerivedType(typeof(DerivedClassB), 3)] +public class BaseClass +{ + public BaseClass() : this(-1) { } + + public BaseClass(int identifier) + { + IdentifierProperty = identifier; + } + + [ProtoMember(1)] public int IdentifierProperty { get; set; } +} + +public class DerivedClassA() : BaseClass(2) +{ + [ProtoMember(2)] public string NameProperty { get; set; } +} + +public class DerivedClassB() : BaseClass(3) +{ + [ProtoMember(2)] public float ValueProperty { get; set; } +} + +#endregion \ No newline at end of file diff --git a/Lagrange.Proto/ProtoDerivedTypeAttribute.cs b/Lagrange.Proto/ProtoDerivedTypeAttribute.cs new file mode 100644 index 00000000..ea6f235d --- /dev/null +++ b/Lagrange.Proto/ProtoDerivedTypeAttribute.cs @@ -0,0 +1,32 @@ +namespace Lagrange.Proto; + +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Interface, AllowMultiple = true, Inherited = false)] +public class ProtoDerivedTypeAttribute : Attribute +{ + public ProtoDerivedTypeAttribute(Type derivedType) + { + DerivedType = derivedType; + } + + public ProtoDerivedTypeAttribute(Type derivedType, string typeDiscriminator) + { + DerivedType = derivedType; + TypeDiscriminator = typeDiscriminator; + } + + public ProtoDerivedTypeAttribute(Type derivedType, int typeDiscriminator) + { + DerivedType = derivedType; + TypeDiscriminator = typeDiscriminator; + } + + /// + /// A derived type that should be supported in polymorphic serialization of the declared base type. + /// + public Type DerivedType { get; init; } + + /// + /// The type discriminator identifier to be used for the serialization of the subtype. + /// + public object? TypeDiscriminator { get; init; } +} \ No newline at end of file diff --git a/Lagrange.Proto/ProtoPolymorphicAttribute.cs b/Lagrange.Proto/ProtoPolymorphicAttribute.cs new file mode 100644 index 00000000..2397b6c5 --- /dev/null +++ b/Lagrange.Proto/ProtoPolymorphicAttribute.cs @@ -0,0 +1,8 @@ +namespace Lagrange.Proto; + + +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Interface, AllowMultiple = false, Inherited = false)] +public class ProtoPolymorphicAttribute : Attribute +{ + public uint FieldNumber { get; init; } +} \ No newline at end of file diff --git a/Lagrange.Proto/Serialization/Metadata/ProtoObjectInfo.cs b/Lagrange.Proto/Serialization/Metadata/ProtoObjectInfo.cs index 64ceb4f8..97e15e0b 100644 --- a/Lagrange.Proto/Serialization/Metadata/ProtoObjectInfo.cs +++ b/Lagrange.Proto/Serialization/Metadata/ProtoObjectInfo.cs @@ -10,4 +10,6 @@ public class ProtoObjectInfo public Func? ObjectCreator { get; init; } public bool IgnoreDefaultFields { get; init; } + public uint PolymorphicIndicateIndex { get; init; } = 0; + public Dictionary? objectCreator,Dictionary fields)>? PolymorphicFields { get; init; } } \ No newline at end of file diff --git a/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs b/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs index d60a1443..e4aba33d 100644 --- a/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs +++ b/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs @@ -66,35 +66,71 @@ internal static ProtoObjectInfo CreateObjectInfo() { var ctor = typeof(T).IsValueType ? null : typeof(T).GetConstructor(Type.EmptyTypes); bool ignoreDefaultFields = typeof(T).GetCustomAttribute()?.IgnoreDefaultFields == true; + var fields = CreateTypeFieldInfo(typeof(T)); + + var polymorphicAttributes = typeof(T).GetCustomAttributes().ToArray(); + var polymorphicFieldNumber = typeof(T).GetCustomAttribute()?.FieldNumber ?? 0; + + if (polymorphicAttributes.Length > 0) + { + if (polymorphicFieldNumber == 0) polymorphicFieldNumber = 1; // use first for default + var polymorphicFields = new Dictionary? objectCreator,Dictionary fields)>(); + foreach (var attr in polymorphicAttributes) + { + var key = attr.TypeDiscriminator; + if (key == null) ThrowHelper.ThrowInvalidOperationException_NullPolymorphicDiscriminator(attr.DerivedType); + var derivedFields = CreateTypeFieldInfo(attr.DerivedType); + if (polymorphicFields.ContainsKey(key)) ThrowHelper.ThrowInvalidOperationException_DuplicatePolymorphicDiscriminator(typeof(T), (int)polymorphicFieldNumber); + var derivedCtor = attr.DerivedType.IsValueType ? null : attr.DerivedType.GetConstructor(Type.EmptyTypes); + polymorphicFields[key] = (MemberAccessor.CreateParameterlessConstructor(derivedCtor), derivedFields); + } + + return new ProtoObjectInfo + { + ObjectCreator = MemberAccessor.CreateParameterlessConstructor(ctor), + IgnoreDefaultFields = ignoreDefaultFields, + PolymorphicIndicateIndex = polymorphicFieldNumber, + PolymorphicFields = polymorphicFields, + Fields = fields.OrderBy(x => x.Key).ToDictionary(x => x.Key, x => x.Value) + }; + } + + + + return new ProtoObjectInfo + { + ObjectCreator = MemberAccessor.CreateParameterlessConstructor(ctor), + IgnoreDefaultFields = ignoreDefaultFields, + Fields = fields.OrderBy(x => x.Key).ToDictionary(x => x.Key, x => x.Value) + }; + } + + internal static Dictionary CreateTypeFieldInfo(Type type) + { var fields = new Dictionary(); - foreach (var field in typeof(T).GetFields(BindingFlags.Public | BindingFlags.Instance)) + foreach (var field in type.GetFields(BindingFlags.Public | BindingFlags.Instance)) { if (field.IsStatic) continue; - var fieldInfo = CreateFieldInfo(typeof(T), field); + var fieldInfo = CreateFieldInfo(type, field); if (fieldInfo == null) continue; uint tag = ((uint)fieldInfo.Field << 3) | (byte)fieldInfo.WireType; - if (fields.ContainsKey(tag)) ThrowHelper.ThrowInvalidOperationException_DuplicateField(typeof(T), fieldInfo.Field); + if (fields.ContainsKey(tag)) ThrowHelper.ThrowInvalidOperationException_DuplicateField(type, fieldInfo.Field); fields[tag] = fieldInfo; } - foreach (var field in typeof(T).GetProperties(BindingFlags.Public | BindingFlags.Instance)) + foreach (var field in type.GetProperties(BindingFlags.Public | BindingFlags.Instance)) { - var fieldInfo = CreateFieldInfo(typeof(T), field); + var fieldInfo = CreateFieldInfo(type, field); if (fieldInfo == null) continue; uint tag = ((uint)fieldInfo.Field << 3) | (byte)fieldInfo.WireType; - if (fields.ContainsKey(tag)) ThrowHelper.ThrowInvalidOperationException_DuplicateField(typeof(T), fieldInfo.Field); + if (fields.ContainsKey(tag)) ThrowHelper.ThrowInvalidOperationException_DuplicateField(type, fieldInfo.Field); fields[tag] = fieldInfo; } - - return new ProtoObjectInfo - { - ObjectCreator = MemberAccessor.CreateParameterlessConstructor(ctor), - IgnoreDefaultFields = ignoreDefaultFields, - Fields = fields.OrderBy(x => x.Key).ToDictionary(x => x.Key, x => x.Value) - }; + + return fields; } diff --git a/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs b/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs index 417ed77a..71317a4b 100644 --- a/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs +++ b/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs @@ -84,10 +84,43 @@ private static T DeserializeProtoPackableCore(ref ProtoReader reader) where T var boxed = (object?)target; // avoid multiple times of boxing if (boxed is null) ThrowHelper.ThrowInvalidOperationException_CanNotCreateObject(typeof(T)); + var fieldInfos = converter.ObjectInfo.Fields; + + // polymorphic type + if (converter.ObjectInfo.PolymorphicIndicateIndex != 0) + { + // has polymorphic type, read the first field to determine the actual type + uint firstTag = reader.DecodeVarIntUnsafe(); + + if (firstTag >>> 3 != converter.ObjectInfo.PolymorphicIndicateIndex) + { + ThrowHelper.ThrowInvalidOperationException_PolymorphicFieldNotFirst(typeof(T), converter.ObjectInfo.PolymorphicIndicateIndex, firstTag >>> 3); + } + var firstField = converter.ObjectInfo.Fields[firstTag]; + firstField.Read(ref reader, boxed); + var polyTypeKey = firstField.Get?.Invoke(boxed); + if (polyTypeKey is null) + { + ThrowHelper.ThrowInvalidOperationException_FailedParsePolymorphicType(typeof(T), firstTag); + } + if (converter.ObjectInfo.PolymorphicFields?.TryGetValue(polyTypeKey, out var polyTypeInfo) is not true) + { + ThrowHelper.ThrowInvalidOperationException_UnknownPolymorphicType(typeof(T), polyTypeKey); + return default; // never reach this, make compiler happy + } + + fieldInfos = polyTypeInfo.fields; + Debug.Assert(polyTypeInfo.objectCreator != null); + target = polyTypeInfo.objectCreator(); + boxed = (object?)target; // boxing + if (boxed is null) ThrowHelper.ThrowInvalidOperationException_CanNotCreateObject(typeof(T)); + } + + while (!reader.IsCompleted) { uint tag = reader.DecodeVarIntUnsafe(); - if (converter.ObjectInfo.Fields.TryGetValue(tag, out var fieldInfo)) + if (fieldInfos.TryGetValue(tag, out var fieldInfo)) { fieldInfo.Read(ref reader, boxed); } diff --git a/Lagrange.Proto/Serialization/ProtoSerializer.Serialize.cs b/Lagrange.Proto/Serialization/ProtoSerializer.Serialize.cs index e182aecc..0bfc049d 100644 --- a/Lagrange.Proto/Serialization/ProtoSerializer.Serialize.cs +++ b/Lagrange.Proto/Serialization/ProtoSerializer.Serialize.cs @@ -124,10 +124,32 @@ private static void SerializeProtoPackableCore(ProtoWriter writer, T obj) whe var objectInfo = converter.ObjectInfo; object? boxed = obj; // avoid multiple times of boxing if (boxed is null) return; + var fields = objectInfo.Fields; + uint skipTag = 0; - foreach (var (tag, info) in objectInfo.Fields) + // check polymorphic type + if (converter.ObjectInfo.PolymorphicIndicateIndex != 0) { - if (info.ShouldSerialize(boxed, objectInfo.IgnoreDefaultFields)) + // has polymorphic type + var index = converter.ObjectInfo.PolymorphicIndicateIndex; + var fieldInfo = objectInfo.Fields.FirstOrDefault(t=>t.Value.Field == index); + if (fieldInfo.Value is null) ThrowHelper.ThrowInvalidOperationException_NullPolymorphicDiscriminator(typeof(T)); + var discriminator = fieldInfo.Value.Get?.Invoke(boxed); + if (discriminator is null) ThrowHelper.ThrowInvalidOperationException_NullPolymorphicDiscriminator(typeof(T)); + if (objectInfo.PolymorphicFields?.TryGetValue(discriminator, out var derivedTypeInfo) is not true) + { + ThrowHelper.ThrowInvalidOperationException_NullPolymorphicDiscriminator(typeof(T)); + return; // make compiler happy + } + skipTag = fieldInfo.Key; + writer.EncodeVarInt(fieldInfo.Key); + fieldInfo.Value.Write(writer, boxed); + fields = derivedTypeInfo.fields; + } + + foreach (var (tag, info) in fields) + { + if (skipTag != tag && info.ShouldSerialize(boxed, objectInfo.IgnoreDefaultFields)) { writer.EncodeVarInt(tag); info.Write(writer, boxed); diff --git a/Lagrange.Proto/ThrowHelper.cs b/Lagrange.Proto/ThrowHelper.cs index 491106c9..3e5c089b 100644 --- a/Lagrange.Proto/ThrowHelper.cs +++ b/Lagrange.Proto/ThrowHelper.cs @@ -66,4 +66,24 @@ public static void ThrowInvalidOperationException_NodeWrongType(params ReadOnlyS [DoesNotReturn] [MethodImpl(MethodImplOptions.NoInlining)] public static void ThrowInvalidOperationException_InvalidNodesWireType(string fieldName) => throw new InvalidOperationException($"The wire type must be explicitly set for field {fieldName} as the wire type for the ProtoNode, ProtoValue, and ProtoArray types is not known at compile time, to set the wire type, use the NodesWireType Property in ProtoMember attribute"); + + [DoesNotReturn] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void ThrowInvalidOperationException_NullPolymorphicDiscriminator(Type type) => throw new InvalidOperationException($"The polymorphic discriminator field for type {type.Name} cannot be null. Please ensure that the field is set and has a valid value."); + + [DoesNotReturn] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void ThrowInvalidOperationException_DuplicatePolymorphicDiscriminator(Type type, object key) => throw new InvalidOperationException($"The polymorphic discriminator key '{key}' for type {type.Name} is duplicated. Please ensure that the keys are unique."); + + [DoesNotReturn] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void ThrowInvalidOperationException_PolymorphicFieldNotFirst(Type type, uint expected, uint actual) => throw new InvalidOperationException($"The polymorphic discriminator field for type {type.Name} must be the first field in the message. Expected field number {expected}, but found {actual}."); + + [DoesNotReturn] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void ThrowInvalidOperationException_FailedParsePolymorphicType(Type type, uint index) => throw new InvalidOperationException($"Failed to parse the polymorphic type from proto for type {type.Name} at '{index}'."); + + [DoesNotReturn] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void ThrowInvalidOperationException_UnknownPolymorphicType(Type type, object polyTypeKey) => throw new InvalidOperationException($"Unknown polymorphic type '{polyTypeKey}' for base type {type.Name}. Please ensure that the polymorphic type is registered."); } \ No newline at end of file From a2ded7b6de247ef57f47dbb21c7cb501032f94ee Mon Sep 17 00:00:00 2001 From: Kengwang Date: Wed, 17 Sep 2025 01:25:10 +0800 Subject: [PATCH 2/7] [Proto] Allow fallback to base type --- Lagrange.Proto/ProtoPolymorphicAttribute.cs | 1 + .../Serialization/Metadata/ProtoObjectInfo.cs | 1 + .../Metadata/ProtoTypeResolver.Dynamic.cs | 5 +- .../ProtoSerializer.Deserialize.cs | 49 +++++++++++-------- 4 files changed, 34 insertions(+), 22 deletions(-) diff --git a/Lagrange.Proto/ProtoPolymorphicAttribute.cs b/Lagrange.Proto/ProtoPolymorphicAttribute.cs index 2397b6c5..3dde88d1 100644 --- a/Lagrange.Proto/ProtoPolymorphicAttribute.cs +++ b/Lagrange.Proto/ProtoPolymorphicAttribute.cs @@ -5,4 +5,5 @@ public class ProtoPolymorphicAttribute : Attribute { public uint FieldNumber { get; init; } + public bool FallbackToBaseType { get; init; } = true; } \ No newline at end of file diff --git a/Lagrange.Proto/Serialization/Metadata/ProtoObjectInfo.cs b/Lagrange.Proto/Serialization/Metadata/ProtoObjectInfo.cs index 97e15e0b..33a03ed1 100644 --- a/Lagrange.Proto/Serialization/Metadata/ProtoObjectInfo.cs +++ b/Lagrange.Proto/Serialization/Metadata/ProtoObjectInfo.cs @@ -11,5 +11,6 @@ public class ProtoObjectInfo public bool IgnoreDefaultFields { get; init; } public uint PolymorphicIndicateIndex { get; init; } = 0; + public bool PolymorphicFallbackToBaseType { get; init; } = true; public Dictionary? objectCreator,Dictionary fields)>? PolymorphicFields { get; init; } } \ No newline at end of file diff --git a/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs b/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs index e4aba33d..4688e59b 100644 --- a/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs +++ b/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs @@ -69,7 +69,9 @@ internal static ProtoObjectInfo CreateObjectInfo() var fields = CreateTypeFieldInfo(typeof(T)); var polymorphicAttributes = typeof(T).GetCustomAttributes().ToArray(); - var polymorphicFieldNumber = typeof(T).GetCustomAttribute()?.FieldNumber ?? 0; + var polymorphicConfigAttribute = typeof(T).GetCustomAttribute(); + var polymorphicFieldNumber = polymorphicConfigAttribute?.FieldNumber ?? 0; + var fallbackToBaseType = polymorphicConfigAttribute?.FallbackToBaseType ?? true; if (polymorphicAttributes.Length > 0) { @@ -90,6 +92,7 @@ internal static ProtoObjectInfo CreateObjectInfo() ObjectCreator = MemberAccessor.CreateParameterlessConstructor(ctor), IgnoreDefaultFields = ignoreDefaultFields, PolymorphicIndicateIndex = polymorphicFieldNumber, + PolymorphicFallbackToBaseType = fallbackToBaseType, PolymorphicFields = polymorphicFields, Fields = fields.OrderBy(x => x.Key).ToDictionary(x => x.Key, x => x.Value) }; diff --git a/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs b/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs index 71317a4b..097b2a42 100644 --- a/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs +++ b/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs @@ -19,12 +19,12 @@ public static T DeserializeProtoPackable(ReadOnlySpan data) where T : I var reader = new ProtoReader(data); return DeserializeProtoPackableCore(ref reader); } - + private static T DeserializeProtoPackableCore(ref ProtoReader reader) where T : IProtoSerializable { var objectInfo = T.TypeInfo; Debug.Assert(objectInfo.ObjectCreator != null); - + T target = objectInfo.ObjectCreator(); while (!reader.IsCompleted) @@ -39,10 +39,10 @@ private static T DeserializeProtoPackableCore(ref ProtoReader reader) where T reader.SkipField((WireType)(tag & 0x07)); } } - + return target; } - + /// /// Deserialize the ProtoPackable Object from the source buffer, based on reflection /// @@ -51,15 +51,17 @@ private static T DeserializeProtoPackableCore(ref ProtoReader reader) where T /// The deserialized object [RequiresUnreferencedCode(SerializationUnreferencedCodeMessage)] [RequiresDynamicCode(SerializationRequiresDynamicCodeMessage)] - public static T Deserialize<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.All)] T>(ReadOnlySpan data) + public static T Deserialize<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.All)] T>( + ReadOnlySpan data) { var reader = new ProtoReader(data); return DeserializeCore(ref reader); } - + [RequiresUnreferencedCode(SerializationUnreferencedCodeMessage)] [RequiresDynamicCode(SerializationRequiresDynamicCodeMessage)] - private static T DeserializeCore<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.All)] T>(ref ProtoReader reader) + private static T DeserializeCore<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.All)] T>( + ref ProtoReader reader) { ProtoObjectConverter converter; if (ProtoTypeResolver.IsRegistered()) @@ -78,6 +80,7 @@ private static T DeserializeProtoPackableCore(ref ProtoReader reader) where T { ProtoTypeResolver.Register(converter = new ProtoObjectConverter()); } + Debug.Assert(converter.ObjectInfo.ObjectCreator != null); T target = converter.ObjectInfo.ObjectCreator(); @@ -85,17 +88,19 @@ private static T DeserializeProtoPackableCore(ref ProtoReader reader) where T if (boxed is null) ThrowHelper.ThrowInvalidOperationException_CanNotCreateObject(typeof(T)); var fieldInfos = converter.ObjectInfo.Fields; - + // polymorphic type if (converter.ObjectInfo.PolymorphicIndicateIndex != 0) { // has polymorphic type, read the first field to determine the actual type uint firstTag = reader.DecodeVarIntUnsafe(); - + if (firstTag >>> 3 != converter.ObjectInfo.PolymorphicIndicateIndex) { - ThrowHelper.ThrowInvalidOperationException_PolymorphicFieldNotFirst(typeof(T), converter.ObjectInfo.PolymorphicIndicateIndex, firstTag >>> 3); + ThrowHelper.ThrowInvalidOperationException_PolymorphicFieldNotFirst(typeof(T), + converter.ObjectInfo.PolymorphicIndicateIndex, firstTag >>> 3); } + var firstField = converter.ObjectInfo.Fields[firstTag]; firstField.Read(ref reader, boxed); var polyTypeKey = firstField.Get?.Invoke(boxed); @@ -103,20 +108,22 @@ private static T DeserializeProtoPackableCore(ref ProtoReader reader) where T { ThrowHelper.ThrowInvalidOperationException_FailedParsePolymorphicType(typeof(T), firstTag); } - if (converter.ObjectInfo.PolymorphicFields?.TryGetValue(polyTypeKey, out var polyTypeInfo) is not true) + + if (converter.ObjectInfo.PolymorphicFields?.TryGetValue(polyTypeKey, out var polyTypeInfo) is true) + { + fieldInfos = polyTypeInfo.fields; + Debug.Assert(polyTypeInfo.objectCreator != null); + target = polyTypeInfo.objectCreator(); + boxed = (object?)target; // boxing + if (boxed is null) ThrowHelper.ThrowInvalidOperationException_CanNotCreateObject(typeof(T)); + } + else if (!converter.ObjectInfo.PolymorphicFallbackToBaseType) { ThrowHelper.ThrowInvalidOperationException_UnknownPolymorphicType(typeof(T), polyTypeKey); - return default; // never reach this, make compiler happy } - - fieldInfos = polyTypeInfo.fields; - Debug.Assert(polyTypeInfo.objectCreator != null); - target = polyTypeInfo.objectCreator(); - boxed = (object?)target; // boxing - if (boxed is null) ThrowHelper.ThrowInvalidOperationException_CanNotCreateObject(typeof(T)); } - - + + while (!reader.IsCompleted) { uint tag = reader.DecodeVarIntUnsafe(); @@ -129,7 +136,7 @@ private static T DeserializeProtoPackableCore(ref ProtoReader reader) where T reader.SkipField((WireType)(tag & 0x07)); } } - + return target; } } \ No newline at end of file From 3d45b86cb8446ddf0c5513f144e91398010163c0 Mon Sep 17 00:00:00 2001 From: Kengwang Date: Thu, 18 Sep 2025 02:18:48 +0800 Subject: [PATCH 3/7] [Proto] Use generic attribute to avoid boxing --- Lagrange.Proto.Test/ProtoPolymorphismTest.cs | 30 ++++--- Lagrange.Proto/ProtoDerivedTypeAttribute.cs | 29 +++---- Lagrange.Proto/ProtoPolymorphicAttribute.cs | 2 +- .../Serialization/Metadata/ProtoObjectInfo.cs | 8 +- .../Metadata/ProtoPolymorphicInfo.cs | 41 ++++++++++ .../Metadata/ProtoTypeResolver.Dynamic.cs | 81 +++++++++++-------- .../ProtoSerializer.Deserialize.cs | 16 ++-- .../ProtoSerializer.Serialize.cs | 8 +- 8 files changed, 139 insertions(+), 76 deletions(-) create mode 100644 Lagrange.Proto/Serialization/Metadata/ProtoPolymorphicInfo.cs diff --git a/Lagrange.Proto.Test/ProtoPolymorphismTest.cs b/Lagrange.Proto.Test/ProtoPolymorphismTest.cs index 77b25883..c0a83a38 100644 --- a/Lagrange.Proto.Test/ProtoPolymorphismTest.cs +++ b/Lagrange.Proto.Test/ProtoPolymorphismTest.cs @@ -11,17 +11,27 @@ public class ProtoPolymorphismTest public void BasicPolymorphism_SerializeAndDeserialize_ReturnsCorrectDerivedType() { // Arrange - BaseClass original = new DerivedClassA { NameProperty = "TestName" }; + BaseClass originalA = new DerivedClassA { NameProperty = "TestName" }; + BaseClass originalB = new DerivedClassB { ValueProperty = 114514f }; - byte[] bytes = ProtoSerializer.Serialize(original); - BaseClass deserialized = ProtoSerializer.Deserialize(bytes); + byte[] bytesA = ProtoSerializer.Serialize(originalA); + BaseClass deserializedA = ProtoSerializer.Deserialize(bytesA); + + byte[] bytesB = ProtoSerializer.Serialize(originalB); + BaseClass deserializedB = ProtoSerializer.Deserialize(bytesB); Assert.Multiple(() => { - Assert.That(deserialized, Is.AssignableTo()); - Assert.That(deserialized, Is.AssignableTo()); - Assert.That(deserialized.IdentifierProperty, Is.EqualTo(2)); - Assert.That(((DerivedClassA)deserialized).NameProperty, Is.EqualTo("TestName")); + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA.IdentifierProperty, Is.EqualTo(2)); + Assert.That(((DerivedClassA)deserializedA).NameProperty, Is.EqualTo("TestName")); + + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB.IdentifierProperty, Is.EqualTo(3)); + Assert.That(((DerivedClassB)deserializedB).ValueProperty, Is.EqualTo(114514f)); + }); } @@ -31,8 +41,8 @@ public void BasicPolymorphism_SerializeAndDeserialize_ReturnsCorrectDerivedType( #region Test Classes [ProtoPolymorphic(FieldNumber = 1)] -[ProtoDerivedType(typeof(DerivedClassA), 2)] -[ProtoDerivedType(typeof(DerivedClassB), 3)] +[ProtoDerivedType(typeof(DerivedClassA), 2)] +[ProtoDerivedType(typeof(DerivedClassB), 3)] public class BaseClass { public BaseClass() : this(-1) { } @@ -52,7 +62,7 @@ public class DerivedClassA() : BaseClass(2) public class DerivedClassB() : BaseClass(3) { - [ProtoMember(2)] public float ValueProperty { get; set; } + [ProtoMember(2)] public float ValueProperty { get; set; } = 0f; } #endregion \ No newline at end of file diff --git a/Lagrange.Proto/ProtoDerivedTypeAttribute.cs b/Lagrange.Proto/ProtoDerivedTypeAttribute.cs index ea6f235d..70f42ba0 100644 --- a/Lagrange.Proto/ProtoDerivedTypeAttribute.cs +++ b/Lagrange.Proto/ProtoDerivedTypeAttribute.cs @@ -1,32 +1,29 @@ namespace Lagrange.Proto; [AttributeUsage(AttributeTargets.Class | AttributeTargets.Interface, AllowMultiple = true, Inherited = false)] -public class ProtoDerivedTypeAttribute : Attribute +public class ProtoDerivedTypeAttribute : ProtoDerivedTypeAttribute where T : IEquatable { - public ProtoDerivedTypeAttribute(Type derivedType) - { - DerivedType = derivedType; - } - - public ProtoDerivedTypeAttribute(Type derivedType, string typeDiscriminator) + public ProtoDerivedTypeAttribute(Type derivedType, T typeDiscriminator) : base(derivedType) { - DerivedType = derivedType; TypeDiscriminator = typeDiscriminator; } - - public ProtoDerivedTypeAttribute(Type derivedType, int typeDiscriminator) + + /// + /// The type discriminator identifier to be used for the serialization of the subtype. + /// + public T TypeDiscriminator { get; init; } +} + +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Interface, AllowMultiple = true, Inherited = false)] +public class ProtoDerivedTypeAttribute : Attribute +{ + public ProtoDerivedTypeAttribute(Type derivedType) { DerivedType = derivedType; - TypeDiscriminator = typeDiscriminator; } /// /// A derived type that should be supported in polymorphic serialization of the declared base type. /// public Type DerivedType { get; init; } - - /// - /// The type discriminator identifier to be used for the serialization of the subtype. - /// - public object? TypeDiscriminator { get; init; } } \ No newline at end of file diff --git a/Lagrange.Proto/ProtoPolymorphicAttribute.cs b/Lagrange.Proto/ProtoPolymorphicAttribute.cs index 3dde88d1..e9b29010 100644 --- a/Lagrange.Proto/ProtoPolymorphicAttribute.cs +++ b/Lagrange.Proto/ProtoPolymorphicAttribute.cs @@ -1,7 +1,7 @@ namespace Lagrange.Proto; -[AttributeUsage(AttributeTargets.Class | AttributeTargets.Interface, AllowMultiple = false, Inherited = false)] +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Interface, Inherited = false)] public class ProtoPolymorphicAttribute : Attribute { public uint FieldNumber { get; init; } diff --git a/Lagrange.Proto/Serialization/Metadata/ProtoObjectInfo.cs b/Lagrange.Proto/Serialization/Metadata/ProtoObjectInfo.cs index 33a03ed1..ac1f421e 100644 --- a/Lagrange.Proto/Serialization/Metadata/ProtoObjectInfo.cs +++ b/Lagrange.Proto/Serialization/Metadata/ProtoObjectInfo.cs @@ -6,11 +6,9 @@ namespace Lagrange.Proto.Serialization.Metadata; public class ProtoObjectInfo { public Dictionary Fields { get; init; } = new(); - + public Func? ObjectCreator { get; init; } - + public bool IgnoreDefaultFields { get; init; } - public uint PolymorphicIndicateIndex { get; init; } = 0; - public bool PolymorphicFallbackToBaseType { get; init; } = true; - public Dictionary? objectCreator,Dictionary fields)>? PolymorphicFields { get; init; } + public ProtoPolymorphicInfoBase? PolymorphicInfo { get; init; } } \ No newline at end of file diff --git a/Lagrange.Proto/Serialization/Metadata/ProtoPolymorphicInfo.cs b/Lagrange.Proto/Serialization/Metadata/ProtoPolymorphicInfo.cs new file mode 100644 index 00000000..49103fc0 --- /dev/null +++ b/Lagrange.Proto/Serialization/Metadata/ProtoPolymorphicInfo.cs @@ -0,0 +1,41 @@ +namespace Lagrange.Proto.Serialization.Metadata; + + +public class ProtoPolymorphicInfoBase +{ + public uint PolymorphicIndicateIndex { get; set; } = 0; + public bool PolymorphicFallbackToBaseType { get; set; } = true; + + public virtual ProtoPolymorphicDerivedTypeInfo? GetTypeFromDiscriminator(object discriminator) + { + return null; + } + + public virtual bool SetTypeDiscriminator(object discriminator, ProtoPolymorphicDerivedTypeInfo info) + { + return false; + } +} + +public class ProtoPolymorphicDerivedTypeInfo +{ + public required Type DerivedType { get; init; } + public Func? ObjectCreator { get; init; } + public Dictionary Fields { get; init; } = new(); +} + +public class ProtoPolymorphicObjectInfo : ProtoPolymorphicInfoBase where TKey : IEquatable +{ + public override ProtoPolymorphicDerivedTypeInfo? GetTypeFromDiscriminator(object discriminator) + { + return PolymorphicDerivedTypes.GetValueOrDefault((TKey)discriminator); + } + + public override bool SetTypeDiscriminator(object discriminator, ProtoPolymorphicDerivedTypeInfo info) + { + PolymorphicDerivedTypes[(TKey)discriminator] = info; + return true; + } + + public Dictionary> PolymorphicDerivedTypes { get; } = []; +} \ No newline at end of file diff --git a/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs b/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs index 4688e59b..b9550828 100644 --- a/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs +++ b/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs @@ -64,48 +64,65 @@ static MemberAccessor Initialize() [RequiresDynamicCode(ProtoSerializer.SerializationRequiresDynamicCodeMessage)] internal static ProtoObjectInfo CreateObjectInfo() { - var ctor = typeof(T).IsValueType ? null : typeof(T).GetConstructor(Type.EmptyTypes); - bool ignoreDefaultFields = typeof(T).GetCustomAttribute()?.IgnoreDefaultFields == true; - var fields = CreateTypeFieldInfo(typeof(T)); - - var polymorphicAttributes = typeof(T).GetCustomAttributes().ToArray(); - var polymorphicConfigAttribute = typeof(T).GetCustomAttribute(); - var polymorphicFieldNumber = polymorphicConfigAttribute?.FieldNumber ?? 0; - var fallbackToBaseType = polymorphicConfigAttribute?.FallbackToBaseType ?? true; + var objType = typeof(T); + var ctor = objType.IsValueType ? null : objType.GetConstructor(Type.EmptyTypes); + bool ignoreDefaultFields = objType.GetCustomAttribute()?.IgnoreDefaultFields == true; + var fields = CreateTypeFieldInfo(objType); + return new ProtoObjectInfo + { + ObjectCreator = MemberAccessor.CreateParameterlessConstructor(ctor), + IgnoreDefaultFields = ignoreDefaultFields, + Fields = fields.OrderBy(x => x.Key).ToDictionary(x => x.Key, x => x.Value), + PolymorphicInfo = PopulatePolymorphicInfo() + }; + } + + internal static ProtoPolymorphicInfoBase? PopulatePolymorphicInfo() + { + var type = typeof(T); + var polymorphicAttributes = type.GetCustomAttributes(typeof(ProtoDerivedTypeAttribute<>)) + .OfType().ToArray(); if (polymorphicAttributes.Length > 0) { + var polymorphicConfigAttribute = type.GetCustomAttribute(); + var polymorphicFieldNumber = polymorphicConfigAttribute?.FieldNumber ?? 0; + var fallbackToBaseType = polymorphicConfigAttribute?.FallbackToBaseType ?? true; if (polymorphicFieldNumber == 0) polymorphicFieldNumber = 1; // use first for default - var polymorphicFields = new Dictionary? objectCreator,Dictionary fields)>(); + + // get the TKey from first + var firstAttr = polymorphicAttributes[0]; + var keyType = firstAttr.GetType().GetGenericArguments()[0]; + var objectInfo = MemberAccessor.CreateParameterlessConstructor>( + typeof(ProtoPolymorphicObjectInfo<,>).MakeGenericType(typeof(T), keyType) + .GetConstructor(Type.EmptyTypes))?.Invoke(); + + Debug.Assert(objectInfo != null); + objectInfo.PolymorphicIndicateIndex = polymorphicFieldNumber; + objectInfo.PolymorphicFallbackToBaseType = fallbackToBaseType; + foreach (var attr in polymorphicAttributes) { - var key = attr.TypeDiscriminator; - if (key == null) ThrowHelper.ThrowInvalidOperationException_NullPolymorphicDiscriminator(attr.DerivedType); var derivedFields = CreateTypeFieldInfo(attr.DerivedType); - if (polymorphicFields.ContainsKey(key)) ThrowHelper.ThrowInvalidOperationException_DuplicatePolymorphicDiscriminator(typeof(T), (int)polymorphicFieldNumber); - var derivedCtor = attr.DerivedType.IsValueType ? null : attr.DerivedType.GetConstructor(Type.EmptyTypes); - polymorphicFields[key] = (MemberAccessor.CreateParameterlessConstructor(derivedCtor), derivedFields); + var derivedCtor = attr.DerivedType.IsValueType + ? null + : attr.DerivedType.GetConstructor(Type.EmptyTypes); + var key = attr.GetType().GetProperty(nameof(ProtoDerivedTypeAttribute.TypeDiscriminator)) + ?.GetValue(attr); + if (key == null) ThrowHelper.ThrowInvalidOperationException_UnknownPolymorphicType(type, attr.DerivedType); + objectInfo.SetTypeDiscriminator(key, + new ProtoPolymorphicDerivedTypeInfo + { + DerivedType = attr.DerivedType, + ObjectCreator = MemberAccessor.CreateParameterlessConstructor(derivedCtor), + Fields = derivedFields + }); } - - return new ProtoObjectInfo - { - ObjectCreator = MemberAccessor.CreateParameterlessConstructor(ctor), - IgnoreDefaultFields = ignoreDefaultFields, - PolymorphicIndicateIndex = polymorphicFieldNumber, - PolymorphicFallbackToBaseType = fallbackToBaseType, - PolymorphicFields = polymorphicFields, - Fields = fields.OrderBy(x => x.Key).ToDictionary(x => x.Key, x => x.Value) - }; + + return objectInfo; } - - - return new ProtoObjectInfo - { - ObjectCreator = MemberAccessor.CreateParameterlessConstructor(ctor), - IgnoreDefaultFields = ignoreDefaultFields, - Fields = fields.OrderBy(x => x.Key).ToDictionary(x => x.Key, x => x.Value) - }; + return null; } internal static Dictionary CreateTypeFieldInfo(Type type) diff --git a/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs b/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs index 097b2a42..df2b8893 100644 --- a/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs +++ b/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs @@ -90,15 +90,15 @@ private static T DeserializeProtoPackableCore(ref ProtoReader reader) where T var fieldInfos = converter.ObjectInfo.Fields; // polymorphic type - if (converter.ObjectInfo.PolymorphicIndicateIndex != 0) + if (converter.ObjectInfo.PolymorphicInfo?.PolymorphicIndicateIndex is > 0) { // has polymorphic type, read the first field to determine the actual type uint firstTag = reader.DecodeVarIntUnsafe(); - if (firstTag >>> 3 != converter.ObjectInfo.PolymorphicIndicateIndex) + if (firstTag >>> 3 != converter.ObjectInfo.PolymorphicInfo.PolymorphicIndicateIndex) { ThrowHelper.ThrowInvalidOperationException_PolymorphicFieldNotFirst(typeof(T), - converter.ObjectInfo.PolymorphicIndicateIndex, firstTag >>> 3); + converter.ObjectInfo.PolymorphicInfo.PolymorphicIndicateIndex, firstTag >>> 3); } var firstField = converter.ObjectInfo.Fields[firstTag]; @@ -109,15 +109,15 @@ private static T DeserializeProtoPackableCore(ref ProtoReader reader) where T ThrowHelper.ThrowInvalidOperationException_FailedParsePolymorphicType(typeof(T), firstTag); } - if (converter.ObjectInfo.PolymorphicFields?.TryGetValue(polyTypeKey, out var polyTypeInfo) is true) + if (converter.ObjectInfo.PolymorphicInfo.GetTypeFromDiscriminator(polyTypeKey) is { } polyTypeInfo) { - fieldInfos = polyTypeInfo.fields; - Debug.Assert(polyTypeInfo.objectCreator != null); - target = polyTypeInfo.objectCreator(); + fieldInfos = polyTypeInfo.Fields; + Debug.Assert(polyTypeInfo.ObjectCreator != null); + target = polyTypeInfo.ObjectCreator(); boxed = (object?)target; // boxing if (boxed is null) ThrowHelper.ThrowInvalidOperationException_CanNotCreateObject(typeof(T)); } - else if (!converter.ObjectInfo.PolymorphicFallbackToBaseType) + else if (!converter.ObjectInfo.PolymorphicInfo.PolymorphicFallbackToBaseType) { ThrowHelper.ThrowInvalidOperationException_UnknownPolymorphicType(typeof(T), polyTypeKey); } diff --git a/Lagrange.Proto/Serialization/ProtoSerializer.Serialize.cs b/Lagrange.Proto/Serialization/ProtoSerializer.Serialize.cs index 0bfc049d..831e3737 100644 --- a/Lagrange.Proto/Serialization/ProtoSerializer.Serialize.cs +++ b/Lagrange.Proto/Serialization/ProtoSerializer.Serialize.cs @@ -128,15 +128,15 @@ private static void SerializeProtoPackableCore(ProtoWriter writer, T obj) whe uint skipTag = 0; // check polymorphic type - if (converter.ObjectInfo.PolymorphicIndicateIndex != 0) + if (converter.ObjectInfo.PolymorphicInfo?.PolymorphicIndicateIndex is > 0) { // has polymorphic type - var index = converter.ObjectInfo.PolymorphicIndicateIndex; + var index = converter.ObjectInfo.PolymorphicInfo.PolymorphicIndicateIndex; var fieldInfo = objectInfo.Fields.FirstOrDefault(t=>t.Value.Field == index); if (fieldInfo.Value is null) ThrowHelper.ThrowInvalidOperationException_NullPolymorphicDiscriminator(typeof(T)); var discriminator = fieldInfo.Value.Get?.Invoke(boxed); if (discriminator is null) ThrowHelper.ThrowInvalidOperationException_NullPolymorphicDiscriminator(typeof(T)); - if (objectInfo.PolymorphicFields?.TryGetValue(discriminator, out var derivedTypeInfo) is not true) + if (objectInfo.PolymorphicInfo!.GetTypeFromDiscriminator(discriminator) is not { } derivedTypeInfo) { ThrowHelper.ThrowInvalidOperationException_NullPolymorphicDiscriminator(typeof(T)); return; // make compiler happy @@ -144,7 +144,7 @@ private static void SerializeProtoPackableCore(ProtoWriter writer, T obj) whe skipTag = fieldInfo.Key; writer.EncodeVarInt(fieldInfo.Key); fieldInfo.Value.Write(writer, boxed); - fields = derivedTypeInfo.fields; + fields = derivedTypeInfo.Fields; } foreach (var (tag, info) in fields) From c4c5d0c75868ec6e776d5b4a370a84e925272505 Mon Sep 17 00:00:00 2001 From: Kengwang Date: Thu, 18 Sep 2025 03:27:31 +0800 Subject: [PATCH 4/7] [Proto] Polymorphic derived type reuse converter logic --- .../Metadata/ProtoPolymorphicInfo.cs | 18 +++---- .../Metadata/ProtoTypeResolver.Dynamic.cs | 12 +---- .../ProtoSerializer.Deserialize.cs | 30 +++--------- .../Serialization/ProtoSerializer.Helpers.cs | 49 +++++++++++++++++++ .../ProtoSerializer.Serialize.cs | 24 ++------- 5 files changed, 66 insertions(+), 67 deletions(-) diff --git a/Lagrange.Proto/Serialization/Metadata/ProtoPolymorphicInfo.cs b/Lagrange.Proto/Serialization/Metadata/ProtoPolymorphicInfo.cs index 49103fc0..3d296a55 100644 --- a/Lagrange.Proto/Serialization/Metadata/ProtoPolymorphicInfo.cs +++ b/Lagrange.Proto/Serialization/Metadata/ProtoPolymorphicInfo.cs @@ -6,36 +6,30 @@ public class ProtoPolymorphicInfoBase public uint PolymorphicIndicateIndex { get; set; } = 0; public bool PolymorphicFallbackToBaseType { get; set; } = true; - public virtual ProtoPolymorphicDerivedTypeInfo? GetTypeFromDiscriminator(object discriminator) + public virtual Type? GetTypeFromDiscriminator(object discriminator) { return null; } - public virtual bool SetTypeDiscriminator(object discriminator, ProtoPolymorphicDerivedTypeInfo info) + public virtual bool SetTypeDiscriminator(object discriminator, Type type) { return false; } } -public class ProtoPolymorphicDerivedTypeInfo -{ - public required Type DerivedType { get; init; } - public Func? ObjectCreator { get; init; } - public Dictionary Fields { get; init; } = new(); -} public class ProtoPolymorphicObjectInfo : ProtoPolymorphicInfoBase where TKey : IEquatable { - public override ProtoPolymorphicDerivedTypeInfo? GetTypeFromDiscriminator(object discriminator) + public override Type? GetTypeFromDiscriminator(object discriminator) { return PolymorphicDerivedTypes.GetValueOrDefault((TKey)discriminator); } - public override bool SetTypeDiscriminator(object discriminator, ProtoPolymorphicDerivedTypeInfo info) + public override bool SetTypeDiscriminator(object discriminator, Type type) { - PolymorphicDerivedTypes[(TKey)discriminator] = info; + PolymorphicDerivedTypes[(TKey)discriminator] = type; return true; } - public Dictionary> PolymorphicDerivedTypes { get; } = []; + public Dictionary PolymorphicDerivedTypes { get; } = []; } \ No newline at end of file diff --git a/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs b/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs index b9550828..a2f4d906 100644 --- a/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs +++ b/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs @@ -103,20 +103,10 @@ internal static ProtoObjectInfo CreateObjectInfo() foreach (var attr in polymorphicAttributes) { - var derivedFields = CreateTypeFieldInfo(attr.DerivedType); - var derivedCtor = attr.DerivedType.IsValueType - ? null - : attr.DerivedType.GetConstructor(Type.EmptyTypes); var key = attr.GetType().GetProperty(nameof(ProtoDerivedTypeAttribute.TypeDiscriminator)) ?.GetValue(attr); if (key == null) ThrowHelper.ThrowInvalidOperationException_UnknownPolymorphicType(type, attr.DerivedType); - objectInfo.SetTypeDiscriminator(key, - new ProtoPolymorphicDerivedTypeInfo - { - DerivedType = attr.DerivedType, - ObjectCreator = MemberAccessor.CreateParameterlessConstructor(derivedCtor), - Fields = derivedFields - }); + objectInfo.SetTypeDiscriminator(key, attr.DerivedType); } return objectInfo; diff --git a/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs b/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs index df2b8893..e92666aa 100644 --- a/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs +++ b/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs @@ -63,24 +63,7 @@ private static T DeserializeProtoPackableCore(ref ProtoReader reader) where T private static T DeserializeCore<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.All)] T>( ref ProtoReader reader) { - ProtoObjectConverter converter; - if (ProtoTypeResolver.IsRegistered()) - { - if (ProtoTypeResolver.GetConverter() as ProtoObjectConverter is not { } c) - { - converter = new ProtoObjectConverter(ProtoTypeResolver.CreateObjectInfo()); - ProtoTypeResolver.Register(converter); - } - else - { - converter = c; - } - } - else - { - ProtoTypeResolver.Register(converter = new ProtoObjectConverter()); - } - + var converter = GetConverterOf(); Debug.Assert(converter.ObjectInfo.ObjectCreator != null); T target = converter.ObjectInfo.ObjectCreator(); @@ -109,13 +92,12 @@ private static T DeserializeProtoPackableCore(ref ProtoReader reader) where T ThrowHelper.ThrowInvalidOperationException_FailedParsePolymorphicType(typeof(T), firstTag); } - if (converter.ObjectInfo.PolymorphicInfo.GetTypeFromDiscriminator(polyTypeKey) is { } polyTypeInfo) + if (converter.ObjectInfo.PolymorphicInfo.GetTypeFromDiscriminator(polyTypeKey) is { } polyType) { - fieldInfos = polyTypeInfo.Fields; - Debug.Assert(polyTypeInfo.ObjectCreator != null); - target = polyTypeInfo.ObjectCreator(); - boxed = (object?)target; // boxing - if (boxed is null) ThrowHelper.ThrowInvalidOperationException_CanNotCreateObject(typeof(T)); + (fieldInfos, var objectCreator ) = GetObjectInfoReflection(polyType); + target = objectCreator(); + boxed = target; + if (boxed is null) ThrowHelper.ThrowInvalidOperationException_CanNotCreateObject(polyType); } else if (!converter.ObjectInfo.PolymorphicInfo.PolymorphicFallbackToBaseType) { diff --git a/Lagrange.Proto/Serialization/ProtoSerializer.Helpers.cs b/Lagrange.Proto/Serialization/ProtoSerializer.Helpers.cs index 9958b536..874697f7 100644 --- a/Lagrange.Proto/Serialization/ProtoSerializer.Helpers.cs +++ b/Lagrange.Proto/Serialization/ProtoSerializer.Helpers.cs @@ -1,7 +1,56 @@ +using System.Diagnostics; +using System.Reflection; +using Lagrange.Proto.Serialization.Converter; +using Lagrange.Proto.Serialization.Metadata; + namespace Lagrange.Proto.Serialization; public static partial class ProtoSerializer { internal const string SerializationUnreferencedCodeMessage = "Proto serialization and deserialization might require types that cannot be statically analyzed. Use the SerializePackable that takes a IProtoSerializable to ensure generated code is used, or make sure all of the required types are preserved."; internal const string SerializationRequiresDynamicCodeMessage = "Proto serialization and deserialization might require types that cannot be statically analyzed and might need runtime code generation. Use Lagrange.Proto source generation for native AOT applications."; + + internal static ProtoObjectConverter GetConverterOf() + { + ProtoObjectConverter converter; + if (ProtoTypeResolver.IsRegistered()) + { + if (ProtoTypeResolver.GetConverter() as ProtoObjectConverter is not { } c) + { + converter = new ProtoObjectConverter(ProtoTypeResolver.CreateObjectInfo()); + ProtoTypeResolver.Register(converter); + } + else + { + converter = c; + } + } + else + { + ProtoTypeResolver.Register(converter = new ProtoObjectConverter()); + } + + return converter; + } + + internal static (Dictionary Fields, Func ObjectCreator) GetObjectInfoReflection(Type polyType) + { + Debug.Assert(polyType != typeof(T)); + Debug.Assert(polyType.IsAssignableTo(typeof(T))); + var method = typeof(ProtoSerializer).GetMethod(nameof(GetConverterOf), + BindingFlags.Static | BindingFlags.NonPublic); + Debug.Assert(method != null); + var genericMethod = method.MakeGenericMethod(polyType); + var polyConverter = genericMethod.Invoke(null, null)!; + + // get creator and fields, oh my reflection! + var polyObjectInfo = polyConverter.GetType() + .GetField("ObjectInfo",BindingFlags.NonPublic | BindingFlags.Instance)!.GetValue(polyConverter)!; + var polyCreator = polyObjectInfo.GetType() + .GetProperty("ObjectCreator")!.GetValue(polyObjectInfo)!; + var fieldInfos = (Dictionary)polyObjectInfo.GetType() + .GetProperty("Fields")!.GetValue(polyObjectInfo)!; + return (fieldInfos, ObjectCreator); + T ObjectCreator() => (T)polyCreator.GetType().GetMethod("Invoke")!.Invoke(polyCreator, null)!; + } } \ No newline at end of file diff --git a/Lagrange.Proto/Serialization/ProtoSerializer.Serialize.cs b/Lagrange.Proto/Serialization/ProtoSerializer.Serialize.cs index 831e3737..22e332d2 100644 --- a/Lagrange.Proto/Serialization/ProtoSerializer.Serialize.cs +++ b/Lagrange.Proto/Serialization/ProtoSerializer.Serialize.cs @@ -102,25 +102,8 @@ private static void SerializeProtoPackableCore(ProtoWriter writer, T obj) whe [RequiresUnreferencedCode(SerializationUnreferencedCodeMessage)] [RequiresDynamicCode(SerializationRequiresDynamicCodeMessage)] private static void SerializeCore<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.All)] T>(ProtoWriter writer, T obj) - { - ProtoObjectConverter converter; - if (ProtoTypeResolver.IsRegistered()) - { - if (ProtoTypeResolver.GetConverter() as ProtoObjectConverter is not { } c) - { - converter = new ProtoObjectConverter(ProtoTypeResolver.CreateObjectInfo()); - ProtoTypeResolver.Register(converter); - } - else - { - converter = c; - } - } - else - { - ProtoTypeResolver.Register(converter = new ProtoObjectConverter()); - } - + { + var converter = GetConverterOf(); var objectInfo = converter.ObjectInfo; object? boxed = obj; // avoid multiple times of boxing if (boxed is null) return; @@ -144,7 +127,8 @@ private static void SerializeProtoPackableCore(ProtoWriter writer, T obj) whe skipTag = fieldInfo.Key; writer.EncodeVarInt(fieldInfo.Key); fieldInfo.Value.Write(writer, boxed); - fields = derivedTypeInfo.Fields; + + (fields, _) = GetObjectInfoReflection(derivedTypeInfo); } foreach (var (tag, info) in fields) From 1adeaa6e640b1825862fb05c563e495133968352 Mon Sep 17 00:00:00 2001 From: Kengwang Date: Fri, 19 Sep 2025 01:59:06 +0800 Subject: [PATCH 5/7] [Proto] Add .Generator support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Allow [Base,Derived]Serialize × [Base,Derived]Deserialize (4 cases in total) --- .../Entity/PolymorphicTypeInfo.cs | 20 ++ .../ProtoSourceGenerator.Emitter.Serialize.cs | 42 +++- .../ProtoSourceGenerator.Emitter.TypeInfo.cs | 76 ++++++ .../ProtoSourceGenerator.Parser.cs | 84 ++++++- Lagrange.Proto.Runner/Program.cs | 35 ++- Lagrange.Proto.Test/ProtoPolymorphismTest.cs | 220 +++++++++++++++++- Lagrange.Proto/IProtoSerializable.cs | 2 + Lagrange.Proto/ProtoDerivedTypeAttribute.cs | 5 +- .../Serialization/Metadata/ProtoObjectInfo.cs | 4 +- .../Metadata/ProtoPolymorphicInfo.cs | 36 +-- .../Metadata/ProtoTypeResolver.Dynamic.cs | 9 +- .../ProtoSerializer.Deserialize.cs | 48 +++- .../Serialization/ProtoSerializer.Helpers.cs | 6 +- .../ProtoSerializer.Serialize.cs | 8 +- 14 files changed, 537 insertions(+), 58 deletions(-) create mode 100644 Lagrange.Proto.Generator/Entity/PolymorphicTypeInfo.cs diff --git a/Lagrange.Proto.Generator/Entity/PolymorphicTypeInfo.cs b/Lagrange.Proto.Generator/Entity/PolymorphicTypeInfo.cs new file mode 100644 index 00000000..c86daf3d --- /dev/null +++ b/Lagrange.Proto.Generator/Entity/PolymorphicTypeInfo.cs @@ -0,0 +1,20 @@ +using Microsoft.CodeAnalysis; + +namespace Lagrange.Proto.Generator.Entity; + +public class PolymorphicTypeInfo +{ + public INamedTypeSymbol PolymorphicKeyType { get; internal set; } = null!; + + public uint PolymorphicIndicateIndex { get; internal set; } = 0; + + public bool PolymorphicFallbackToBaseType { get; internal set; } = true; + + public List PolymorphicTypes { get; } = []; +} + +public class PolymorphicDerivedTypeInfo +{ + public INamedTypeSymbol DerivedType { get; internal set; } = null!; + public TypedConstant Key { get; internal set; } +} \ No newline at end of file diff --git a/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.Serialize.cs b/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.Serialize.cs index f224fa96..a69ba97a 100644 --- a/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.Serialize.cs +++ b/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.Serialize.cs @@ -25,22 +25,62 @@ private partial class Emitter private const string EncodeStringMethodName = "EncodeString"; private const string EncodeBytesMethodName = "EncodeBytes"; private const string EncodeResolvableMethodName = "EncodeResolvable"; + private void EmitSerializeMethod(SourceWriter source) { source.WriteLine($"public static void SerializeHandler({_fullQualifiedName} {ObjectVarName}, {ProtoWriterTypeRef} {WriterVarName})"); source.WriteLine("{"); source.Indentation++; - + + if (parser.BaseTypeSymbol is not null) + { + source.WriteLine($"{parser.BaseTypeSymbol.GetFullName()}.SerializeHandler({ObjectVarName},{WriterVarName});"); + } + + source.WriteLine($"SerializeHandlerCore({ObjectVarName}, {WriterVarName});"); + source.Indentation--; + source.WriteLine("}"); + source.WriteLine(); + + source.WriteLine($"private static void SerializeHandlerCore({_fullQualifiedName} {ObjectVarName}, {ProtoWriterTypeRef} {WriterVarName})"); + source.WriteLine("{"); + source.Indentation++; + if (parser.PolymorphicInfo.PolymorphicIndicateIndex > 0) + { + // write indicate field first + var idx = (int)parser.PolymorphicInfo.PolymorphicIndicateIndex; + var indicatorField = parser.Fields[idx]; + EmitMembers(source, idx, indicatorField); + source.WriteLine(); + + source.WriteLine($"switch ({ObjectVarName})"); + source.WriteLine("{"); + source.Indentation++; + for (var index = 0; index < parser.PolymorphicInfo.PolymorphicTypes.Count; index++) + { + var kv = parser.PolymorphicInfo.PolymorphicTypes[index]; + source.WriteLine($"case {kv.DerivedType.GetFullName()} derived{index}:"); + source.Indentation++; + source.WriteLine($"{kv.DerivedType.GetFullName()}.SerializeHandlerCore(derived{index}, {WriterVarName});"); + source.WriteLine("break;"); + source.Indentation--; + } + source.Indentation--; + source.WriteLine("}"); + source.WriteLine(); + } foreach (var kv in parser.Fields) { int field = kv.Key; + if (parser.PolymorphicInfo.PolymorphicIndicateIndex == field) continue; // already written var info = kv.Value; EmitMembers(source, field, info); source.WriteLine(); } + source.Indentation--; source.WriteLine("}"); } diff --git a/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.TypeInfo.cs b/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.TypeInfo.cs index 968ef61c..c292481b 100644 --- a/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.TypeInfo.cs +++ b/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.TypeInfo.cs @@ -2,6 +2,7 @@ using Lagrange.Proto.Generator.Utility; using Lagrange.Proto.Generator.Utility.Extension; using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; namespace Lagrange.Proto.Generator; @@ -13,6 +14,8 @@ private partial class Emitter private const string TypeInfoPropertyName = "TypeInfo"; private const string ProtoObjectInfoTypeRef = "global::Lagrange.Proto.Serialization.Metadata.ProtoObjectInfo<{0}>"; + private const string ProtoPolymorphicObjectInfoTypeRef = "global::Lagrange.Proto.Serialization.Metadata.ProtoPolymorphicObjectInfo<{0}>"; + private const string ProtoPolymorphicDerivedTypeDescriptorTypeRef = "global::Lagrange.Proto.Serialization.Metadata.ProtoPolymorphicDerivedTypeDescriptor<{0}>"; private const string ProtoFieldInfoTypeRef = "global::Lagrange.Proto.Serialization.Metadata.ProtoFieldInfo"; private const string ProtoFieldInfoGenericTypeRef = "global::Lagrange.Proto.Serialization.Metadata.ProtoFieldInfo<{0}>"; private const string ProtoMapFieldInfoGenericTypeRef = "global::Lagrange.Proto.Serialization.Metadata.ProtoMapFieldInfo<{0}, {1}, {2}>"; @@ -37,11 +40,15 @@ private partial class Emitter private void EmitTypeInfo(SourceWriter source) { + source.WriteLine("#pragma warning disable CS0108"); + source.WriteLine(); source.WriteLine($"public static {ProtoObjectInfoTypeRefGeneric}? {TypeInfoFieldName};"); source.WriteLine(); source.WriteLine($"public static {ProtoObjectInfoTypeRefGeneric} {TypeInfoPropertyName} => {TypeInfoFieldName} ??= GetTypeInfo();"); source.WriteLine(); + source.WriteLine("#pragma warning restore CS0108"); + source.WriteLine(); source.WriteLine($"private static {ProtoObjectInfoTypeRefGeneric} GetTypeInfo()"); source.WriteLine('{'); @@ -76,12 +83,16 @@ private void EmitTypeInfo(SourceWriter source) source.WriteLine("},"); source.WriteLine($"ObjectCreator = () => new {_fullQualifiedName}(),"); + EmitPolymorphicInfo(source, parser.PolymorphicInfo); source.WriteLine($"IgnoreDefaultFields = {parser.IgnoreDefaultFields.ToString().ToLower()}"); source.Indentation--; source.WriteLine("};"); source.Indentation--; source.WriteLine('}'); + + EmitPolymorphicDerivedTypeDescriptor(source, parser.PolymorphicInfo); + return; static void EmitByTypeSymbol(SourceWriter source, ITypeSymbol typeSymbol) @@ -142,5 +153,70 @@ private void EmitMapFieldInfo(SourceWriter source, int field, ProtoFieldInfo inf source.Indentation--; source.WriteLine("},"); } + + private void EmitPolymorphicInfo(SourceWriter source, PolymorphicTypeInfo polymorphicInfo) + { + if (polymorphicInfo.PolymorphicIndicateIndex == 0) return; + source.WriteLine($"PolymorphicInfo = new {string.Format(ProtoPolymorphicObjectInfoTypeRef, polymorphicInfo.PolymorphicKeyType.GetFullName())}()"); + source.WriteLine('{'); + source.Indentation++; + + source.WriteLine($"PolymorphicIndicateIndex = {polymorphicInfo.PolymorphicIndicateIndex},"); + source.WriteLine($"PolymorphicFallbackToBaseType = {parser.IgnoreDefaultFields.ToString().ToLower()},"); + source.WriteLine( + $"PolymorphicDerivedTypes = new global::System.Collections.Generic.Dictionary<{polymorphicInfo.PolymorphicKeyType.GetFullName()}, global::System.Type>()"); + source.WriteLine('{'); + source.Indentation++; + + foreach (var derivedTypeInfo in polymorphicInfo.PolymorphicTypes) + { + source.WriteLine($"[{derivedTypeInfo.Key.ToCSharpString()}] = typeof({derivedTypeInfo.DerivedType.GetFullName()}),"); + } + + source.Indentation--; + source.WriteLine("}"); + source.Indentation--; + source.WriteLine("},"); + } + + private void EmitPolymorphicDerivedTypeDescriptor(SourceWriter source, PolymorphicTypeInfo polymorphicInfo) + { + source.WriteLine("#pragma warning disable CS0108"); + source.WriteLine( + $"public static {string.Format(ProtoPolymorphicDerivedTypeDescriptorTypeRef,_fullQualifiedName)}? GetPolymorphicTypeDescriptor(TKey discriminator)"); + source.WriteLine('{'); + source.Indentation++; + if (polymorphicInfo.PolymorphicIndicateIndex > 0) + { + source.WriteLine("switch (discriminator)"); + source.WriteLine('{'); + source.Indentation++; + + foreach (var derivedTypeInfo in polymorphicInfo.PolymorphicTypes) + { + source.WriteLine($"case {derivedTypeInfo.Key.ToCSharpString()}: return new {string.Format(ProtoPolymorphicDerivedTypeDescriptorTypeRef, _fullQualifiedName)}()"); + source.WriteLine('{'); + source.Indentation++; + source.WriteLine($"Fields = {derivedTypeInfo.DerivedType.GetFullName()}.{TypeInfoPropertyName}.Fields,"); + source.WriteLine($"ObjectCreator = {derivedTypeInfo.DerivedType.GetFullName()}.{TypeInfoPropertyName}.ObjectCreator,"); + source.WriteLine($"IgnoreDefaultFields = {derivedTypeInfo.DerivedType.GetFullName()}.{TypeInfoPropertyName}.IgnoreDefaultFields"); + source.Indentation--; + source.WriteLine("};"); + } + + source.WriteLine("default: return null;"); + source.Indentation--; + source.WriteLine("}"); + } + else + { + source.WriteLine("return null;"); + } + + source.Indentation--; + source.WriteLine('}'); + source.WriteLine("#pragma warning restore CS0108"); + source.WriteLine(); + } } } \ No newline at end of file diff --git a/Lagrange.Proto.Generator/ProtoSourceGenerator.Parser.cs b/Lagrange.Proto.Generator/ProtoSourceGenerator.Parser.cs index 1ab8f999..23f17f8f 100644 --- a/Lagrange.Proto.Generator/ProtoSourceGenerator.Parser.cs +++ b/Lagrange.Proto.Generator/ProtoSourceGenerator.Parser.cs @@ -30,6 +30,10 @@ private class Parser(ClassDeclarationSyntax context, SemanticModel model) public bool IgnoreDefaultFields { get; private set; } public Dictionary Fields { get; } = new(); + + public PolymorphicTypeInfo PolymorphicInfo { get; } = new(); + + public INamedTypeSymbol? BaseTypeSymbol { get; private set; } = null; public void Parse(CancellationToken token = default) { @@ -41,22 +45,88 @@ public void Parse(CancellationToken token = default) ReportDiagnostics(UnableToGetSymbol, context.GetLocation(), context.Identifier.Text); return; } - + + var checkingBaseTypeSymbol = classSymbol.BaseType; + while (checkingBaseTypeSymbol?.BaseType is not null) + { + if (checkingBaseTypeSymbol.GetAttributes() + .Any(t => t.AttributeClass?.Name == "ProtoPackableAttribute") is true) + { + BaseTypeSymbol = checkingBaseTypeSymbol; + break; + } + checkingBaseTypeSymbol = checkingBaseTypeSymbol?.BaseType; + } + if (!classSymbol.Constructors.Any(x => x is { Parameters.Length: 0, DeclaredAccessibility: Accessibility.Public })) { ReportDiagnostics(MustContainParameterlessConstructor, context.GetLocation(), context.Identifier.Text); return; } - foreach (var argument in classSymbol.GetAttributes().SelectMany(x => x.NamedArguments)) + foreach (var attribute in classSymbol.GetAttributes()) { - switch (argument.Key) + switch (attribute.AttributeClass?.Name) { - case "IgnoreDefaultFields": - { - IgnoreDefaultFields = (bool)(argument.Value.Value ?? false); + case "ProtoPackableAttribute": + foreach (var argument in attribute.NamedArguments) + { + switch (argument.Key) + { + case "IgnoreDefaultFields": + IgnoreDefaultFields = (bool)(argument.Value.Value ?? false); + break; + } + } + break; + case "ProtoDerivedTypeAttribute": + if (PolymorphicInfo.PolymorphicIndicateIndex == 0) + PolymorphicInfo.PolymorphicIndicateIndex = 1; // set to default + // get key type + if (attribute.AttributeClass.TypeArguments.First() is not INamedTypeSymbol keyType){ + ReportDiagnostics(UnableToGetSymbol, context.GetLocation(), + attribute.AttributeClass.TypeArguments.First().ToDisplayString()); + return; + } + PolymorphicInfo.PolymorphicKeyType = keyType; + + // get derived type, in typeof + var derivedTypeConstant = attribute.ConstructorArguments.First(); + if (derivedTypeConstant.Kind != TypedConstantKind.Type) + { + ReportDiagnostics(UnableToGetSymbol, context.GetLocation(), + derivedTypeConstant.ToCSharpString()); + return; + } + + if (derivedTypeConstant.Value is not INamedTypeSymbol derivedType) + { + ReportDiagnostics(UnableToGetSymbol, context.GetLocation(), + derivedTypeConstant.ToCSharpString()); + return; + } + + // get type discriminator + var typeDiscriminatorConstant = attribute.ConstructorArguments.ElementAtOrDefault(1); + PolymorphicInfo.PolymorphicTypes.Add(new PolymorphicDerivedTypeInfo + { + DerivedType = derivedType, Key = typeDiscriminatorConstant + }); + break; + case "ProtoPolymorphicAttribute": + foreach (var argument in attribute.NamedArguments) + { + switch (argument.Key) + { + case "FieldNumber": + PolymorphicInfo.PolymorphicIndicateIndex = (uint)(argument.Value.Value ?? 0); + break; + case "FallbackToBaseType": + PolymorphicInfo.PolymorphicFallbackToBaseType = (bool)(argument.Value.Value ?? true); + break; + } + } break; - } } } diff --git a/Lagrange.Proto.Runner/Program.cs b/Lagrange.Proto.Runner/Program.cs index 9ceeda98..d01e95d4 100644 --- a/Lagrange.Proto.Runner/Program.cs +++ b/Lagrange.Proto.Runner/Program.cs @@ -20,4 +20,37 @@ private static void Main(string[] args) int value = parsed[1][0][1].GetValue(); } -} \ No newline at end of file +} + +#region Test Classes +[ProtoPackable] +[ProtoPolymorphic(FieldNumber = 1)] +[ProtoDerivedType(typeof(DerivedClassA), 4)] +[ProtoDerivedType(typeof(DerivedClassB), 3)] +public partial class BaseClass +{ + public BaseClass() : this(-1) { } + + public BaseClass(int identifier) + { + IdentifierProperty = identifier; + } + + [ProtoMember(1)] public int IdentifierProperty { get; set; } +} + + +[ProtoPackable] +public partial class DerivedClassA() : BaseClass(2) +{ + [ProtoMember(2)] public string? NameProperty { get; set; } +} + + +[ProtoPackable] +public partial class DerivedClassB() : BaseClass(3) +{ + [ProtoMember(2)] public float ValueProperty { get; set; } = 0f; +} + +#endregion \ No newline at end of file diff --git a/Lagrange.Proto.Test/ProtoPolymorphismTest.cs b/Lagrange.Proto.Test/ProtoPolymorphismTest.cs index c0a83a38..1367c7c5 100644 --- a/Lagrange.Proto.Test/ProtoPolymorphismTest.cs +++ b/Lagrange.Proto.Test/ProtoPolymorphismTest.cs @@ -1,4 +1,5 @@ -using Lagrange.Proto.Serialization; +using System.Reflection; +using Lagrange.Proto.Serialization; namespace Lagrange.Proto.Test; @@ -8,7 +9,7 @@ public class ProtoPolymorphismTest #region Basic Polymorphism [Test] - public void BasicPolymorphism_SerializeAndDeserialize_ReturnsCorrectDerivedType() + public void ReflectionPolymorphism_SerializeBaseAndDeserializeBase_ReturnsCorrectDerivedType() { // Arrange BaseClass originalA = new DerivedClassA { NameProperty = "TestName" }; @@ -34,16 +35,218 @@ public void BasicPolymorphism_SerializeAndDeserialize_ReturnsCorrectDerivedType( }); } + + [Test] + public void ReflectionPolymorphism_SerializeDerivedAndDeserializeBase_ReturnsCorrectDerivedType() + { + // Arrange + DerivedClassA originalA = new DerivedClassA { NameProperty = "TestName" }; + DerivedClassB originalB = new DerivedClassB { ValueProperty = 114514f }; + + byte[] bytesA = ProtoSerializer.Serialize(originalA); + BaseClass deserializedA = ProtoSerializer.Deserialize(bytesA); + + byte[] bytesB = ProtoSerializer.Serialize(originalB); + BaseClass deserializedB = ProtoSerializer.Deserialize(bytesB); + + Assert.Multiple(() => + { + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA.IdentifierProperty, Is.EqualTo(2)); + Assert.That(((DerivedClassA)deserializedA).NameProperty, Is.EqualTo("TestName")); + + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB.IdentifierProperty, Is.EqualTo(3)); + Assert.That(((DerivedClassB)deserializedB).ValueProperty, Is.EqualTo(114514f)); + + }); + } + + [Test] + public void ReflectionPolymorphism_SerializeBaseAndDeserializeDerived_ReturnsCorrectDerivedType() + { + // Arrange + BaseClass originalA = new DerivedClassA { NameProperty = "TestName" }; + BaseClass originalB = new DerivedClassB { ValueProperty = 114514f }; + + byte[] bytesA = ProtoSerializer.Serialize(originalA); + BaseClass deserializedA = ProtoSerializer.Deserialize(bytesA); + + byte[] bytesB = ProtoSerializer.Serialize(originalB); + BaseClass deserializedB = ProtoSerializer.Deserialize(bytesB); + + Assert.Multiple(() => + { + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA.IdentifierProperty, Is.EqualTo(2)); + Assert.That(((DerivedClassA)deserializedA).NameProperty, Is.EqualTo("TestName")); + + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB.IdentifierProperty, Is.EqualTo(3)); + Assert.That(((DerivedClassB)deserializedB).ValueProperty, Is.EqualTo(114514f)); + + }); + } + + [Test] + public void ReflectionPolymorphism_SerializeDerivedAndDeserializeDerived_ReturnsCorrectDerivedType() + { + // Arrange + DerivedClassA originalA = new DerivedClassA { NameProperty = "TestName" }; + DerivedClassB originalB = new DerivedClassB { ValueProperty = 114514f }; + + byte[] bytesA = ProtoSerializer.Serialize(originalA); + BaseClass deserializedA = ProtoSerializer.Deserialize(bytesA); + + byte[] bytesB = ProtoSerializer.Serialize(originalB); + BaseClass deserializedB = ProtoSerializer.Deserialize(bytesB); + + Assert.Multiple(() => + { + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA.IdentifierProperty, Is.EqualTo(2)); + Assert.That(((DerivedClassA)deserializedA).NameProperty, Is.EqualTo("TestName")); + + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB.IdentifierProperty, Is.EqualTo(3)); + Assert.That(((DerivedClassB)deserializedB).ValueProperty, Is.EqualTo(114514f)); + + }); + } + + #endregion + + #region ProtoPackable + + [Test] + public void ProtoPackablePolymorphism_SerializeBaseAndDeserializeDerived_ReturnsCorrectDerivedType() + { + // Arrange + BaseClass originalA = new DerivedClassA { NameProperty = "TestName" }; + BaseClass originalB = new DerivedClassB { ValueProperty = 114514f }; + + byte[] bytesA = ProtoSerializer.SerializeProtoPackable(originalA); + BaseClass deserializedA = ProtoSerializer.DeserializeProtoPackable(bytesA); + + byte[] bytesB = ProtoSerializer.SerializeProtoPackable(originalB); + BaseClass deserializedB = ProtoSerializer.DeserializeProtoPackable(bytesB); + + Assert.Multiple(() => + { + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA.IdentifierProperty, Is.EqualTo(2)); + Assert.That(((DerivedClassA)deserializedA).NameProperty, Is.EqualTo("TestName")); + + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB.IdentifierProperty, Is.EqualTo(3)); + Assert.That(((DerivedClassB)deserializedB).ValueProperty, Is.EqualTo(114514f)); + + }); + } + + [Test] + public void ProtoPackablePolymorphism_SerializeDerivedAndDeserializeDerived_ReturnsCorrectDerivedType() + { + // Arrange + DerivedClassA originalA = new DerivedClassA { NameProperty = "TestName" }; + DerivedClassB originalB = new DerivedClassB { ValueProperty = 114514f }; + + byte[] bytesA = ProtoSerializer.SerializeProtoPackable(originalA); + BaseClass deserializedA = ProtoSerializer.DeserializeProtoPackable(bytesA); + + byte[] bytesB = ProtoSerializer.SerializeProtoPackable(originalB); + BaseClass deserializedB = ProtoSerializer.DeserializeProtoPackable(bytesB); + + Assert.Multiple(() => + { + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA.IdentifierProperty, Is.EqualTo(2)); + Assert.That(((DerivedClassA)deserializedA).NameProperty, Is.EqualTo("TestName")); + + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB.IdentifierProperty, Is.EqualTo(3)); + Assert.That(((DerivedClassB)deserializedB).ValueProperty, Is.EqualTo(114514f)); + + }); + } + + [Test] + public void ProtoPackablePolymorphism_SerializeBaseAndDeserializeBase_ReturnsCorrectDerivedType() + { + // Arrange + BaseClass originalA = new DerivedClassA { NameProperty = "TestName" }; + BaseClass originalB = new DerivedClassB { ValueProperty = 114514f }; + + byte[] bytesA = ProtoSerializer.SerializeProtoPackable(originalA); + BaseClass deserializedA = ProtoSerializer.DeserializeProtoPackable(bytesA); + + byte[] bytesB = ProtoSerializer.SerializeProtoPackable(originalB); + BaseClass deserializedB = ProtoSerializer.DeserializeProtoPackable(bytesB); + + Assert.Multiple(() => + { + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA.IdentifierProperty, Is.EqualTo(2)); + Assert.That(((DerivedClassA)deserializedA).NameProperty, Is.EqualTo("TestName")); + + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB.IdentifierProperty, Is.EqualTo(3)); + Assert.That(((DerivedClassB)deserializedB).ValueProperty, Is.EqualTo(114514f)); + + }); + } + + [Test] + public void ProtoPackablePolymorphism_SerializeDerivedAndDeserializeBase_ReturnsCorrectDerivedType() + { + // Arrange + DerivedClassA originalA = new DerivedClassA { NameProperty = "TestName" }; + DerivedClassB originalB = new DerivedClassB { ValueProperty = 114514f }; + + byte[] bytesA = ProtoSerializer.SerializeProtoPackable(originalA); + BaseClass deserializedA = ProtoSerializer.DeserializeProtoPackable(bytesA); + + byte[] bytesB = ProtoSerializer.SerializeProtoPackable(originalB); + BaseClass deserializedB = ProtoSerializer.DeserializeProtoPackable(bytesB); + + Assert.Multiple(() => + { + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA.IdentifierProperty, Is.EqualTo(2)); + Assert.That(((DerivedClassA)deserializedA).NameProperty, Is.EqualTo("TestName")); + + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB.IdentifierProperty, Is.EqualTo(3)); + Assert.That(((DerivedClassB)deserializedB).ValueProperty, Is.EqualTo(114514f)); + }); + + } + + #endregion } #region Test Classes - +[ProtoPackable] [ProtoPolymorphic(FieldNumber = 1)] [ProtoDerivedType(typeof(DerivedClassA), 2)] [ProtoDerivedType(typeof(DerivedClassB), 3)] -public class BaseClass +public partial class BaseClass { public BaseClass() : this(-1) { } @@ -55,12 +258,15 @@ public BaseClass(int identifier) [ProtoMember(1)] public int IdentifierProperty { get; set; } } -public class DerivedClassA() : BaseClass(2) +[ProtoPackable] +public partial class DerivedClassA() : BaseClass(2) { - [ProtoMember(2)] public string NameProperty { get; set; } + [ProtoMember(2)] public string? NameProperty { get; set; } } -public class DerivedClassB() : BaseClass(3) + +[ProtoPackable] +public partial class DerivedClassB() : BaseClass(3) { [ProtoMember(2)] public float ValueProperty { get; set; } = 0f; } diff --git a/Lagrange.Proto/IProtoSerializable.cs b/Lagrange.Proto/IProtoSerializable.cs index 0b8ef9f3..efee9dfd 100644 --- a/Lagrange.Proto/IProtoSerializable.cs +++ b/Lagrange.Proto/IProtoSerializable.cs @@ -10,4 +10,6 @@ public interface IProtoSerializable public static abstract void SerializeHandler(T obj, ProtoWriter writer); public static abstract int MeasureHandler(T obj); + + public static abstract ProtoPolymorphicDerivedTypeDescriptor? GetPolymorphicTypeDescriptor(TKey discriminator); } \ No newline at end of file diff --git a/Lagrange.Proto/ProtoDerivedTypeAttribute.cs b/Lagrange.Proto/ProtoDerivedTypeAttribute.cs index 70f42ba0..2c9bac37 100644 --- a/Lagrange.Proto/ProtoDerivedTypeAttribute.cs +++ b/Lagrange.Proto/ProtoDerivedTypeAttribute.cs @@ -11,10 +11,9 @@ public ProtoDerivedTypeAttribute(Type derivedType, T typeDiscriminator) : base(d /// /// The type discriminator identifier to be used for the serialization of the subtype. /// - public T TypeDiscriminator { get; init; } + internal T TypeDiscriminator { get; init; } } -[AttributeUsage(AttributeTargets.Class | AttributeTargets.Interface, AllowMultiple = true, Inherited = false)] public class ProtoDerivedTypeAttribute : Attribute { public ProtoDerivedTypeAttribute(Type derivedType) @@ -25,5 +24,5 @@ public ProtoDerivedTypeAttribute(Type derivedType) /// /// A derived type that should be supported in polymorphic serialization of the declared base type. /// - public Type DerivedType { get; init; } + internal Type DerivedType { get; init; } } \ No newline at end of file diff --git a/Lagrange.Proto/Serialization/Metadata/ProtoObjectInfo.cs b/Lagrange.Proto/Serialization/Metadata/ProtoObjectInfo.cs index ac1f421e..a062ab8f 100644 --- a/Lagrange.Proto/Serialization/Metadata/ProtoObjectInfo.cs +++ b/Lagrange.Proto/Serialization/Metadata/ProtoObjectInfo.cs @@ -6,9 +6,7 @@ namespace Lagrange.Proto.Serialization.Metadata; public class ProtoObjectInfo { public Dictionary Fields { get; init; } = new(); - public Func? ObjectCreator { get; init; } - public bool IgnoreDefaultFields { get; init; } - public ProtoPolymorphicInfoBase? PolymorphicInfo { get; init; } + public IProtoPolymorphicInfoBase? PolymorphicInfo { get; init; } } \ No newline at end of file diff --git a/Lagrange.Proto/Serialization/Metadata/ProtoPolymorphicInfo.cs b/Lagrange.Proto/Serialization/Metadata/ProtoPolymorphicInfo.cs index 3d296a55..4ae56b59 100644 --- a/Lagrange.Proto/Serialization/Metadata/ProtoPolymorphicInfo.cs +++ b/Lagrange.Proto/Serialization/Metadata/ProtoPolymorphicInfo.cs @@ -1,35 +1,39 @@ namespace Lagrange.Proto.Serialization.Metadata; -public class ProtoPolymorphicInfoBase +public interface IProtoPolymorphicInfoBase { - public uint PolymorphicIndicateIndex { get; set; } = 0; - public bool PolymorphicFallbackToBaseType { get; set; } = true; + public uint PolymorphicIndicateIndex { get; set; } + public bool PolymorphicFallbackToBaseType { get; set; } - public virtual Type? GetTypeFromDiscriminator(object discriminator) - { - return null; - } - - public virtual bool SetTypeDiscriminator(object discriminator, Type type) - { - return false; - } + public Type? GetTypeFromDiscriminator(object discriminator); + + public bool SetTypeDiscriminator(object discriminator, Type type); } -public class ProtoPolymorphicObjectInfo : ProtoPolymorphicInfoBase where TKey : IEquatable +public class ProtoPolymorphicObjectInfo : IProtoPolymorphicInfoBase where TKey : IEquatable { - public override Type? GetTypeFromDiscriminator(object discriminator) + public uint PolymorphicIndicateIndex { get; set; } = 0; + public bool PolymorphicFallbackToBaseType { get; set; } = true; + + public Type? GetTypeFromDiscriminator(object discriminator) { return PolymorphicDerivedTypes.GetValueOrDefault((TKey)discriminator); } - public override bool SetTypeDiscriminator(object discriminator, Type type) + public bool SetTypeDiscriminator(object discriminator, Type type) { PolymorphicDerivedTypes[(TKey)discriminator] = type; return true; } - public Dictionary PolymorphicDerivedTypes { get; } = []; + public Dictionary PolymorphicDerivedTypes { get; init; } = []; +} + +public class ProtoPolymorphicDerivedTypeDescriptor +{ + public Dictionary Fields { get; init; } = []; + public Func ObjectCreator { get; init; } = null!; + public bool IgnoreDefaultFields { get; init; } } \ No newline at end of file diff --git a/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs b/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs index a2f4d906..df62e50f 100644 --- a/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs +++ b/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs @@ -78,7 +78,7 @@ internal static ProtoObjectInfo CreateObjectInfo() }; } - internal static ProtoPolymorphicInfoBase? PopulatePolymorphicInfo() + internal static IProtoPolymorphicInfoBase? PopulatePolymorphicInfo() { var type = typeof(T); var polymorphicAttributes = type.GetCustomAttributes(typeof(ProtoDerivedTypeAttribute<>)) @@ -93,9 +93,8 @@ internal static ProtoObjectInfo CreateObjectInfo() // get the TKey from first var firstAttr = polymorphicAttributes[0]; var keyType = firstAttr.GetType().GetGenericArguments()[0]; - var objectInfo = MemberAccessor.CreateParameterlessConstructor>( - typeof(ProtoPolymorphicObjectInfo<,>).MakeGenericType(typeof(T), keyType) - .GetConstructor(Type.EmptyTypes))?.Invoke(); + var objectInfo = (IProtoPolymorphicInfoBase?) typeof(ProtoPolymorphicObjectInfo<>).MakeGenericType(keyType) + .GetConstructor(Type.EmptyTypes)?.Invoke(null); Debug.Assert(objectInfo != null); objectInfo.PolymorphicIndicateIndex = polymorphicFieldNumber; @@ -103,7 +102,7 @@ internal static ProtoObjectInfo CreateObjectInfo() foreach (var attr in polymorphicAttributes) { - var key = attr.GetType().GetProperty(nameof(ProtoDerivedTypeAttribute.TypeDiscriminator)) + var key = attr.GetType().GetProperty(nameof(ProtoDerivedTypeAttribute.TypeDiscriminator), BindingFlags.NonPublic | BindingFlags.Instance) ?.GetValue(attr); if (key == null) ThrowHelper.ThrowInvalidOperationException_UnknownPolymorphicType(type, attr.DerivedType); objectInfo.SetTypeDiscriminator(key, attr.DerivedType); diff --git a/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs b/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs index e92666aa..90469b04 100644 --- a/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs +++ b/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs @@ -26,11 +26,39 @@ private static T DeserializeProtoPackableCore(ref ProtoReader reader) where T Debug.Assert(objectInfo.ObjectCreator != null); T target = objectInfo.ObjectCreator(); + var fields = objectInfo.Fields; + if (objectInfo.PolymorphicInfo?.PolymorphicIndicateIndex is > 0) + { + // has polymorphic type, read the first field to determine the actual type + uint firstTag = reader.DecodeVarIntUnsafe(); + + if (firstTag >>> 3 != objectInfo.PolymorphicInfo.PolymorphicIndicateIndex) + { + ThrowHelper.ThrowInvalidOperationException_PolymorphicFieldNotFirst(typeof(T), + objectInfo.PolymorphicInfo.PolymorphicIndicateIndex, firstTag >>> 3); + } + + var firstField = objectInfo.Fields[firstTag]; + firstField.Read(ref reader, target); + var polyTypeKey = firstField.Get?.Invoke(target); + var typeDescriptor = T.GetPolymorphicTypeDescriptor(polyTypeKey); + if (typeDescriptor is not null) + { + fields = typeDescriptor.Fields; + target = typeDescriptor.ObjectCreator.Invoke(); + } + else if (!objectInfo.PolymorphicInfo.PolymorphicFallbackToBaseType) + { + ThrowHelper.ThrowInvalidOperationException_UnknownPolymorphicType(typeof(T), polyTypeKey!); + } + } + + while (!reader.IsCompleted) { uint tag = reader.DecodeVarIntUnsafe(); - if (objectInfo.Fields.TryGetValue(tag, out var fieldInfo)) + if (fields.TryGetValue(tag, out var fieldInfo)) { fieldInfo.Read(ref reader, target); } @@ -69,19 +97,19 @@ private static T DeserializeProtoPackableCore(ref ProtoReader reader) where T T target = converter.ObjectInfo.ObjectCreator(); var boxed = (object?)target; // avoid multiple times of boxing if (boxed is null) ThrowHelper.ThrowInvalidOperationException_CanNotCreateObject(typeof(T)); - var fieldInfos = converter.ObjectInfo.Fields; - + var polymorphicInfo = converter.ObjectInfo.PolymorphicInfo; + startDeserialize: // polymorphic type - if (converter.ObjectInfo.PolymorphicInfo?.PolymorphicIndicateIndex is > 0) + if (polymorphicInfo?.PolymorphicIndicateIndex is > 0) { // has polymorphic type, read the first field to determine the actual type uint firstTag = reader.DecodeVarIntUnsafe(); - if (firstTag >>> 3 != converter.ObjectInfo.PolymorphicInfo.PolymorphicIndicateIndex) + if (firstTag >>> 3 != polymorphicInfo.PolymorphicIndicateIndex) { ThrowHelper.ThrowInvalidOperationException_PolymorphicFieldNotFirst(typeof(T), - converter.ObjectInfo.PolymorphicInfo.PolymorphicIndicateIndex, firstTag >>> 3); + polymorphicInfo.PolymorphicIndicateIndex, firstTag >>> 3); } var firstField = converter.ObjectInfo.Fields[firstTag]; @@ -92,17 +120,19 @@ private static T DeserializeProtoPackableCore(ref ProtoReader reader) where T ThrowHelper.ThrowInvalidOperationException_FailedParsePolymorphicType(typeof(T), firstTag); } - if (converter.ObjectInfo.PolymorphicInfo.GetTypeFromDiscriminator(polyTypeKey) is { } polyType) + if (polymorphicInfo.GetTypeFromDiscriminator(polyTypeKey) is { } polyType) { - (fieldInfos, var objectCreator ) = GetObjectInfoReflection(polyType); + (fieldInfos, var objectCreator, polymorphicInfo) = GetObjectInfoReflection(polyType); target = objectCreator(); boxed = target; if (boxed is null) ThrowHelper.ThrowInvalidOperationException_CanNotCreateObject(polyType); } - else if (!converter.ObjectInfo.PolymorphicInfo.PolymorphicFallbackToBaseType) + else if (!polymorphicInfo.PolymorphicFallbackToBaseType) { ThrowHelper.ThrowInvalidOperationException_UnknownPolymorphicType(typeof(T), polyTypeKey); } + + goto startDeserialize; } diff --git a/Lagrange.Proto/Serialization/ProtoSerializer.Helpers.cs b/Lagrange.Proto/Serialization/ProtoSerializer.Helpers.cs index 874697f7..9694faf8 100644 --- a/Lagrange.Proto/Serialization/ProtoSerializer.Helpers.cs +++ b/Lagrange.Proto/Serialization/ProtoSerializer.Helpers.cs @@ -33,7 +33,7 @@ internal static ProtoObjectConverter GetConverterOf() return converter; } - internal static (Dictionary Fields, Func ObjectCreator) GetObjectInfoReflection(Type polyType) + internal static (Dictionary Fields, Func ObjectCreator, IProtoPolymorphicInfoBase? polymorphicInfo) GetObjectInfoReflection(Type polyType) { Debug.Assert(polyType != typeof(T)); Debug.Assert(polyType.IsAssignableTo(typeof(T))); @@ -50,7 +50,9 @@ internal static (Dictionary Fields, Func ObjectCreator) .GetProperty("ObjectCreator")!.GetValue(polyObjectInfo)!; var fieldInfos = (Dictionary)polyObjectInfo.GetType() .GetProperty("Fields")!.GetValue(polyObjectInfo)!; - return (fieldInfos, ObjectCreator); + var polymorphicInfo = (IProtoPolymorphicInfoBase)polyObjectInfo.GetType() + .GetProperty("PolymorphicInfo")!.GetValue(polyObjectInfo)!; + return (fieldInfos, ObjectCreator,polymorphicInfo); T ObjectCreator() => (T)polyCreator.GetType().GetMethod("Invoke")!.Invoke(polyCreator, null)!; } } \ No newline at end of file diff --git a/Lagrange.Proto/Serialization/ProtoSerializer.Serialize.cs b/Lagrange.Proto/Serialization/ProtoSerializer.Serialize.cs index 22e332d2..ccee312f 100644 --- a/Lagrange.Proto/Serialization/ProtoSerializer.Serialize.cs +++ b/Lagrange.Proto/Serialization/ProtoSerializer.Serialize.cs @@ -109,12 +109,12 @@ private static void SerializeProtoPackableCore(ProtoWriter writer, T obj) whe if (boxed is null) return; var fields = objectInfo.Fields; uint skipTag = 0; - + var polymorphicInfo = converter.ObjectInfo.PolymorphicInfo; // check polymorphic type - if (converter.ObjectInfo.PolymorphicInfo?.PolymorphicIndicateIndex is > 0) + if (polymorphicInfo?.PolymorphicIndicateIndex is > 0) { // has polymorphic type - var index = converter.ObjectInfo.PolymorphicInfo.PolymorphicIndicateIndex; + var index = polymorphicInfo.PolymorphicIndicateIndex; var fieldInfo = objectInfo.Fields.FirstOrDefault(t=>t.Value.Field == index); if (fieldInfo.Value is null) ThrowHelper.ThrowInvalidOperationException_NullPolymorphicDiscriminator(typeof(T)); var discriminator = fieldInfo.Value.Get?.Invoke(boxed); @@ -128,7 +128,7 @@ private static void SerializeProtoPackableCore(ProtoWriter writer, T obj) whe writer.EncodeVarInt(fieldInfo.Key); fieldInfo.Value.Write(writer, boxed); - (fields, _) = GetObjectInfoReflection(derivedTypeInfo); + (fields, _, _) = GetObjectInfoReflection(derivedTypeInfo); } foreach (var (tag, info) in fields) From 14b739cb0fcfc37e3bce47ec1b7ebb20f6a870a7 Mon Sep 17 00:00:00 2001 From: Kengwang Date: Mon, 6 Oct 2025 00:39:12 +0800 Subject: [PATCH 6/7] [Proto] Add derived type (de)serialization support --- .../Entity/PolymorphicTypeInfo.cs | 8 + .../ProtoSourceGenerator.Emitter.Serialize.cs | 15 +- .../ProtoSourceGenerator.Emitter.TypeInfo.cs | 73 ++++++--- .../ProtoSourceGenerator.Parser.cs | 131 ++++++++++----- Lagrange.Proto.Test/ProtoPolymorphismTest.cs | 150 ++++++++++++++---- Lagrange.Proto/IProtoSerializable.cs | 2 - .../Serialization/Metadata/ProtoObjectInfo.cs | 2 +- .../Metadata/ProtoPolymorphicInfo.cs | 72 ++++++--- .../Metadata/ProtoTypeResolver.Dynamic.cs | 36 ++++- .../ProtoSerializer.Deserialize.cs | 99 ++++++------ .../Serialization/ProtoSerializer.Helpers.cs | 19 ++- .../ProtoSerializer.Serialize.cs | 36 +++-- 12 files changed, 449 insertions(+), 194 deletions(-) diff --git a/Lagrange.Proto.Generator/Entity/PolymorphicTypeInfo.cs b/Lagrange.Proto.Generator/Entity/PolymorphicTypeInfo.cs index c86daf3d..47391da0 100644 --- a/Lagrange.Proto.Generator/Entity/PolymorphicTypeInfo.cs +++ b/Lagrange.Proto.Generator/Entity/PolymorphicTypeInfo.cs @@ -17,4 +17,12 @@ public class PolymorphicDerivedTypeInfo { public INamedTypeSymbol DerivedType { get; internal set; } = null!; public TypedConstant Key { get; internal set; } +} + +public class BaseTypeInfo +{ + public INamedTypeSymbol BaseType { get; internal set; } = null!; + public PolymorphicTypeInfo PolymorphicInfo { get; internal set; } = new(); + public bool IgnoreDefaultFields { get; internal set; } + public Dictionary Fields { get; } = new(); } \ No newline at end of file diff --git a/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.Serialize.cs b/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.Serialize.cs index a69ba97a..b636c60f 100644 --- a/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.Serialize.cs +++ b/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.Serialize.cs @@ -33,17 +33,20 @@ private void EmitSerializeMethod(SourceWriter source) source.WriteLine("{"); source.Indentation++; - if (parser.BaseTypeSymbol is not null) + if (parser.BaseTypeInfo.BaseType.GetFullName() != _fullQualifiedName && parser.BaseTypeInfo?.PolymorphicInfo.PolymorphicIndicateIndex is > 0) { - source.WriteLine($"{parser.BaseTypeSymbol.GetFullName()}.SerializeHandler({ObjectVarName},{WriterVarName});"); + source.WriteLine($"{parser.BaseTypeInfo.BaseType.GetFullName()}.SerializeHandler({ObjectVarName},{WriterVarName});"); + } + else + { + source.WriteLine($"SerializeHandlerCore({ObjectVarName}, {WriterVarName});"); } - source.WriteLine($"SerializeHandlerCore({ObjectVarName}, {WriterVarName});"); source.Indentation--; source.WriteLine("}"); source.WriteLine(); - source.WriteLine($"private static void SerializeHandlerCore({_fullQualifiedName} {ObjectVarName}, {ProtoWriterTypeRef} {WriterVarName})"); + source.WriteLine($"public static void SerializeHandlerCore({_fullQualifiedName} {ObjectVarName}, {ProtoWriterTypeRef} {WriterVarName})"); source.WriteLine("{"); source.Indentation++; if (parser.PolymorphicInfo.PolymorphicIndicateIndex > 0) @@ -60,9 +63,9 @@ private void EmitSerializeMethod(SourceWriter source) for (var index = 0; index < parser.PolymorphicInfo.PolymorphicTypes.Count; index++) { var kv = parser.PolymorphicInfo.PolymorphicTypes[index]; - source.WriteLine($"case {kv.DerivedType.GetFullName()} derived{index}:"); + source.WriteLine($"case {kv.DerivedType.GetFullName()} _:"); source.Indentation++; - source.WriteLine($"{kv.DerivedType.GetFullName()}.SerializeHandlerCore(derived{index}, {WriterVarName});"); + source.WriteLine($"{kv.DerivedType.GetFullName()}.SerializeHandlerCore(({kv.DerivedType.GetFullName()}){ObjectVarName}, {WriterVarName});"); source.WriteLine("break;"); source.Indentation--; } diff --git a/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.TypeInfo.cs b/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.TypeInfo.cs index c292481b..76852adb 100644 --- a/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.TypeInfo.cs +++ b/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.TypeInfo.cs @@ -16,6 +16,7 @@ private partial class Emitter private const string ProtoObjectInfoTypeRef = "global::Lagrange.Proto.Serialization.Metadata.ProtoObjectInfo<{0}>"; private const string ProtoPolymorphicObjectInfoTypeRef = "global::Lagrange.Proto.Serialization.Metadata.ProtoPolymorphicObjectInfo<{0}>"; private const string ProtoPolymorphicDerivedTypeDescriptorTypeRef = "global::Lagrange.Proto.Serialization.Metadata.ProtoPolymorphicDerivedTypeDescriptor<{0}>"; + private const string ProtoPolymorphicDerivedTypeDescriptorBaseTypeRef = "global::Lagrange.Proto.Serialization.Metadata.ProtoPolymorphicDerivedTypeDescriptor"; private const string ProtoFieldInfoTypeRef = "global::Lagrange.Proto.Serialization.Metadata.ProtoFieldInfo"; private const string ProtoFieldInfoGenericTypeRef = "global::Lagrange.Proto.Serialization.Metadata.ProtoFieldInfo<{0}>"; private const string ProtoMapFieldInfoGenericTypeRef = "global::Lagrange.Proto.Serialization.Metadata.ProtoMapFieldInfo<{0}, {1}, {2}>"; @@ -67,23 +68,11 @@ private void EmitTypeInfo(SourceWriter source) source.WriteLine($"return new {ProtoObjectInfoTypeRefGeneric}()"); source.WriteLine('{'); source.Indentation++; - - source.WriteLine($"Fields = new global::System.Collections.Generic.Dictionary()"); - source.WriteLine('{'); - source.Indentation++; - foreach (var kv in parser.Fields) - { - int field = kv.Key; - var info = kv.Value; - if (info.ExtraTypeInfo.Count == 0) EmitFieldInfo(source, field, info); - else EmitMapFieldInfo(source, field, info); - } - source.Indentation--; - source.WriteLine("},"); + EmitFieldsInfo(source, parser.Fields); source.WriteLine($"ObjectCreator = () => new {_fullQualifiedName}(),"); - EmitPolymorphicInfo(source, parser.PolymorphicInfo); + EmitPolymorphicInfo(source, parser.PolymorphicInfo, parser.BaseTypeInfo); source.WriteLine($"IgnoreDefaultFields = {parser.IgnoreDefaultFields.ToString().ToLower()}"); source.Indentation--; @@ -117,6 +106,23 @@ static void EmitByTypeSymbol(SourceWriter source, ITypeSymbol typeSymbol) } } + private void EmitFieldsInfo(SourceWriter source,Dictionary fields ) + { + source.WriteLine($"Fields = new global::System.Collections.Generic.Dictionary()"); + source.WriteLine('{'); + source.Indentation++; + foreach (var kv in fields) + { + int field = kv.Key; + var info = kv.Value; + + if (info.ExtraTypeInfo.Count == 0) EmitFieldInfo(source, field, info); + else EmitMapFieldInfo(source, field, info); + } + source.Indentation--; + source.WriteLine("},"); + } + private void EmitFieldInfo(SourceWriter source, int field, ProtoFieldInfo info) { int tag = field << 3 | (byte)info.WireType; @@ -154,7 +160,7 @@ private void EmitMapFieldInfo(SourceWriter source, int field, ProtoFieldInfo inf source.WriteLine("},"); } - private void EmitPolymorphicInfo(SourceWriter source, PolymorphicTypeInfo polymorphicInfo) + private void EmitPolymorphicInfo(SourceWriter source, PolymorphicTypeInfo polymorphicInfo, BaseTypeInfo baseTypeInfo) { if (polymorphicInfo.PolymorphicIndicateIndex == 0) return; source.WriteLine($"PolymorphicInfo = new {string.Format(ProtoPolymorphicObjectInfoTypeRef, polymorphicInfo.PolymorphicKeyType.GetFullName())}()"); @@ -162,15 +168,37 @@ private void EmitPolymorphicInfo(SourceWriter source, PolymorphicTypeInfo polymo source.Indentation++; source.WriteLine($"PolymorphicIndicateIndex = {polymorphicInfo.PolymorphicIndicateIndex},"); - source.WriteLine($"PolymorphicFallbackToBaseType = {parser.IgnoreDefaultFields.ToString().ToLower()},"); + source.WriteLine($"PolymorphicFallbackToBaseType = {polymorphicInfo.PolymorphicFallbackToBaseType.ToString().ToLower()},"); + if (parser.BaseTypeInfo?.BaseType is not null && parser.BaseTypeInfo.BaseType.GetFullName() != _fullQualifiedName) + { + source.WriteLine($"RootTypeDescriptorGetter = () => new {string.Format(ProtoPolymorphicDerivedTypeDescriptorTypeRef, parser.BaseTypeInfo.BaseType.GetFullName())}()"); + source.WriteLine('{'); + source.Indentation++; + source.WriteLine($"FieldsGetter = () => {parser.BaseTypeInfo.BaseType.GetFullName()}.{TypeInfoPropertyName}.Fields,"); + source.WriteLine($"ObjectCreator = () => new {parser.BaseTypeInfo.BaseType.GetFullName()}(),"); + source.WriteLine($"IgnoreDefaultFieldsGetter = () => {parser.BaseTypeInfo.BaseType.GetFullName()}.{TypeInfoPropertyName}.IgnoreDefaultFields,"); + source.WriteLine($"PolymorphicInfoGetter = () => {parser.BaseTypeInfo.BaseType.GetFullName()}.{TypeInfoPropertyName}.PolymorphicInfo,"); + source.WriteLine($"CurrentType = typeof({parser.BaseTypeInfo.BaseType.GetFullName()})"); + source.Indentation--; + source.WriteLine("},"); + } source.WriteLine( - $"PolymorphicDerivedTypes = new global::System.Collections.Generic.Dictionary<{polymorphicInfo.PolymorphicKeyType.GetFullName()}, global::System.Type>()"); + $"PolymorphicDerivedTypes = new global::System.Collections.Generic.Dictionary<{polymorphicInfo.PolymorphicKeyType.GetFullName()}, {ProtoPolymorphicDerivedTypeDescriptorBaseTypeRef}>()"); source.WriteLine('{'); source.Indentation++; foreach (var derivedTypeInfo in polymorphicInfo.PolymorphicTypes) { - source.WriteLine($"[{derivedTypeInfo.Key.ToCSharpString()}] = typeof({derivedTypeInfo.DerivedType.GetFullName()}),"); + source.WriteLine($"[{derivedTypeInfo.Key.ToCSharpString()}] = new {string.Format(ProtoPolymorphicDerivedTypeDescriptorTypeRef, baseTypeInfo.BaseType.GetFullName())}()"); + source.WriteLine('{'); + source.Indentation++; + source.WriteLine($"CurrentType = typeof({derivedTypeInfo.DerivedType.GetFullName()}),"); + source.WriteLine($"FieldsGetter = () => {derivedTypeInfo.DerivedType.GetFullName()}.{TypeInfoPropertyName}.Fields,"); + source.WriteLine($"ObjectCreator = () => ({baseTypeInfo.BaseType.GetFullName()})new {derivedTypeInfo.DerivedType.GetFullName()}(),"); + source.WriteLine($"IgnoreDefaultFieldsGetter = () => {derivedTypeInfo.DerivedType.GetFullName()}.{TypeInfoPropertyName}.IgnoreDefaultFields,"); + source.WriteLine($"PolymorphicInfoGetter = () => {derivedTypeInfo.DerivedType.GetFullName()}.{TypeInfoPropertyName}.PolymorphicInfo"); + source.Indentation--; + source.WriteLine("},"); } source.Indentation--; @@ -181,6 +209,7 @@ private void EmitPolymorphicInfo(SourceWriter source, PolymorphicTypeInfo polymo private void EmitPolymorphicDerivedTypeDescriptor(SourceWriter source, PolymorphicTypeInfo polymorphicInfo) { + source.WriteLine("#pragma warning disable CS0108"); source.WriteLine( $"public static {string.Format(ProtoPolymorphicDerivedTypeDescriptorTypeRef,_fullQualifiedName)}? GetPolymorphicTypeDescriptor(TKey discriminator)"); @@ -197,9 +226,11 @@ private void EmitPolymorphicDerivedTypeDescriptor(SourceWriter source, Polymorph source.WriteLine($"case {derivedTypeInfo.Key.ToCSharpString()}: return new {string.Format(ProtoPolymorphicDerivedTypeDescriptorTypeRef, _fullQualifiedName)}()"); source.WriteLine('{'); source.Indentation++; - source.WriteLine($"Fields = {derivedTypeInfo.DerivedType.GetFullName()}.{TypeInfoPropertyName}.Fields,"); - source.WriteLine($"ObjectCreator = {derivedTypeInfo.DerivedType.GetFullName()}.{TypeInfoPropertyName}.ObjectCreator,"); - source.WriteLine($"IgnoreDefaultFields = {derivedTypeInfo.DerivedType.GetFullName()}.{TypeInfoPropertyName}.IgnoreDefaultFields"); + source.WriteLine($"FieldsGetter = () => {derivedTypeInfo.DerivedType.GetFullName()}.{TypeInfoPropertyName}.Fields,"); + source.WriteLine($"ObjectCreator = () => ({_fullQualifiedName})new {derivedTypeInfo.DerivedType.GetFullName()}(),"); + source.WriteLine($"IgnoreDefaultFieldsGetter = () => {derivedTypeInfo.DerivedType.GetFullName()}.{TypeInfoPropertyName}.IgnoreDefaultFields,"); + source.WriteLine($"PolymorphicInfoGetter = () => {derivedTypeInfo.DerivedType.GetFullName()}.{TypeInfoPropertyName}.PolymorphicInfo,"); + source.WriteLine($"CurrentType = typeof({derivedTypeInfo.DerivedType.GetFullName()})"); source.Indentation--; source.WriteLine("};"); } diff --git a/Lagrange.Proto.Generator/ProtoSourceGenerator.Parser.cs b/Lagrange.Proto.Generator/ProtoSourceGenerator.Parser.cs index 23f17f8f..dc185599 100644 --- a/Lagrange.Proto.Generator/ProtoSourceGenerator.Parser.cs +++ b/Lagrange.Proto.Generator/ProtoSourceGenerator.Parser.cs @@ -33,7 +33,7 @@ private class Parser(ClassDeclarationSyntax context, SemanticModel model) public PolymorphicTypeInfo PolymorphicInfo { get; } = new(); - public INamedTypeSymbol? BaseTypeSymbol { get; private set; } = null; + public BaseTypeInfo? BaseTypeInfo { get; private set; } = null; public void Parse(CancellationToken token = default) { @@ -46,18 +46,6 @@ public void Parse(CancellationToken token = default) return; } - var checkingBaseTypeSymbol = classSymbol.BaseType; - while (checkingBaseTypeSymbol?.BaseType is not null) - { - if (checkingBaseTypeSymbol.GetAttributes() - .Any(t => t.AttributeClass?.Name == "ProtoPackableAttribute") is true) - { - BaseTypeSymbol = checkingBaseTypeSymbol; - break; - } - checkingBaseTypeSymbol = checkingBaseTypeSymbol?.BaseType; - } - if (!classSymbol.Constructors.Any(x => x is { Parameters.Length: 0, DeclaredAccessibility: Accessibility.Public })) { ReportDiagnostics(MustContainParameterlessConstructor, context.GetLocation(), context.Identifier.Text); @@ -79,16 +67,79 @@ public void Parse(CancellationToken token = default) } } break; + } + } + + if (!TryGetNestedTypeDeclarations(context, Model, token, out var typeDeclarations)) + { + ReportDiagnostics(MustBePartialClass, context.GetLocation(), context.Identifier.Text); + return; + } + TypeDeclarations.AddRange(typeDeclarations); + + PopulateFieldInfo(classSymbol, Fields, identifier, token); + PopulatePolymorphicInfo(classSymbol, PolymorphicInfo, token); + + // Handling BaseType + var checkingBaseTypeSymbol = classSymbol; + while (checkingBaseTypeSymbol?.BaseType is not null) + { + token.ThrowIfCancellationRequested(); + if (!checkingBaseTypeSymbol.BaseType.GetAttributes() + .Any(t => t.AttributeClass?.Name == "ProtoPackableAttribute")) + { + break; + } + + checkingBaseTypeSymbol = checkingBaseTypeSymbol.BaseType; + PopulateFieldInfo(checkingBaseTypeSymbol, Fields, identifier, token); + } + + if (checkingBaseTypeSymbol is not null) + { + BaseTypeInfo = new BaseTypeInfo(); + BaseTypeInfo.BaseType = checkingBaseTypeSymbol; + foreach (var attribute in checkingBaseTypeSymbol.GetAttributes()) + { + switch (attribute.AttributeClass?.Name) + { + case "ProtoPackableAttribute": + foreach (var argument in attribute.NamedArguments) + { + switch (argument.Key) + { + case "IgnoreDefaultFields": + BaseTypeInfo.IgnoreDefaultFields = (bool)(argument.Value.Value ?? false); + break; + } + } + break; + } + } + PopulateFieldInfo(BaseTypeInfo.BaseType, BaseTypeInfo.Fields, BaseTypeInfo.BaseType.GetFullName(), token); + PopulatePolymorphicInfo(BaseTypeInfo.BaseType, BaseTypeInfo.PolymorphicInfo, token); + } + } + + private void PopulatePolymorphicInfo(INamedTypeSymbol classSymbol, PolymorphicTypeInfo polymorphicInfo, CancellationToken cancellationToken = default) + { + foreach (var attribute in classSymbol.GetAttributes()) + { + cancellationToken.ThrowIfCancellationRequested(); + switch (attribute.AttributeClass?.Name) + { case "ProtoDerivedTypeAttribute": - if (PolymorphicInfo.PolymorphicIndicateIndex == 0) - PolymorphicInfo.PolymorphicIndicateIndex = 1; // set to default + if (polymorphicInfo.PolymorphicIndicateIndex == 0) + polymorphicInfo.PolymorphicIndicateIndex = 1; // set to default // get key type - if (attribute.AttributeClass.TypeArguments.First() is not INamedTypeSymbol keyType){ + if (attribute.AttributeClass.TypeArguments.First() is not INamedTypeSymbol keyType) + { ReportDiagnostics(UnableToGetSymbol, context.GetLocation(), attribute.AttributeClass.TypeArguments.First().ToDisplayString()); return; } - PolymorphicInfo.PolymorphicKeyType = keyType; + + polymorphicInfo.PolymorphicKeyType = keyType; // get derived type, in typeof var derivedTypeConstant = attribute.ConstructorArguments.First(); @@ -108,7 +159,7 @@ public void Parse(CancellationToken token = default) // get type discriminator var typeDiscriminatorConstant = attribute.ConstructorArguments.ElementAtOrDefault(1); - PolymorphicInfo.PolymorphicTypes.Add(new PolymorphicDerivedTypeInfo + polymorphicInfo.PolymorphicTypes.Add(new PolymorphicDerivedTypeInfo { DerivedType = derivedType, Key = typeDiscriminatorConstant }); @@ -119,40 +170,32 @@ public void Parse(CancellationToken token = default) switch (argument.Key) { case "FieldNumber": - PolymorphicInfo.PolymorphicIndicateIndex = (uint)(argument.Value.Value ?? 0); + polymorphicInfo.PolymorphicIndicateIndex = (uint)(argument.Value.Value ?? 0); break; case "FallbackToBaseType": - PolymorphicInfo.PolymorphicFallbackToBaseType = (bool)(argument.Value.Value ?? true); + polymorphicInfo.PolymorphicFallbackToBaseType = + (bool)(argument.Value.Value ?? true); break; } } + break; } } - if (!TryGetNestedTypeDeclarations(context, Model, token, out var typeDeclarations)) - { - ReportDiagnostics(MustBePartialClass, context.GetLocation(), context.Identifier.Text); - return; - } - TypeDeclarations.AddRange(typeDeclarations); - - var members = context.ChildNodes() - .Where(x => x is FieldDeclarationSyntax or PropertyDeclarationSyntax) - .Cast() - .Where(x => x.ContainsAttribute("ProtoMember")); + } + + private void PopulateFieldInfo(INamedTypeSymbol classSymbol, Dictionary fields, string identifier = "" ,CancellationToken token = default) + { + var members = classSymbol.GetMembers() + .Where(x => x is IPropertySymbol or IFieldSymbol) + .Where(x => x.GetAttributes().Any(t=>t.AttributeClass?.Name == "ProtoMemberAttribute")); - foreach (var member in members) + foreach (var symbol in members) { token.ThrowIfCancellationRequested(); - var symbol = classSymbol.GetMembers().First(x => x.Name == member switch - { - FieldDeclarationSyntax fieldDeclaration => fieldDeclaration.Declaration.Variables[0].Identifier.ToString(), - PropertyDeclarationSyntax propertyDeclaration => propertyDeclaration.Identifier.ToString(), - _ => throw new InvalidOperationException("Unsupported member type.") - }); - + var member = symbol.DeclaringSyntaxReferences.FirstOrDefault()!.GetSyntax(token); if (symbol.IsStatic) { ReportDiagnostics(MustNotBeStatic, member.GetLocation(), symbol.Name, identifier); @@ -161,7 +204,7 @@ public void Parse(CancellationToken token = default) var attribute = symbol.GetAttributes().First(x => x.AttributeClass?.Name == "ProtoMemberAttribute"); int field = (int)(attribute.ConstructorArguments[0].Value ?? throw new InvalidOperationException("Unable to get field number.")); - if (Fields.ContainsKey(field)) + if (fields.ContainsKey(field)) { ReportDiagnostics(DuplicateFieldNumber, member.GetLocation(), field, identifier); continue; @@ -196,7 +239,7 @@ public void Parse(CancellationToken token = default) var valueAttribute = symbol.GetAttributes().FirstOrDefault(x => x.AttributeClass?.ToDisplayString() == ProtoValueMemberAttributeFullName); if (valueAttribute != null) ReadProtoMemberAttribute(valueAttribute, typeSymbol, ref valueWireType, member, field, identifier, ref valueSigned); - Fields[field] = new ProtoFieldInfo(symbol, typeSymbol, wireType, signed) + fields[field] = new ProtoFieldInfo(symbol, typeSymbol, wireType, signed) { ExtraTypeInfo = { @@ -208,12 +251,12 @@ public void Parse(CancellationToken token = default) else { ReadProtoMemberAttribute(attribute, typeSymbol, ref wireType, member, field, identifier, ref signed); - Fields[field] = new ProtoFieldInfo(symbol, typeSymbol, wireType, signed); + fields[field] = new ProtoFieldInfo(symbol, typeSymbol, wireType, signed); } } } - - private void ReadProtoMemberAttribute(AttributeData attribute, ITypeSymbol typeSymbol, ref WireType wireType, MemberDeclarationSyntax member, int field, string identifier, ref bool signed) + + private void ReadProtoMemberAttribute(AttributeData attribute, ITypeSymbol typeSymbol, ref WireType wireType, SyntaxNode member, int field, string identifier, ref bool signed) { foreach (var argument in attribute.NamedArguments) { diff --git a/Lagrange.Proto.Test/ProtoPolymorphismTest.cs b/Lagrange.Proto.Test/ProtoPolymorphismTest.cs index 1367c7c5..6661fe19 100644 --- a/Lagrange.Proto.Test/ProtoPolymorphismTest.cs +++ b/Lagrange.Proto.Test/ProtoPolymorphismTest.cs @@ -55,15 +55,14 @@ public void ReflectionPolymorphism_SerializeDerivedAndDeserializeBase_ReturnsCor Assert.That(deserializedA, Is.AssignableTo()); Assert.That(deserializedA.IdentifierProperty, Is.EqualTo(2)); Assert.That(((DerivedClassA)deserializedA).NameProperty, Is.EqualTo("TestName")); - + Assert.That(deserializedB, Is.AssignableTo()); Assert.That(deserializedB, Is.AssignableTo()); Assert.That(deserializedB.IdentifierProperty, Is.EqualTo(3)); Assert.That(((DerivedClassB)deserializedB).ValueProperty, Is.EqualTo(114514f)); - }); } - + [Test] public void ReflectionPolymorphism_SerializeBaseAndDeserializeDerived_ReturnsCorrectDerivedType() { @@ -73,7 +72,7 @@ public void ReflectionPolymorphism_SerializeBaseAndDeserializeDerived_ReturnsCor byte[] bytesA = ProtoSerializer.Serialize(originalA); BaseClass deserializedA = ProtoSerializer.Deserialize(bytesA); - + byte[] bytesB = ProtoSerializer.Serialize(originalB); BaseClass deserializedB = ProtoSerializer.Deserialize(bytesB); @@ -83,15 +82,14 @@ public void ReflectionPolymorphism_SerializeBaseAndDeserializeDerived_ReturnsCor Assert.That(deserializedA, Is.AssignableTo()); Assert.That(deserializedA.IdentifierProperty, Is.EqualTo(2)); Assert.That(((DerivedClassA)deserializedA).NameProperty, Is.EqualTo("TestName")); - + Assert.That(deserializedB, Is.AssignableTo()); Assert.That(deserializedB, Is.AssignableTo()); Assert.That(deserializedB.IdentifierProperty, Is.EqualTo(3)); Assert.That(((DerivedClassB)deserializedB).ValueProperty, Is.EqualTo(114514f)); - }); } - + [Test] public void ReflectionPolymorphism_SerializeDerivedAndDeserializeDerived_ReturnsCorrectDerivedType() { @@ -101,7 +99,7 @@ public void ReflectionPolymorphism_SerializeDerivedAndDeserializeDerived_Returns byte[] bytesA = ProtoSerializer.Serialize(originalA); BaseClass deserializedA = ProtoSerializer.Deserialize(bytesA); - + byte[] bytesB = ProtoSerializer.Serialize(originalB); BaseClass deserializedB = ProtoSerializer.Deserialize(bytesB); @@ -111,12 +109,11 @@ public void ReflectionPolymorphism_SerializeDerivedAndDeserializeDerived_Returns Assert.That(deserializedA, Is.AssignableTo()); Assert.That(deserializedA.IdentifierProperty, Is.EqualTo(2)); Assert.That(((DerivedClassA)deserializedA).NameProperty, Is.EqualTo("TestName")); - + Assert.That(deserializedB, Is.AssignableTo()); Assert.That(deserializedB, Is.AssignableTo()); Assert.That(deserializedB.IdentifierProperty, Is.EqualTo(3)); Assert.That(((DerivedClassB)deserializedB).ValueProperty, Is.EqualTo(114514f)); - }); } @@ -132,10 +129,10 @@ public void ProtoPackablePolymorphism_SerializeBaseAndDeserializeDerived_Returns BaseClass originalB = new DerivedClassB { ValueProperty = 114514f }; byte[] bytesA = ProtoSerializer.SerializeProtoPackable(originalA); - BaseClass deserializedA = ProtoSerializer.DeserializeProtoPackable(bytesA); - + DerivedClassA deserializedA = ProtoSerializer.DeserializeProtoPackable(bytesA); + byte[] bytesB = ProtoSerializer.SerializeProtoPackable(originalB); - BaseClass deserializedB = ProtoSerializer.DeserializeProtoPackable(bytesB); + DerivedClassB deserializedB = ProtoSerializer.DeserializeProtoPackable(bytesB); Assert.Multiple(() => { @@ -143,15 +140,14 @@ public void ProtoPackablePolymorphism_SerializeBaseAndDeserializeDerived_Returns Assert.That(deserializedA, Is.AssignableTo()); Assert.That(deserializedA.IdentifierProperty, Is.EqualTo(2)); Assert.That(((DerivedClassA)deserializedA).NameProperty, Is.EqualTo("TestName")); - + Assert.That(deserializedB, Is.AssignableTo()); Assert.That(deserializedB, Is.AssignableTo()); Assert.That(deserializedB.IdentifierProperty, Is.EqualTo(3)); Assert.That(((DerivedClassB)deserializedB).ValueProperty, Is.EqualTo(114514f)); - }); } - + [Test] public void ProtoPackablePolymorphism_SerializeDerivedAndDeserializeDerived_ReturnsCorrectDerivedType() { @@ -161,7 +157,7 @@ public void ProtoPackablePolymorphism_SerializeDerivedAndDeserializeDerived_Retu byte[] bytesA = ProtoSerializer.SerializeProtoPackable(originalA); BaseClass deserializedA = ProtoSerializer.DeserializeProtoPackable(bytesA); - + byte[] bytesB = ProtoSerializer.SerializeProtoPackable(originalB); BaseClass deserializedB = ProtoSerializer.DeserializeProtoPackable(bytesB); @@ -171,15 +167,14 @@ public void ProtoPackablePolymorphism_SerializeDerivedAndDeserializeDerived_Retu Assert.That(deserializedA, Is.AssignableTo()); Assert.That(deserializedA.IdentifierProperty, Is.EqualTo(2)); Assert.That(((DerivedClassA)deserializedA).NameProperty, Is.EqualTo("TestName")); - + Assert.That(deserializedB, Is.AssignableTo()); Assert.That(deserializedB, Is.AssignableTo()); Assert.That(deserializedB.IdentifierProperty, Is.EqualTo(3)); Assert.That(((DerivedClassB)deserializedB).ValueProperty, Is.EqualTo(114514f)); - }); } - + [Test] public void ProtoPackablePolymorphism_SerializeBaseAndDeserializeBase_ReturnsCorrectDerivedType() { @@ -189,7 +184,7 @@ public void ProtoPackablePolymorphism_SerializeBaseAndDeserializeBase_ReturnsCor byte[] bytesA = ProtoSerializer.SerializeProtoPackable(originalA); BaseClass deserializedA = ProtoSerializer.DeserializeProtoPackable(bytesA); - + byte[] bytesB = ProtoSerializer.SerializeProtoPackable(originalB); BaseClass deserializedB = ProtoSerializer.DeserializeProtoPackable(bytesB); @@ -199,12 +194,11 @@ public void ProtoPackablePolymorphism_SerializeBaseAndDeserializeBase_ReturnsCor Assert.That(deserializedA, Is.AssignableTo()); Assert.That(deserializedA.IdentifierProperty, Is.EqualTo(2)); Assert.That(((DerivedClassA)deserializedA).NameProperty, Is.EqualTo("TestName")); - + Assert.That(deserializedB, Is.AssignableTo()); Assert.That(deserializedB, Is.AssignableTo()); Assert.That(deserializedB.IdentifierProperty, Is.EqualTo(3)); Assert.That(((DerivedClassB)deserializedB).ValueProperty, Is.EqualTo(114514f)); - }); } @@ -214,56 +208,152 @@ public void ProtoPackablePolymorphism_SerializeDerivedAndDeserializeBase_Returns // Arrange DerivedClassA originalA = new DerivedClassA { NameProperty = "TestName" }; DerivedClassB originalB = new DerivedClassB { ValueProperty = 114514f }; - + byte[] bytesA = ProtoSerializer.SerializeProtoPackable(originalA); BaseClass deserializedA = ProtoSerializer.DeserializeProtoPackable(bytesA); - + byte[] bytesB = ProtoSerializer.SerializeProtoPackable(originalB); BaseClass deserializedB = ProtoSerializer.DeserializeProtoPackable(bytesB); - + Assert.Multiple(() => { Assert.That(deserializedA, Is.AssignableTo()); Assert.That(deserializedA, Is.AssignableTo()); Assert.That(deserializedA.IdentifierProperty, Is.EqualTo(2)); Assert.That(((DerivedClassA)deserializedA).NameProperty, Is.EqualTo("TestName")); - + Assert.That(deserializedB, Is.AssignableTo()); Assert.That(deserializedB, Is.AssignableTo()); Assert.That(deserializedB.IdentifierProperty, Is.EqualTo(3)); Assert.That(((DerivedClassB)deserializedB).ValueProperty, Is.EqualTo(114514f)); }); + } + + #endregion + #region Nested Polymorphism Tests + + [Test] + public void NestedPolymorphism_Reflection_SerializeAndDeserialize_ReturnsCorrectDerivedClassC() + { + // Arrange + BaseClass baseRef = + new DerivedClassC { NameProperty = "NestedName", Cannon = "BigCannon" }; + DerivedClassA classARef = (DerivedClassA)baseRef; + DerivedClassC classCRef = (DerivedClassC)baseRef; + + // Serialize from BaseClass + byte[] bytesFromBase = ProtoSerializer.Serialize(baseRef); + var deserializedFromBase = ProtoSerializer.Deserialize(bytesFromBase); + + // Serialize from DerivedClassA + byte[] bytesFromA = ProtoSerializer.Serialize(classARef); + var deserializedFromA = ProtoSerializer.Deserialize(bytesFromA); + + // Serialize from DerivedClassC + byte[] bytesFromC = ProtoSerializer.Serialize(classCRef); + var deserializedFromC = ProtoSerializer.Deserialize(bytesFromC); + + Assert.Multiple(() => + { + Assert.That(deserializedFromBase, Is.AssignableTo()); + Assert.That(((DerivedClassC)deserializedFromBase).NameProperty, Is.EqualTo("NestedName")); + Assert.That(((DerivedClassC)deserializedFromBase).Cannon, Is.EqualTo("BigCannon")); + Assert.That(((DerivedClassC)deserializedFromBase).AnotherIdentifier, Is.EqualTo(114514)); + + Assert.That(deserializedFromA, Is.AssignableTo()); + Assert.That(((DerivedClassC)deserializedFromA).NameProperty, Is.EqualTo("NestedName")); + Assert.That(((DerivedClassC)deserializedFromA).Cannon, Is.EqualTo("BigCannon")); + Assert.That(((DerivedClassC)deserializedFromA).AnotherIdentifier, Is.EqualTo(114514)); + + Assert.That(deserializedFromC, Is.AssignableTo()); + Assert.That(deserializedFromC.NameProperty, Is.EqualTo("NestedName")); + Assert.That(deserializedFromC.Cannon, Is.EqualTo("BigCannon")); + Assert.That(deserializedFromC.AnotherIdentifier, Is.EqualTo(114514)); + }); } + [Test] + public void NestedPolymorphism_ProtoPackable_SerializeAndDeserialize_ReturnsCorrectDerivedClassC() + { + // Arrange + BaseClass baseRef = + new DerivedClassC { NameProperty = "NestedName", Cannon = "BigCannon" }; + DerivedClassA classARef = (DerivedClassA)baseRef; + DerivedClassC classCRef = (DerivedClassC)baseRef; + + // Serialize from BaseClass + byte[] bytesFromBase = ProtoSerializer.SerializeProtoPackable(baseRef); + var deserializedFromBase = ProtoSerializer.DeserializeProtoPackable(bytesFromBase); + // Serialize from DerivedClassA + byte[] bytesFromA = ProtoSerializer.SerializeProtoPackable(classARef); + var deserializedFromA = ProtoSerializer.DeserializeProtoPackable(bytesFromA); + + // Serialize from DerivedClassC + byte[] bytesFromC = ProtoSerializer.SerializeProtoPackable(classCRef); + var deserializedFromC = ProtoSerializer.DeserializeProtoPackable(bytesFromC); + + Assert.Multiple(() => + { + Assert.That(deserializedFromBase, Is.AssignableTo()); + Assert.That(((DerivedClassC)deserializedFromBase).NameProperty, Is.EqualTo("NestedName")); + Assert.That(((DerivedClassC)deserializedFromBase).Cannon, Is.EqualTo("BigCannon")); + Assert.That(((DerivedClassC)deserializedFromBase).AnotherIdentifier, Is.EqualTo(114514)); + + Assert.That(deserializedFromA, Is.AssignableTo()); + Assert.That(((DerivedClassC)deserializedFromA).NameProperty, Is.EqualTo("NestedName")); + Assert.That(((DerivedClassC)deserializedFromA).Cannon, Is.EqualTo("BigCannon")); + Assert.That(((DerivedClassC)deserializedFromA).AnotherIdentifier, Is.EqualTo(114514)); + + Assert.That(deserializedFromC, Is.AssignableTo()); + Assert.That(deserializedFromC.NameProperty, Is.EqualTo("NestedName")); + Assert.That(deserializedFromC.Cannon, Is.EqualTo("BigCannon")); + Assert.That(deserializedFromC.AnotherIdentifier, Is.EqualTo(114514)); + }); + } #endregion + } #region Test Classes + [ProtoPackable] [ProtoPolymorphic(FieldNumber = 1)] [ProtoDerivedType(typeof(DerivedClassA), 2)] [ProtoDerivedType(typeof(DerivedClassB), 3)] public partial class BaseClass { - public BaseClass() : this(-1) { } + public BaseClass() { } public BaseClass(int identifier) { IdentifierProperty = identifier; } - [ProtoMember(1)] public int IdentifierProperty { get; set; } + [ProtoMember(1, NumberHandling = ProtoNumberHandling.Fixed32)] public int IdentifierProperty { get; set; } = -1; } [ProtoPackable] +[ProtoPolymorphic(FieldNumber = 4, FallbackToBaseType = true)] +[ProtoDerivedType(typeof(DerivedClassC), 114514)] public partial class DerivedClassA() : BaseClass(2) { - [ProtoMember(2)] public string? NameProperty { get; set; } + public DerivedClassA(int anotherIdentifier) : this() + { + AnotherIdentifier = anotherIdentifier; + } + + [ProtoMember(4, NumberHandling = ProtoNumberHandling.Fixed32)] public int AnotherIdentifier { get; set; } = -1; + [ProtoMember(2)] public string NameProperty { get; set; } = string.Empty; } +[ProtoPackable] +public partial class DerivedClassC() : DerivedClassA(114514) +{ + [ProtoMember(10)] public string Cannon { get; set; } = string.Empty; +} [ProtoPackable] public partial class DerivedClassB() : BaseClass(3) diff --git a/Lagrange.Proto/IProtoSerializable.cs b/Lagrange.Proto/IProtoSerializable.cs index efee9dfd..0b8ef9f3 100644 --- a/Lagrange.Proto/IProtoSerializable.cs +++ b/Lagrange.Proto/IProtoSerializable.cs @@ -10,6 +10,4 @@ public interface IProtoSerializable public static abstract void SerializeHandler(T obj, ProtoWriter writer); public static abstract int MeasureHandler(T obj); - - public static abstract ProtoPolymorphicDerivedTypeDescriptor? GetPolymorphicTypeDescriptor(TKey discriminator); } \ No newline at end of file diff --git a/Lagrange.Proto/Serialization/Metadata/ProtoObjectInfo.cs b/Lagrange.Proto/Serialization/Metadata/ProtoObjectInfo.cs index a062ab8f..9bd7c7ef 100644 --- a/Lagrange.Proto/Serialization/Metadata/ProtoObjectInfo.cs +++ b/Lagrange.Proto/Serialization/Metadata/ProtoObjectInfo.cs @@ -8,5 +8,5 @@ public class ProtoObjectInfo public Dictionary Fields { get; init; } = new(); public Func? ObjectCreator { get; init; } public bool IgnoreDefaultFields { get; init; } - public IProtoPolymorphicInfoBase? PolymorphicInfo { get; init; } + public ProtoPolymorphicInfoBase? PolymorphicInfo { get; init; } } \ No newline at end of file diff --git a/Lagrange.Proto/Serialization/Metadata/ProtoPolymorphicInfo.cs b/Lagrange.Proto/Serialization/Metadata/ProtoPolymorphicInfo.cs index 4ae56b59..15322062 100644 --- a/Lagrange.Proto/Serialization/Metadata/ProtoPolymorphicInfo.cs +++ b/Lagrange.Proto/Serialization/Metadata/ProtoPolymorphicInfo.cs @@ -1,39 +1,75 @@ -namespace Lagrange.Proto.Serialization.Metadata; +using Lagrange.Proto.Primitives; +using Lagrange.Proto.Utility; +namespace Lagrange.Proto.Serialization.Metadata; -public interface IProtoPolymorphicInfoBase + +public class ProtoPolymorphicInfoBase { public uint PolymorphicIndicateIndex { get; set; } public bool PolymorphicFallbackToBaseType { get; set; } - - public Type? GetTypeFromDiscriminator(object discriminator); - - public bool SetTypeDiscriminator(object discriminator, Type type); + public virtual ProtoPolymorphicDerivedTypeDescriptor? GetDerivedTypeDescriptorFromReader(ref ProtoReader reader) => null; + public virtual void SetDerivedTypeDescriptor(object key, ProtoPolymorphicDerivedTypeDescriptor descriptor) => throw new MissingMethodException(); + public Func? RootTypeDescriptorGetter { get; set; } + public virtual IEnumerable GetAllDerivedTypeDescriptors() => []; } -public class ProtoPolymorphicObjectInfo : IProtoPolymorphicInfoBase where TKey : IEquatable +public class ProtoPolymorphicObjectInfo : ProtoPolymorphicInfoBase where TKey : IEquatable { - public uint PolymorphicIndicateIndex { get; set; } = 0; - public bool PolymorphicFallbackToBaseType { get; set; } = true; + public Dictionary PolymorphicDerivedTypes { get; init; } = []; + public override ProtoPolymorphicDerivedTypeDescriptor? GetDerivedTypeDescriptorFromReader(ref ProtoReader reader) + { + uint tag = reader.DecodeVarIntUnsafe(); + int field = (int)(tag >> 3); + var wireType = (WireType)(tag & 0x7); + if (field != PolymorphicIndicateIndex) + { + reader.Rewind(-ProtoHelper.GetVarIntLength(tag)); + return null; + } + var converter = ProtoTypeResolver.GetConverter(); + var key = converter.Read(field, wireType, ref reader); + var rst = PolymorphicDerivedTypes.GetValueOrDefault(key); + if (rst == null && !PolymorphicFallbackToBaseType) + ThrowHelper.ThrowInvalidOperationException_UnknownPolymorphicType(typeof(TKey), key); + return rst; + } - public Type? GetTypeFromDiscriminator(object discriminator) + public override void SetDerivedTypeDescriptor(object o, ProtoPolymorphicDerivedTypeDescriptor descriptor) { - return PolymorphicDerivedTypes.GetValueOrDefault((TKey)discriminator); + if (o is TKey key) + PolymorphicDerivedTypes[key] = descriptor; + else + ThrowHelper.ThrowInvalidOperationException_UnknownPolymorphicType(typeof(TKey), o); } - public bool SetTypeDiscriminator(object discriminator, Type type) + public override IEnumerable GetAllDerivedTypeDescriptors() { - PolymorphicDerivedTypes[(TKey)discriminator] = type; - return true; + return PolymorphicDerivedTypes.Values; } +} + - public Dictionary PolymorphicDerivedTypes { get; init; } = []; +public class ProtoPolymorphicDerivedTypeDescriptor +{ + public required Type CurrentType { get; init; } + public Func> FieldsGetter { get; init; } = () => throw new MissingMethodException(); + public Func IgnoreDefaultFieldsGetter { get; init; } = () => false; + public Func PolymorphicInfoGetter { get; init; } = () => null; + + public virtual object? CreateObject() + { + return null; + } } -public class ProtoPolymorphicDerivedTypeDescriptor +public class ProtoPolymorphicDerivedTypeDescriptor : ProtoPolymorphicDerivedTypeDescriptor { - public Dictionary Fields { get; init; } = []; public Func ObjectCreator { get; init; } = null!; - public bool IgnoreDefaultFields { get; init; } + + public override object? CreateObject() + { + return ObjectCreator(); + } } \ No newline at end of file diff --git a/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs b/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs index df62e50f..7f917aca 100644 --- a/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs +++ b/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs @@ -78,9 +78,23 @@ internal static ProtoObjectInfo CreateObjectInfo() }; } - internal static IProtoPolymorphicInfoBase? PopulatePolymorphicInfo() + internal static ProtoPolymorphicInfoBase? PopulatePolymorphicInfo() { var type = typeof(T); + ProtoPolymorphicInfoBase? objectInfo = null; + // goto root + var checkingType = type; + while (checkingType.BaseType != null) + { + // check if base type has polymorphic attribute + var basePolymorphicAttributes = checkingType.BaseType.GetCustomAttributes(typeof(ProtoDerivedTypeAttribute<>)) + .OfType().ToArray(); + if (basePolymorphicAttributes.Length > 0) checkingType = checkingType.BaseType; + else break; + } + + if (checkingType == type) checkingType = null; + var polymorphicAttributes = type.GetCustomAttributes(typeof(ProtoDerivedTypeAttribute<>)) .OfType().ToArray(); if (polymorphicAttributes.Length > 0) @@ -93,25 +107,33 @@ internal static ProtoObjectInfo CreateObjectInfo() // get the TKey from first var firstAttr = polymorphicAttributes[0]; var keyType = firstAttr.GetType().GetGenericArguments()[0]; - var objectInfo = (IProtoPolymorphicInfoBase?) typeof(ProtoPolymorphicObjectInfo<>).MakeGenericType(keyType) + objectInfo = (ProtoPolymorphicInfoBase?) typeof(ProtoPolymorphicObjectInfo<>).MakeGenericType(keyType) .GetConstructor(Type.EmptyTypes)?.Invoke(null); Debug.Assert(objectInfo != null); objectInfo.PolymorphicIndicateIndex = polymorphicFieldNumber; objectInfo.PolymorphicFallbackToBaseType = fallbackToBaseType; - + foreach (var attr in polymorphicAttributes) { var key = attr.GetType().GetProperty(nameof(ProtoDerivedTypeAttribute.TypeDiscriminator), BindingFlags.NonPublic | BindingFlags.Instance) ?.GetValue(attr); if (key == null) ThrowHelper.ThrowInvalidOperationException_UnknownPolymorphicType(type, attr.DerivedType); - objectInfo.SetTypeDiscriminator(key, attr.DerivedType); + objectInfo.SetDerivedTypeDescriptor(key, ProtoSerializer.GetObjectInfoReflection(attr.DerivedType)); } - - return objectInfo; } - return null; + if (checkingType is not null) + { + objectInfo ??= new ProtoPolymorphicInfoBase() + { + PolymorphicIndicateIndex = 0, + PolymorphicFallbackToBaseType = true + }; + objectInfo.RootTypeDescriptorGetter = () => ProtoSerializer.GetObjectInfoReflection(checkingType); + } + + return objectInfo; } internal static Dictionary CreateTypeFieldInfo(Type type) diff --git a/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs b/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs index 90469b04..feb809ab 100644 --- a/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs +++ b/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs @@ -27,30 +27,39 @@ private static T DeserializeProtoPackableCore(ref ProtoReader reader) where T T target = objectInfo.ObjectCreator(); var fields = objectInfo.Fields; - if (objectInfo.PolymorphicInfo?.PolymorphicIndicateIndex is > 0) - { - // has polymorphic type, read the first field to determine the actual type - uint firstTag = reader.DecodeVarIntUnsafe(); + var polymorphicInfo = objectInfo.PolymorphicInfo; - if (firstTag >>> 3 != objectInfo.PolymorphicInfo.PolymorphicIndicateIndex) + if (polymorphicInfo?.PolymorphicIndicateIndex is > 0) + { + var root = polymorphicInfo.RootTypeDescriptorGetter?.Invoke(); + if (root is not null) { - ThrowHelper.ThrowInvalidOperationException_PolymorphicFieldNotFirst(typeof(T), - objectInfo.PolymorphicInfo.PolymorphicIndicateIndex, firstTag >>> 3); + fields = root.FieldsGetter(); + polymorphicInfo = root.PolymorphicInfoGetter(); } - - var firstField = objectInfo.Fields[firstTag]; - firstField.Read(ref reader, target); - var polyTypeKey = firstField.Get?.Invoke(target); - - var typeDescriptor = T.GetPolymorphicTypeDescriptor(polyTypeKey); + } + + polyDeserialize: + if (polymorphicInfo?.PolymorphicIndicateIndex is > 0) + { + var typeDescriptor = polymorphicInfo.GetDerivedTypeDescriptorFromReader(ref reader); if (typeDescriptor is not null) { - fields = typeDescriptor.Fields; - target = typeDescriptor.ObjectCreator.Invoke(); - } - else if (!objectInfo.PolymorphicInfo.PolymorphicFallbackToBaseType) - { - ThrowHelper.ThrowInvalidOperationException_UnknownPolymorphicType(typeof(T), polyTypeKey!); + fields = typeDescriptor.FieldsGetter(); + polymorphicInfo = typeDescriptor.PolymorphicInfoGetter(); + if (!typeof(T).IsAssignableTo(typeDescriptor.CurrentType)) + { + if (typeDescriptor.CreateObject() is T newObj) + { + target = newObj; + } + else + { + ThrowHelper.ThrowInvalidOperationException_CanNotCreateObject(typeDescriptor.CurrentType); + } + } + + goto polyDeserialize; } } @@ -98,41 +107,35 @@ private static T DeserializeProtoPackableCore(ref ProtoReader reader) where T var boxed = (object?)target; // avoid multiple times of boxing if (boxed is null) ThrowHelper.ThrowInvalidOperationException_CanNotCreateObject(typeof(T)); var fieldInfos = converter.ObjectInfo.Fields; - var polymorphicInfo = converter.ObjectInfo.PolymorphicInfo; - startDeserialize: - // polymorphic type + var polymorphicInfo = converter.ObjectInfo.PolymorphicInfo; + if (polymorphicInfo?.PolymorphicIndicateIndex is > 0) { - // has polymorphic type, read the first field to determine the actual type - uint firstTag = reader.DecodeVarIntUnsafe(); - - if (firstTag >>> 3 != polymorphicInfo.PolymorphicIndicateIndex) - { - ThrowHelper.ThrowInvalidOperationException_PolymorphicFieldNotFirst(typeof(T), - polymorphicInfo.PolymorphicIndicateIndex, firstTag >>> 3); - } - - var firstField = converter.ObjectInfo.Fields[firstTag]; - firstField.Read(ref reader, boxed); - var polyTypeKey = firstField.Get?.Invoke(boxed); - if (polyTypeKey is null) - { - ThrowHelper.ThrowInvalidOperationException_FailedParsePolymorphicType(typeof(T), firstTag); - } - - if (polymorphicInfo.GetTypeFromDiscriminator(polyTypeKey) is { } polyType) + var root = polymorphicInfo.RootTypeDescriptorGetter?.Invoke(); + if (root is not null) { - (fieldInfos, var objectCreator, polymorphicInfo) = GetObjectInfoReflection(polyType); - target = objectCreator(); - boxed = target; - if (boxed is null) ThrowHelper.ThrowInvalidOperationException_CanNotCreateObject(polyType); + fieldInfos = root.FieldsGetter(); + polymorphicInfo = root.PolymorphicInfoGetter(); } - else if (!polymorphicInfo.PolymorphicFallbackToBaseType) + } + + startDeserialize: + if (polymorphicInfo?.PolymorphicIndicateIndex is > 0) + { + var polymorphicDescriptor = polymorphicInfo.GetDerivedTypeDescriptorFromReader(ref reader); + if (polymorphicDescriptor is not null) { - ThrowHelper.ThrowInvalidOperationException_UnknownPolymorphicType(typeof(T), polyTypeKey); + fieldInfos = polymorphicDescriptor.FieldsGetter(); + polymorphicInfo = polymorphicDescriptor.PolymorphicInfoGetter(); + + if (!typeof(T).IsAssignableTo(polymorphicDescriptor.CurrentType)) + { + boxed = polymorphicDescriptor.CreateObject(); + target = (T)boxed!; + if (boxed is null) ThrowHelper.ThrowInvalidOperationException_CanNotCreateObject(polymorphicDescriptor.CurrentType); + } + goto startDeserialize; } - - goto startDeserialize; } diff --git a/Lagrange.Proto/Serialization/ProtoSerializer.Helpers.cs b/Lagrange.Proto/Serialization/ProtoSerializer.Helpers.cs index 9694faf8..b3c43303 100644 --- a/Lagrange.Proto/Serialization/ProtoSerializer.Helpers.cs +++ b/Lagrange.Proto/Serialization/ProtoSerializer.Helpers.cs @@ -33,10 +33,8 @@ internal static ProtoObjectConverter GetConverterOf() return converter; } - internal static (Dictionary Fields, Func ObjectCreator, IProtoPolymorphicInfoBase? polymorphicInfo) GetObjectInfoReflection(Type polyType) + internal static ProtoPolymorphicDerivedTypeDescriptor GetObjectInfoReflection(Type polyType) { - Debug.Assert(polyType != typeof(T)); - Debug.Assert(polyType.IsAssignableTo(typeof(T))); var method = typeof(ProtoSerializer).GetMethod(nameof(GetConverterOf), BindingFlags.Static | BindingFlags.NonPublic); Debug.Assert(method != null); @@ -50,9 +48,18 @@ internal static (Dictionary Fields, Func ObjectCreator, .GetProperty("ObjectCreator")!.GetValue(polyObjectInfo)!; var fieldInfos = (Dictionary)polyObjectInfo.GetType() .GetProperty("Fields")!.GetValue(polyObjectInfo)!; - var polymorphicInfo = (IProtoPolymorphicInfoBase)polyObjectInfo.GetType() + var polymorphicInfo = (ProtoPolymorphicInfoBase)polyObjectInfo.GetType() .GetProperty("PolymorphicInfo")!.GetValue(polyObjectInfo)!; - return (fieldInfos, ObjectCreator,polymorphicInfo); - T ObjectCreator() => (T)polyCreator.GetType().GetMethod("Invoke")!.Invoke(polyCreator, null)!; + var ignoreDefaultFields = polyObjectInfo.GetType() + .GetProperty("IgnoreDefaultFields")!.GetValue(polyObjectInfo) is true; + return new ProtoPolymorphicDerivedTypeDescriptor + { + CurrentType = polyType, + FieldsGetter = () => fieldInfos, + IgnoreDefaultFieldsGetter = () => ignoreDefaultFields, + PolymorphicInfoGetter = () => polymorphicInfo, + ObjectCreator = ObjectCreator + }; + object? ObjectCreator() => (object?)polyCreator.GetType().GetMethod("Invoke")!.Invoke(polyCreator, null)!; } } \ No newline at end of file diff --git a/Lagrange.Proto/Serialization/ProtoSerializer.Serialize.cs b/Lagrange.Proto/Serialization/ProtoSerializer.Serialize.cs index ccee312f..5ea60778 100644 --- a/Lagrange.Proto/Serialization/ProtoSerializer.Serialize.cs +++ b/Lagrange.Proto/Serialization/ProtoSerializer.Serialize.cs @@ -108,32 +108,46 @@ private static void SerializeProtoPackableCore(ProtoWriter writer, T obj) whe object? boxed = obj; // avoid multiple times of boxing if (boxed is null) return; var fields = objectInfo.Fields; - uint skipTag = 0; + List skipTags = []; var polymorphicInfo = converter.ObjectInfo.PolymorphicInfo; + var ignoreDefaultFields = objectInfo.IgnoreDefaultFields; + if (polymorphicInfo?.RootTypeDescriptorGetter is not null) + { + var discriminator = polymorphicInfo.RootTypeDescriptorGetter(); + fields = discriminator.FieldsGetter(); + polymorphicInfo = discriminator.PolymorphicInfoGetter(); + ignoreDefaultFields = discriminator.IgnoreDefaultFieldsGetter(); + } + + var actualType = obj?.GetType(); // check polymorphic type + startSerialize: if (polymorphicInfo?.PolymorphicIndicateIndex is > 0) { // has polymorphic type var index = polymorphicInfo.PolymorphicIndicateIndex; - var fieldInfo = objectInfo.Fields.FirstOrDefault(t=>t.Value.Field == index); + var fieldInfo = fields.FirstOrDefault(t=>t.Value.Field == index); if (fieldInfo.Value is null) ThrowHelper.ThrowInvalidOperationException_NullPolymorphicDiscriminator(typeof(T)); var discriminator = fieldInfo.Value.Get?.Invoke(boxed); if (discriminator is null) ThrowHelper.ThrowInvalidOperationException_NullPolymorphicDiscriminator(typeof(T)); - if (objectInfo.PolymorphicInfo!.GetTypeFromDiscriminator(discriminator) is not { } derivedTypeInfo) - { - ThrowHelper.ThrowInvalidOperationException_NullPolymorphicDiscriminator(typeof(T)); - return; // make compiler happy - } - skipTag = fieldInfo.Key; + skipTags.Add(fieldInfo.Key); writer.EncodeVarInt(fieldInfo.Key); fieldInfo.Value.Write(writer, boxed); - - (fields, _, _) = GetObjectInfoReflection(derivedTypeInfo); + foreach (var descriptor in polymorphicInfo.GetAllDerivedTypeDescriptors()) + { + if (actualType?.IsAssignableTo(descriptor.CurrentType) is true) + { + fields = descriptor.FieldsGetter(); + ignoreDefaultFields = descriptor.IgnoreDefaultFieldsGetter(); + polymorphicInfo = descriptor.PolymorphicInfoGetter(); + goto startSerialize; + } + } } foreach (var (tag, info) in fields) { - if (skipTag != tag && info.ShouldSerialize(boxed, objectInfo.IgnoreDefaultFields)) + if (!skipTags.Contains(tag) && info.ShouldSerialize(boxed, ignoreDefaultFields)) { writer.EncodeVarInt(tag); info.Write(writer, boxed); From d72fa4fcb5a2ee30d7b788a8b27a1043308c6dd1 Mon Sep 17 00:00:00 2001 From: Kengwang Date: Mon, 6 Oct 2025 00:43:36 +0800 Subject: [PATCH 7/7] [Proto] Fix warnings --- .../ProtoSourceGenerator.Emitter.Serialize.cs | 2 +- .../ProtoSourceGenerator.Emitter.TypeInfo.cs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.Serialize.cs b/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.Serialize.cs index b636c60f..06f26d5f 100644 --- a/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.Serialize.cs +++ b/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.Serialize.cs @@ -33,7 +33,7 @@ private void EmitSerializeMethod(SourceWriter source) source.WriteLine("{"); source.Indentation++; - if (parser.BaseTypeInfo.BaseType.GetFullName() != _fullQualifiedName && parser.BaseTypeInfo?.PolymorphicInfo.PolymorphicIndicateIndex is > 0) + if (parser.BaseTypeInfo?.BaseType.GetFullName() != _fullQualifiedName && parser.BaseTypeInfo?.PolymorphicInfo.PolymorphicIndicateIndex is > 0) { source.WriteLine($"{parser.BaseTypeInfo.BaseType.GetFullName()}.SerializeHandler({ObjectVarName},{WriterVarName});"); } diff --git a/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.TypeInfo.cs b/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.TypeInfo.cs index 76852adb..efd096c1 100644 --- a/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.TypeInfo.cs +++ b/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.TypeInfo.cs @@ -72,7 +72,7 @@ private void EmitTypeInfo(SourceWriter source) EmitFieldsInfo(source, parser.Fields); source.WriteLine($"ObjectCreator = () => new {_fullQualifiedName}(),"); - EmitPolymorphicInfo(source, parser.PolymorphicInfo, parser.BaseTypeInfo); + EmitPolymorphicInfo(source, parser.PolymorphicInfo, parser.BaseTypeInfo!); source.WriteLine($"IgnoreDefaultFields = {parser.IgnoreDefaultFields.ToString().ToLower()}"); source.Indentation--;