diff --git a/Lagrange.Proto.Generator/Entity/PolymorphicTypeInfo.cs b/Lagrange.Proto.Generator/Entity/PolymorphicTypeInfo.cs new file mode 100644 index 00000000..47391da0 --- /dev/null +++ b/Lagrange.Proto.Generator/Entity/PolymorphicTypeInfo.cs @@ -0,0 +1,28 @@ +using Microsoft.CodeAnalysis; + +namespace Lagrange.Proto.Generator.Entity; + +public class PolymorphicTypeInfo +{ + public INamedTypeSymbol PolymorphicKeyType { get; internal set; } = null!; + + public uint PolymorphicIndicateIndex { get; internal set; } = 0; + + public bool PolymorphicFallbackToBaseType { get; internal set; } = true; + + public List PolymorphicTypes { get; } = []; +} + +public class PolymorphicDerivedTypeInfo +{ + public INamedTypeSymbol DerivedType { get; internal set; } = null!; + public TypedConstant Key { get; internal set; } +} + +public class BaseTypeInfo +{ + public INamedTypeSymbol BaseType { get; internal set; } = null!; + public PolymorphicTypeInfo PolymorphicInfo { get; internal set; } = new(); + public bool IgnoreDefaultFields { get; internal set; } + public Dictionary Fields { get; } = new(); +} \ No newline at end of file diff --git a/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.Serialize.cs b/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.Serialize.cs index f224fa96..06f26d5f 100644 --- a/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.Serialize.cs +++ b/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.Serialize.cs @@ -25,22 +25,65 @@ private partial class Emitter private const string EncodeStringMethodName = "EncodeString"; private const string EncodeBytesMethodName = "EncodeBytes"; private const string EncodeResolvableMethodName = "EncodeResolvable"; + private void EmitSerializeMethod(SourceWriter source) { source.WriteLine($"public static void SerializeHandler({_fullQualifiedName} {ObjectVarName}, {ProtoWriterTypeRef} {WriterVarName})"); source.WriteLine("{"); source.Indentation++; - + + if (parser.BaseTypeInfo?.BaseType.GetFullName() != _fullQualifiedName && parser.BaseTypeInfo?.PolymorphicInfo.PolymorphicIndicateIndex is > 0) + { + source.WriteLine($"{parser.BaseTypeInfo.BaseType.GetFullName()}.SerializeHandler({ObjectVarName},{WriterVarName});"); + } + else + { + source.WriteLine($"SerializeHandlerCore({ObjectVarName}, {WriterVarName});"); + } + + source.Indentation--; + source.WriteLine("}"); + source.WriteLine(); + + source.WriteLine($"public static void SerializeHandlerCore({_fullQualifiedName} {ObjectVarName}, {ProtoWriterTypeRef} {WriterVarName})"); + source.WriteLine("{"); + source.Indentation++; + if (parser.PolymorphicInfo.PolymorphicIndicateIndex > 0) + { + // write indicate field first + var idx = (int)parser.PolymorphicInfo.PolymorphicIndicateIndex; + var indicatorField = parser.Fields[idx]; + EmitMembers(source, idx, indicatorField); + source.WriteLine(); + + source.WriteLine($"switch ({ObjectVarName})"); + source.WriteLine("{"); + source.Indentation++; + for (var index = 0; index < parser.PolymorphicInfo.PolymorphicTypes.Count; index++) + { + var kv = parser.PolymorphicInfo.PolymorphicTypes[index]; + source.WriteLine($"case {kv.DerivedType.GetFullName()} _:"); + source.Indentation++; + source.WriteLine($"{kv.DerivedType.GetFullName()}.SerializeHandlerCore(({kv.DerivedType.GetFullName()}){ObjectVarName}, {WriterVarName});"); + source.WriteLine("break;"); + source.Indentation--; + } + source.Indentation--; + source.WriteLine("}"); + source.WriteLine(); + } foreach (var kv in parser.Fields) { int field = kv.Key; + if (parser.PolymorphicInfo.PolymorphicIndicateIndex == field) continue; // already written var info = kv.Value; EmitMembers(source, field, info); source.WriteLine(); } + source.Indentation--; source.WriteLine("}"); } diff --git a/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.TypeInfo.cs b/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.TypeInfo.cs index 968ef61c..efd096c1 100644 --- a/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.TypeInfo.cs +++ b/Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.TypeInfo.cs @@ -2,6 +2,7 @@ using Lagrange.Proto.Generator.Utility; using Lagrange.Proto.Generator.Utility.Extension; using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; namespace Lagrange.Proto.Generator; @@ -13,6 +14,9 @@ private partial class Emitter private const string TypeInfoPropertyName = "TypeInfo"; private const string ProtoObjectInfoTypeRef = "global::Lagrange.Proto.Serialization.Metadata.ProtoObjectInfo<{0}>"; + private const string ProtoPolymorphicObjectInfoTypeRef = "global::Lagrange.Proto.Serialization.Metadata.ProtoPolymorphicObjectInfo<{0}>"; + private const string ProtoPolymorphicDerivedTypeDescriptorTypeRef = "global::Lagrange.Proto.Serialization.Metadata.ProtoPolymorphicDerivedTypeDescriptor<{0}>"; + private const string ProtoPolymorphicDerivedTypeDescriptorBaseTypeRef = "global::Lagrange.Proto.Serialization.Metadata.ProtoPolymorphicDerivedTypeDescriptor"; private const string ProtoFieldInfoTypeRef = "global::Lagrange.Proto.Serialization.Metadata.ProtoFieldInfo"; private const string ProtoFieldInfoGenericTypeRef = "global::Lagrange.Proto.Serialization.Metadata.ProtoFieldInfo<{0}>"; private const string ProtoMapFieldInfoGenericTypeRef = "global::Lagrange.Proto.Serialization.Metadata.ProtoMapFieldInfo<{0}, {1}, {2}>"; @@ -37,11 +41,15 @@ private partial class Emitter private void EmitTypeInfo(SourceWriter source) { + source.WriteLine("#pragma warning disable CS0108"); + source.WriteLine(); source.WriteLine($"public static {ProtoObjectInfoTypeRefGeneric}? {TypeInfoFieldName};"); source.WriteLine(); source.WriteLine($"public static {ProtoObjectInfoTypeRefGeneric} {TypeInfoPropertyName} => {TypeInfoFieldName} ??= GetTypeInfo();"); source.WriteLine(); + source.WriteLine("#pragma warning restore CS0108"); + source.WriteLine(); source.WriteLine($"private static {ProtoObjectInfoTypeRefGeneric} GetTypeInfo()"); source.WriteLine('{'); @@ -60,28 +68,20 @@ private void EmitTypeInfo(SourceWriter source) source.WriteLine($"return new {ProtoObjectInfoTypeRefGeneric}()"); source.WriteLine('{'); source.Indentation++; - - source.WriteLine($"Fields = new global::System.Collections.Generic.Dictionary()"); - source.WriteLine('{'); - source.Indentation++; - foreach (var kv in parser.Fields) - { - int field = kv.Key; - var info = kv.Value; - if (info.ExtraTypeInfo.Count == 0) EmitFieldInfo(source, field, info); - else EmitMapFieldInfo(source, field, info); - } - source.Indentation--; - source.WriteLine("},"); + EmitFieldsInfo(source, parser.Fields); source.WriteLine($"ObjectCreator = () => new {_fullQualifiedName}(),"); + EmitPolymorphicInfo(source, parser.PolymorphicInfo, parser.BaseTypeInfo!); source.WriteLine($"IgnoreDefaultFields = {parser.IgnoreDefaultFields.ToString().ToLower()}"); source.Indentation--; source.WriteLine("};"); source.Indentation--; source.WriteLine('}'); + + EmitPolymorphicDerivedTypeDescriptor(source, parser.PolymorphicInfo); + return; static void EmitByTypeSymbol(SourceWriter source, ITypeSymbol typeSymbol) @@ -106,6 +106,23 @@ static void EmitByTypeSymbol(SourceWriter source, ITypeSymbol typeSymbol) } } + private void EmitFieldsInfo(SourceWriter source,Dictionary fields ) + { + source.WriteLine($"Fields = new global::System.Collections.Generic.Dictionary()"); + source.WriteLine('{'); + source.Indentation++; + foreach (var kv in fields) + { + int field = kv.Key; + var info = kv.Value; + + if (info.ExtraTypeInfo.Count == 0) EmitFieldInfo(source, field, info); + else EmitMapFieldInfo(source, field, info); + } + source.Indentation--; + source.WriteLine("},"); + } + private void EmitFieldInfo(SourceWriter source, int field, ProtoFieldInfo info) { int tag = field << 3 | (byte)info.WireType; @@ -142,5 +159,95 @@ private void EmitMapFieldInfo(SourceWriter source, int field, ProtoFieldInfo inf source.Indentation--; source.WriteLine("},"); } + + private void EmitPolymorphicInfo(SourceWriter source, PolymorphicTypeInfo polymorphicInfo, BaseTypeInfo baseTypeInfo) + { + if (polymorphicInfo.PolymorphicIndicateIndex == 0) return; + source.WriteLine($"PolymorphicInfo = new {string.Format(ProtoPolymorphicObjectInfoTypeRef, polymorphicInfo.PolymorphicKeyType.GetFullName())}()"); + source.WriteLine('{'); + source.Indentation++; + + source.WriteLine($"PolymorphicIndicateIndex = {polymorphicInfo.PolymorphicIndicateIndex},"); + source.WriteLine($"PolymorphicFallbackToBaseType = {polymorphicInfo.PolymorphicFallbackToBaseType.ToString().ToLower()},"); + if (parser.BaseTypeInfo?.BaseType is not null && parser.BaseTypeInfo.BaseType.GetFullName() != _fullQualifiedName) + { + source.WriteLine($"RootTypeDescriptorGetter = () => new {string.Format(ProtoPolymorphicDerivedTypeDescriptorTypeRef, parser.BaseTypeInfo.BaseType.GetFullName())}()"); + source.WriteLine('{'); + source.Indentation++; + source.WriteLine($"FieldsGetter = () => {parser.BaseTypeInfo.BaseType.GetFullName()}.{TypeInfoPropertyName}.Fields,"); + source.WriteLine($"ObjectCreator = () => new {parser.BaseTypeInfo.BaseType.GetFullName()}(),"); + source.WriteLine($"IgnoreDefaultFieldsGetter = () => {parser.BaseTypeInfo.BaseType.GetFullName()}.{TypeInfoPropertyName}.IgnoreDefaultFields,"); + source.WriteLine($"PolymorphicInfoGetter = () => {parser.BaseTypeInfo.BaseType.GetFullName()}.{TypeInfoPropertyName}.PolymorphicInfo,"); + source.WriteLine($"CurrentType = typeof({parser.BaseTypeInfo.BaseType.GetFullName()})"); + source.Indentation--; + source.WriteLine("},"); + } + source.WriteLine( + $"PolymorphicDerivedTypes = new global::System.Collections.Generic.Dictionary<{polymorphicInfo.PolymorphicKeyType.GetFullName()}, {ProtoPolymorphicDerivedTypeDescriptorBaseTypeRef}>()"); + source.WriteLine('{'); + source.Indentation++; + + foreach (var derivedTypeInfo in polymorphicInfo.PolymorphicTypes) + { + source.WriteLine($"[{derivedTypeInfo.Key.ToCSharpString()}] = new {string.Format(ProtoPolymorphicDerivedTypeDescriptorTypeRef, baseTypeInfo.BaseType.GetFullName())}()"); + source.WriteLine('{'); + source.Indentation++; + source.WriteLine($"CurrentType = typeof({derivedTypeInfo.DerivedType.GetFullName()}),"); + source.WriteLine($"FieldsGetter = () => {derivedTypeInfo.DerivedType.GetFullName()}.{TypeInfoPropertyName}.Fields,"); + source.WriteLine($"ObjectCreator = () => ({baseTypeInfo.BaseType.GetFullName()})new {derivedTypeInfo.DerivedType.GetFullName()}(),"); + source.WriteLine($"IgnoreDefaultFieldsGetter = () => {derivedTypeInfo.DerivedType.GetFullName()}.{TypeInfoPropertyName}.IgnoreDefaultFields,"); + source.WriteLine($"PolymorphicInfoGetter = () => {derivedTypeInfo.DerivedType.GetFullName()}.{TypeInfoPropertyName}.PolymorphicInfo"); + source.Indentation--; + source.WriteLine("},"); + } + + source.Indentation--; + source.WriteLine("}"); + source.Indentation--; + source.WriteLine("},"); + } + + private void EmitPolymorphicDerivedTypeDescriptor(SourceWriter source, PolymorphicTypeInfo polymorphicInfo) + { + + source.WriteLine("#pragma warning disable CS0108"); + source.WriteLine( + $"public static {string.Format(ProtoPolymorphicDerivedTypeDescriptorTypeRef,_fullQualifiedName)}? GetPolymorphicTypeDescriptor(TKey discriminator)"); + source.WriteLine('{'); + source.Indentation++; + if (polymorphicInfo.PolymorphicIndicateIndex > 0) + { + source.WriteLine("switch (discriminator)"); + source.WriteLine('{'); + source.Indentation++; + + foreach (var derivedTypeInfo in polymorphicInfo.PolymorphicTypes) + { + source.WriteLine($"case {derivedTypeInfo.Key.ToCSharpString()}: return new {string.Format(ProtoPolymorphicDerivedTypeDescriptorTypeRef, _fullQualifiedName)}()"); + source.WriteLine('{'); + source.Indentation++; + source.WriteLine($"FieldsGetter = () => {derivedTypeInfo.DerivedType.GetFullName()}.{TypeInfoPropertyName}.Fields,"); + source.WriteLine($"ObjectCreator = () => ({_fullQualifiedName})new {derivedTypeInfo.DerivedType.GetFullName()}(),"); + source.WriteLine($"IgnoreDefaultFieldsGetter = () => {derivedTypeInfo.DerivedType.GetFullName()}.{TypeInfoPropertyName}.IgnoreDefaultFields,"); + source.WriteLine($"PolymorphicInfoGetter = () => {derivedTypeInfo.DerivedType.GetFullName()}.{TypeInfoPropertyName}.PolymorphicInfo,"); + source.WriteLine($"CurrentType = typeof({derivedTypeInfo.DerivedType.GetFullName()})"); + source.Indentation--; + source.WriteLine("};"); + } + + source.WriteLine("default: return null;"); + source.Indentation--; + source.WriteLine("}"); + } + else + { + source.WriteLine("return null;"); + } + + source.Indentation--; + source.WriteLine('}'); + source.WriteLine("#pragma warning restore CS0108"); + source.WriteLine(); + } } } \ No newline at end of file diff --git a/Lagrange.Proto.Generator/ProtoSourceGenerator.Parser.cs b/Lagrange.Proto.Generator/ProtoSourceGenerator.Parser.cs index 1ab8f999..dc185599 100644 --- a/Lagrange.Proto.Generator/ProtoSourceGenerator.Parser.cs +++ b/Lagrange.Proto.Generator/ProtoSourceGenerator.Parser.cs @@ -30,6 +30,10 @@ private class Parser(ClassDeclarationSyntax context, SemanticModel model) public bool IgnoreDefaultFields { get; private set; } public Dictionary Fields { get; } = new(); + + public PolymorphicTypeInfo PolymorphicInfo { get; } = new(); + + public BaseTypeInfo? BaseTypeInfo { get; private set; } = null; public void Parse(CancellationToken token = default) { @@ -41,22 +45,28 @@ public void Parse(CancellationToken token = default) ReportDiagnostics(UnableToGetSymbol, context.GetLocation(), context.Identifier.Text); return; } - + if (!classSymbol.Constructors.Any(x => x is { Parameters.Length: 0, DeclaredAccessibility: Accessibility.Public })) { ReportDiagnostics(MustContainParameterlessConstructor, context.GetLocation(), context.Identifier.Text); return; } - foreach (var argument in classSymbol.GetAttributes().SelectMany(x => x.NamedArguments)) + foreach (var attribute in classSymbol.GetAttributes()) { - switch (argument.Key) + switch (attribute.AttributeClass?.Name) { - case "IgnoreDefaultFields": - { - IgnoreDefaultFields = (bool)(argument.Value.Value ?? false); + case "ProtoPackableAttribute": + foreach (var argument in attribute.NamedArguments) + { + switch (argument.Key) + { + case "IgnoreDefaultFields": + IgnoreDefaultFields = (bool)(argument.Value.Value ?? false); + break; + } + } break; - } } } @@ -66,23 +76,126 @@ public void Parse(CancellationToken token = default) return; } TypeDeclarations.AddRange(typeDeclarations); - - var members = context.ChildNodes() - .Where(x => x is FieldDeclarationSyntax or PropertyDeclarationSyntax) - .Cast() - .Where(x => x.ContainsAttribute("ProtoMember")); - foreach (var member in members) + PopulateFieldInfo(classSymbol, Fields, identifier, token); + PopulatePolymorphicInfo(classSymbol, PolymorphicInfo, token); + + // Handling BaseType + var checkingBaseTypeSymbol = classSymbol; + while (checkingBaseTypeSymbol?.BaseType is not null) { token.ThrowIfCancellationRequested(); + if (!checkingBaseTypeSymbol.BaseType.GetAttributes() + .Any(t => t.AttributeClass?.Name == "ProtoPackableAttribute")) + { + break; + } - var symbol = classSymbol.GetMembers().First(x => x.Name == member switch + checkingBaseTypeSymbol = checkingBaseTypeSymbol.BaseType; + PopulateFieldInfo(checkingBaseTypeSymbol, Fields, identifier, token); + } + + if (checkingBaseTypeSymbol is not null) + { + BaseTypeInfo = new BaseTypeInfo(); + BaseTypeInfo.BaseType = checkingBaseTypeSymbol; + foreach (var attribute in checkingBaseTypeSymbol.GetAttributes()) { - FieldDeclarationSyntax fieldDeclaration => fieldDeclaration.Declaration.Variables[0].Identifier.ToString(), - PropertyDeclarationSyntax propertyDeclaration => propertyDeclaration.Identifier.ToString(), - _ => throw new InvalidOperationException("Unsupported member type.") - }); + switch (attribute.AttributeClass?.Name) + { + case "ProtoPackableAttribute": + foreach (var argument in attribute.NamedArguments) + { + switch (argument.Key) + { + case "IgnoreDefaultFields": + BaseTypeInfo.IgnoreDefaultFields = (bool)(argument.Value.Value ?? false); + break; + } + } + break; + } + } + PopulateFieldInfo(BaseTypeInfo.BaseType, BaseTypeInfo.Fields, BaseTypeInfo.BaseType.GetFullName(), token); + PopulatePolymorphicInfo(BaseTypeInfo.BaseType, BaseTypeInfo.PolymorphicInfo, token); + } + } + + private void PopulatePolymorphicInfo(INamedTypeSymbol classSymbol, PolymorphicTypeInfo polymorphicInfo, CancellationToken cancellationToken = default) + { + foreach (var attribute in classSymbol.GetAttributes()) + { + cancellationToken.ThrowIfCancellationRequested(); + switch (attribute.AttributeClass?.Name) + { + case "ProtoDerivedTypeAttribute": + if (polymorphicInfo.PolymorphicIndicateIndex == 0) + polymorphicInfo.PolymorphicIndicateIndex = 1; // set to default + // get key type + if (attribute.AttributeClass.TypeArguments.First() is not INamedTypeSymbol keyType) + { + ReportDiagnostics(UnableToGetSymbol, context.GetLocation(), + attribute.AttributeClass.TypeArguments.First().ToDisplayString()); + return; + } + + polymorphicInfo.PolymorphicKeyType = keyType; + + // get derived type, in typeof + var derivedTypeConstant = attribute.ConstructorArguments.First(); + if (derivedTypeConstant.Kind != TypedConstantKind.Type) + { + ReportDiagnostics(UnableToGetSymbol, context.GetLocation(), + derivedTypeConstant.ToCSharpString()); + return; + } + + if (derivedTypeConstant.Value is not INamedTypeSymbol derivedType) + { + ReportDiagnostics(UnableToGetSymbol, context.GetLocation(), + derivedTypeConstant.ToCSharpString()); + return; + } + + // get type discriminator + var typeDiscriminatorConstant = attribute.ConstructorArguments.ElementAtOrDefault(1); + polymorphicInfo.PolymorphicTypes.Add(new PolymorphicDerivedTypeInfo + { + DerivedType = derivedType, Key = typeDiscriminatorConstant + }); + break; + case "ProtoPolymorphicAttribute": + foreach (var argument in attribute.NamedArguments) + { + switch (argument.Key) + { + case "FieldNumber": + polymorphicInfo.PolymorphicIndicateIndex = (uint)(argument.Value.Value ?? 0); + break; + case "FallbackToBaseType": + polymorphicInfo.PolymorphicFallbackToBaseType = + (bool)(argument.Value.Value ?? true); + break; + } + } + break; + } + } + + } + + private void PopulateFieldInfo(INamedTypeSymbol classSymbol, Dictionary fields, string identifier = "" ,CancellationToken token = default) + { + var members = classSymbol.GetMembers() + .Where(x => x is IPropertySymbol or IFieldSymbol) + .Where(x => x.GetAttributes().Any(t=>t.AttributeClass?.Name == "ProtoMemberAttribute")); + + foreach (var symbol in members) + { + token.ThrowIfCancellationRequested(); + + var member = symbol.DeclaringSyntaxReferences.FirstOrDefault()!.GetSyntax(token); if (symbol.IsStatic) { ReportDiagnostics(MustNotBeStatic, member.GetLocation(), symbol.Name, identifier); @@ -91,7 +204,7 @@ public void Parse(CancellationToken token = default) var attribute = symbol.GetAttributes().First(x => x.AttributeClass?.Name == "ProtoMemberAttribute"); int field = (int)(attribute.ConstructorArguments[0].Value ?? throw new InvalidOperationException("Unable to get field number.")); - if (Fields.ContainsKey(field)) + if (fields.ContainsKey(field)) { ReportDiagnostics(DuplicateFieldNumber, member.GetLocation(), field, identifier); continue; @@ -126,7 +239,7 @@ public void Parse(CancellationToken token = default) var valueAttribute = symbol.GetAttributes().FirstOrDefault(x => x.AttributeClass?.ToDisplayString() == ProtoValueMemberAttributeFullName); if (valueAttribute != null) ReadProtoMemberAttribute(valueAttribute, typeSymbol, ref valueWireType, member, field, identifier, ref valueSigned); - Fields[field] = new ProtoFieldInfo(symbol, typeSymbol, wireType, signed) + fields[field] = new ProtoFieldInfo(symbol, typeSymbol, wireType, signed) { ExtraTypeInfo = { @@ -138,12 +251,12 @@ public void Parse(CancellationToken token = default) else { ReadProtoMemberAttribute(attribute, typeSymbol, ref wireType, member, field, identifier, ref signed); - Fields[field] = new ProtoFieldInfo(symbol, typeSymbol, wireType, signed); + fields[field] = new ProtoFieldInfo(symbol, typeSymbol, wireType, signed); } } } - - private void ReadProtoMemberAttribute(AttributeData attribute, ITypeSymbol typeSymbol, ref WireType wireType, MemberDeclarationSyntax member, int field, string identifier, ref bool signed) + + private void ReadProtoMemberAttribute(AttributeData attribute, ITypeSymbol typeSymbol, ref WireType wireType, SyntaxNode member, int field, string identifier, ref bool signed) { foreach (var argument in attribute.NamedArguments) { diff --git a/Lagrange.Proto.Runner/Program.cs b/Lagrange.Proto.Runner/Program.cs index 9ceeda98..d01e95d4 100644 --- a/Lagrange.Proto.Runner/Program.cs +++ b/Lagrange.Proto.Runner/Program.cs @@ -20,4 +20,37 @@ private static void Main(string[] args) int value = parsed[1][0][1].GetValue(); } -} \ No newline at end of file +} + +#region Test Classes +[ProtoPackable] +[ProtoPolymorphic(FieldNumber = 1)] +[ProtoDerivedType(typeof(DerivedClassA), 4)] +[ProtoDerivedType(typeof(DerivedClassB), 3)] +public partial class BaseClass +{ + public BaseClass() : this(-1) { } + + public BaseClass(int identifier) + { + IdentifierProperty = identifier; + } + + [ProtoMember(1)] public int IdentifierProperty { get; set; } +} + + +[ProtoPackable] +public partial class DerivedClassA() : BaseClass(2) +{ + [ProtoMember(2)] public string? NameProperty { get; set; } +} + + +[ProtoPackable] +public partial class DerivedClassB() : BaseClass(3) +{ + [ProtoMember(2)] public float ValueProperty { get; set; } = 0f; +} + +#endregion \ No newline at end of file diff --git a/Lagrange.Proto.Test/ProtoPolymorphismTest.cs b/Lagrange.Proto.Test/ProtoPolymorphismTest.cs new file mode 100644 index 00000000..6661fe19 --- /dev/null +++ b/Lagrange.Proto.Test/ProtoPolymorphismTest.cs @@ -0,0 +1,364 @@ +using System.Reflection; +using Lagrange.Proto.Serialization; + +namespace Lagrange.Proto.Test; + +[TestFixture] +public class ProtoPolymorphismTest +{ + #region Basic Polymorphism + + [Test] + public void ReflectionPolymorphism_SerializeBaseAndDeserializeBase_ReturnsCorrectDerivedType() + { + // Arrange + BaseClass originalA = new DerivedClassA { NameProperty = "TestName" }; + BaseClass originalB = new DerivedClassB { ValueProperty = 114514f }; + + byte[] bytesA = ProtoSerializer.Serialize(originalA); + BaseClass deserializedA = ProtoSerializer.Deserialize(bytesA); + + byte[] bytesB = ProtoSerializer.Serialize(originalB); + BaseClass deserializedB = ProtoSerializer.Deserialize(bytesB); + + Assert.Multiple(() => + { + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA.IdentifierProperty, Is.EqualTo(2)); + Assert.That(((DerivedClassA)deserializedA).NameProperty, Is.EqualTo("TestName")); + + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB.IdentifierProperty, Is.EqualTo(3)); + Assert.That(((DerivedClassB)deserializedB).ValueProperty, Is.EqualTo(114514f)); + + }); + } + + [Test] + public void ReflectionPolymorphism_SerializeDerivedAndDeserializeBase_ReturnsCorrectDerivedType() + { + // Arrange + DerivedClassA originalA = new DerivedClassA { NameProperty = "TestName" }; + DerivedClassB originalB = new DerivedClassB { ValueProperty = 114514f }; + + byte[] bytesA = ProtoSerializer.Serialize(originalA); + BaseClass deserializedA = ProtoSerializer.Deserialize(bytesA); + + byte[] bytesB = ProtoSerializer.Serialize(originalB); + BaseClass deserializedB = ProtoSerializer.Deserialize(bytesB); + + Assert.Multiple(() => + { + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA.IdentifierProperty, Is.EqualTo(2)); + Assert.That(((DerivedClassA)deserializedA).NameProperty, Is.EqualTo("TestName")); + + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB.IdentifierProperty, Is.EqualTo(3)); + Assert.That(((DerivedClassB)deserializedB).ValueProperty, Is.EqualTo(114514f)); + }); + } + + [Test] + public void ReflectionPolymorphism_SerializeBaseAndDeserializeDerived_ReturnsCorrectDerivedType() + { + // Arrange + BaseClass originalA = new DerivedClassA { NameProperty = "TestName" }; + BaseClass originalB = new DerivedClassB { ValueProperty = 114514f }; + + byte[] bytesA = ProtoSerializer.Serialize(originalA); + BaseClass deserializedA = ProtoSerializer.Deserialize(bytesA); + + byte[] bytesB = ProtoSerializer.Serialize(originalB); + BaseClass deserializedB = ProtoSerializer.Deserialize(bytesB); + + Assert.Multiple(() => + { + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA.IdentifierProperty, Is.EqualTo(2)); + Assert.That(((DerivedClassA)deserializedA).NameProperty, Is.EqualTo("TestName")); + + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB.IdentifierProperty, Is.EqualTo(3)); + Assert.That(((DerivedClassB)deserializedB).ValueProperty, Is.EqualTo(114514f)); + }); + } + + [Test] + public void ReflectionPolymorphism_SerializeDerivedAndDeserializeDerived_ReturnsCorrectDerivedType() + { + // Arrange + DerivedClassA originalA = new DerivedClassA { NameProperty = "TestName" }; + DerivedClassB originalB = new DerivedClassB { ValueProperty = 114514f }; + + byte[] bytesA = ProtoSerializer.Serialize(originalA); + BaseClass deserializedA = ProtoSerializer.Deserialize(bytesA); + + byte[] bytesB = ProtoSerializer.Serialize(originalB); + BaseClass deserializedB = ProtoSerializer.Deserialize(bytesB); + + Assert.Multiple(() => + { + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA.IdentifierProperty, Is.EqualTo(2)); + Assert.That(((DerivedClassA)deserializedA).NameProperty, Is.EqualTo("TestName")); + + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB.IdentifierProperty, Is.EqualTo(3)); + Assert.That(((DerivedClassB)deserializedB).ValueProperty, Is.EqualTo(114514f)); + }); + } + + #endregion + + #region ProtoPackable + + [Test] + public void ProtoPackablePolymorphism_SerializeBaseAndDeserializeDerived_ReturnsCorrectDerivedType() + { + // Arrange + BaseClass originalA = new DerivedClassA { NameProperty = "TestName" }; + BaseClass originalB = new DerivedClassB { ValueProperty = 114514f }; + + byte[] bytesA = ProtoSerializer.SerializeProtoPackable(originalA); + DerivedClassA deserializedA = ProtoSerializer.DeserializeProtoPackable(bytesA); + + byte[] bytesB = ProtoSerializer.SerializeProtoPackable(originalB); + DerivedClassB deserializedB = ProtoSerializer.DeserializeProtoPackable(bytesB); + + Assert.Multiple(() => + { + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA.IdentifierProperty, Is.EqualTo(2)); + Assert.That(((DerivedClassA)deserializedA).NameProperty, Is.EqualTo("TestName")); + + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB.IdentifierProperty, Is.EqualTo(3)); + Assert.That(((DerivedClassB)deserializedB).ValueProperty, Is.EqualTo(114514f)); + }); + } + + [Test] + public void ProtoPackablePolymorphism_SerializeDerivedAndDeserializeDerived_ReturnsCorrectDerivedType() + { + // Arrange + DerivedClassA originalA = new DerivedClassA { NameProperty = "TestName" }; + DerivedClassB originalB = new DerivedClassB { ValueProperty = 114514f }; + + byte[] bytesA = ProtoSerializer.SerializeProtoPackable(originalA); + BaseClass deserializedA = ProtoSerializer.DeserializeProtoPackable(bytesA); + + byte[] bytesB = ProtoSerializer.SerializeProtoPackable(originalB); + BaseClass deserializedB = ProtoSerializer.DeserializeProtoPackable(bytesB); + + Assert.Multiple(() => + { + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA.IdentifierProperty, Is.EqualTo(2)); + Assert.That(((DerivedClassA)deserializedA).NameProperty, Is.EqualTo("TestName")); + + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB.IdentifierProperty, Is.EqualTo(3)); + Assert.That(((DerivedClassB)deserializedB).ValueProperty, Is.EqualTo(114514f)); + }); + } + + [Test] + public void ProtoPackablePolymorphism_SerializeBaseAndDeserializeBase_ReturnsCorrectDerivedType() + { + // Arrange + BaseClass originalA = new DerivedClassA { NameProperty = "TestName" }; + BaseClass originalB = new DerivedClassB { ValueProperty = 114514f }; + + byte[] bytesA = ProtoSerializer.SerializeProtoPackable(originalA); + BaseClass deserializedA = ProtoSerializer.DeserializeProtoPackable(bytesA); + + byte[] bytesB = ProtoSerializer.SerializeProtoPackable(originalB); + BaseClass deserializedB = ProtoSerializer.DeserializeProtoPackable(bytesB); + + Assert.Multiple(() => + { + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA.IdentifierProperty, Is.EqualTo(2)); + Assert.That(((DerivedClassA)deserializedA).NameProperty, Is.EqualTo("TestName")); + + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB.IdentifierProperty, Is.EqualTo(3)); + Assert.That(((DerivedClassB)deserializedB).ValueProperty, Is.EqualTo(114514f)); + }); + } + + [Test] + public void ProtoPackablePolymorphism_SerializeDerivedAndDeserializeBase_ReturnsCorrectDerivedType() + { + // Arrange + DerivedClassA originalA = new DerivedClassA { NameProperty = "TestName" }; + DerivedClassB originalB = new DerivedClassB { ValueProperty = 114514f }; + + byte[] bytesA = ProtoSerializer.SerializeProtoPackable(originalA); + BaseClass deserializedA = ProtoSerializer.DeserializeProtoPackable(bytesA); + + byte[] bytesB = ProtoSerializer.SerializeProtoPackable(originalB); + BaseClass deserializedB = ProtoSerializer.DeserializeProtoPackable(bytesB); + + Assert.Multiple(() => + { + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA, Is.AssignableTo()); + Assert.That(deserializedA.IdentifierProperty, Is.EqualTo(2)); + Assert.That(((DerivedClassA)deserializedA).NameProperty, Is.EqualTo("TestName")); + + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB, Is.AssignableTo()); + Assert.That(deserializedB.IdentifierProperty, Is.EqualTo(3)); + Assert.That(((DerivedClassB)deserializedB).ValueProperty, Is.EqualTo(114514f)); + }); + } + + #endregion + + #region Nested Polymorphism Tests + + [Test] + public void NestedPolymorphism_Reflection_SerializeAndDeserialize_ReturnsCorrectDerivedClassC() + { + // Arrange + BaseClass baseRef = + new DerivedClassC { NameProperty = "NestedName", Cannon = "BigCannon" }; + DerivedClassA classARef = (DerivedClassA)baseRef; + DerivedClassC classCRef = (DerivedClassC)baseRef; + + // Serialize from BaseClass + byte[] bytesFromBase = ProtoSerializer.Serialize(baseRef); + var deserializedFromBase = ProtoSerializer.Deserialize(bytesFromBase); + + // Serialize from DerivedClassA + byte[] bytesFromA = ProtoSerializer.Serialize(classARef); + var deserializedFromA = ProtoSerializer.Deserialize(bytesFromA); + + // Serialize from DerivedClassC + byte[] bytesFromC = ProtoSerializer.Serialize(classCRef); + var deserializedFromC = ProtoSerializer.Deserialize(bytesFromC); + + Assert.Multiple(() => + { + Assert.That(deserializedFromBase, Is.AssignableTo()); + Assert.That(((DerivedClassC)deserializedFromBase).NameProperty, Is.EqualTo("NestedName")); + Assert.That(((DerivedClassC)deserializedFromBase).Cannon, Is.EqualTo("BigCannon")); + Assert.That(((DerivedClassC)deserializedFromBase).AnotherIdentifier, Is.EqualTo(114514)); + + Assert.That(deserializedFromA, Is.AssignableTo()); + Assert.That(((DerivedClassC)deserializedFromA).NameProperty, Is.EqualTo("NestedName")); + Assert.That(((DerivedClassC)deserializedFromA).Cannon, Is.EqualTo("BigCannon")); + Assert.That(((DerivedClassC)deserializedFromA).AnotherIdentifier, Is.EqualTo(114514)); + + Assert.That(deserializedFromC, Is.AssignableTo()); + Assert.That(deserializedFromC.NameProperty, Is.EqualTo("NestedName")); + Assert.That(deserializedFromC.Cannon, Is.EqualTo("BigCannon")); + Assert.That(deserializedFromC.AnotherIdentifier, Is.EqualTo(114514)); + }); + } + + [Test] + public void NestedPolymorphism_ProtoPackable_SerializeAndDeserialize_ReturnsCorrectDerivedClassC() + { + // Arrange + BaseClass baseRef = + new DerivedClassC { NameProperty = "NestedName", Cannon = "BigCannon" }; + DerivedClassA classARef = (DerivedClassA)baseRef; + DerivedClassC classCRef = (DerivedClassC)baseRef; + + // Serialize from BaseClass + byte[] bytesFromBase = ProtoSerializer.SerializeProtoPackable(baseRef); + var deserializedFromBase = ProtoSerializer.DeserializeProtoPackable(bytesFromBase); + + // Serialize from DerivedClassA + byte[] bytesFromA = ProtoSerializer.SerializeProtoPackable(classARef); + var deserializedFromA = ProtoSerializer.DeserializeProtoPackable(bytesFromA); + + // Serialize from DerivedClassC + byte[] bytesFromC = ProtoSerializer.SerializeProtoPackable(classCRef); + var deserializedFromC = ProtoSerializer.DeserializeProtoPackable(bytesFromC); + + Assert.Multiple(() => + { + Assert.That(deserializedFromBase, Is.AssignableTo()); + Assert.That(((DerivedClassC)deserializedFromBase).NameProperty, Is.EqualTo("NestedName")); + Assert.That(((DerivedClassC)deserializedFromBase).Cannon, Is.EqualTo("BigCannon")); + Assert.That(((DerivedClassC)deserializedFromBase).AnotherIdentifier, Is.EqualTo(114514)); + + Assert.That(deserializedFromA, Is.AssignableTo()); + Assert.That(((DerivedClassC)deserializedFromA).NameProperty, Is.EqualTo("NestedName")); + Assert.That(((DerivedClassC)deserializedFromA).Cannon, Is.EqualTo("BigCannon")); + Assert.That(((DerivedClassC)deserializedFromA).AnotherIdentifier, Is.EqualTo(114514)); + + Assert.That(deserializedFromC, Is.AssignableTo()); + Assert.That(deserializedFromC.NameProperty, Is.EqualTo("NestedName")); + Assert.That(deserializedFromC.Cannon, Is.EqualTo("BigCannon")); + Assert.That(deserializedFromC.AnotherIdentifier, Is.EqualTo(114514)); + }); + } + + #endregion + +} + +#region Test Classes + +[ProtoPackable] +[ProtoPolymorphic(FieldNumber = 1)] +[ProtoDerivedType(typeof(DerivedClassA), 2)] +[ProtoDerivedType(typeof(DerivedClassB), 3)] +public partial class BaseClass +{ + public BaseClass() { } + + public BaseClass(int identifier) + { + IdentifierProperty = identifier; + } + + [ProtoMember(1, NumberHandling = ProtoNumberHandling.Fixed32)] public int IdentifierProperty { get; set; } = -1; +} + +[ProtoPackable] +[ProtoPolymorphic(FieldNumber = 4, FallbackToBaseType = true)] +[ProtoDerivedType(typeof(DerivedClassC), 114514)] +public partial class DerivedClassA() : BaseClass(2) +{ + public DerivedClassA(int anotherIdentifier) : this() + { + AnotherIdentifier = anotherIdentifier; + } + + [ProtoMember(4, NumberHandling = ProtoNumberHandling.Fixed32)] public int AnotherIdentifier { get; set; } = -1; + [ProtoMember(2)] public string NameProperty { get; set; } = string.Empty; +} + +[ProtoPackable] +public partial class DerivedClassC() : DerivedClassA(114514) +{ + [ProtoMember(10)] public string Cannon { get; set; } = string.Empty; +} + +[ProtoPackable] +public partial class DerivedClassB() : BaseClass(3) +{ + [ProtoMember(2)] public float ValueProperty { get; set; } = 0f; +} + +#endregion \ No newline at end of file diff --git a/Lagrange.Proto/ProtoDerivedTypeAttribute.cs b/Lagrange.Proto/ProtoDerivedTypeAttribute.cs new file mode 100644 index 00000000..2c9bac37 --- /dev/null +++ b/Lagrange.Proto/ProtoDerivedTypeAttribute.cs @@ -0,0 +1,28 @@ +namespace Lagrange.Proto; + +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Interface, AllowMultiple = true, Inherited = false)] +public class ProtoDerivedTypeAttribute : ProtoDerivedTypeAttribute where T : IEquatable +{ + public ProtoDerivedTypeAttribute(Type derivedType, T typeDiscriminator) : base(derivedType) + { + TypeDiscriminator = typeDiscriminator; + } + + /// + /// The type discriminator identifier to be used for the serialization of the subtype. + /// + internal T TypeDiscriminator { get; init; } +} + +public class ProtoDerivedTypeAttribute : Attribute +{ + public ProtoDerivedTypeAttribute(Type derivedType) + { + DerivedType = derivedType; + } + + /// + /// A derived type that should be supported in polymorphic serialization of the declared base type. + /// + internal Type DerivedType { get; init; } +} \ No newline at end of file diff --git a/Lagrange.Proto/ProtoPolymorphicAttribute.cs b/Lagrange.Proto/ProtoPolymorphicAttribute.cs new file mode 100644 index 00000000..e9b29010 --- /dev/null +++ b/Lagrange.Proto/ProtoPolymorphicAttribute.cs @@ -0,0 +1,9 @@ +namespace Lagrange.Proto; + + +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Interface, Inherited = false)] +public class ProtoPolymorphicAttribute : Attribute +{ + public uint FieldNumber { get; init; } + public bool FallbackToBaseType { get; init; } = true; +} \ No newline at end of file diff --git a/Lagrange.Proto/Serialization/Metadata/ProtoObjectInfo.cs b/Lagrange.Proto/Serialization/Metadata/ProtoObjectInfo.cs index 64ceb4f8..9bd7c7ef 100644 --- a/Lagrange.Proto/Serialization/Metadata/ProtoObjectInfo.cs +++ b/Lagrange.Proto/Serialization/Metadata/ProtoObjectInfo.cs @@ -6,8 +6,7 @@ namespace Lagrange.Proto.Serialization.Metadata; public class ProtoObjectInfo { public Dictionary Fields { get; init; } = new(); - public Func? ObjectCreator { get; init; } - public bool IgnoreDefaultFields { get; init; } + public ProtoPolymorphicInfoBase? PolymorphicInfo { get; init; } } \ No newline at end of file diff --git a/Lagrange.Proto/Serialization/Metadata/ProtoPolymorphicInfo.cs b/Lagrange.Proto/Serialization/Metadata/ProtoPolymorphicInfo.cs new file mode 100644 index 00000000..15322062 --- /dev/null +++ b/Lagrange.Proto/Serialization/Metadata/ProtoPolymorphicInfo.cs @@ -0,0 +1,75 @@ +using Lagrange.Proto.Primitives; +using Lagrange.Proto.Utility; + +namespace Lagrange.Proto.Serialization.Metadata; + + +public class ProtoPolymorphicInfoBase +{ + public uint PolymorphicIndicateIndex { get; set; } + public bool PolymorphicFallbackToBaseType { get; set; } + public virtual ProtoPolymorphicDerivedTypeDescriptor? GetDerivedTypeDescriptorFromReader(ref ProtoReader reader) => null; + public virtual void SetDerivedTypeDescriptor(object key, ProtoPolymorphicDerivedTypeDescriptor descriptor) => throw new MissingMethodException(); + public Func? RootTypeDescriptorGetter { get; set; } + public virtual IEnumerable GetAllDerivedTypeDescriptors() => []; +} + + +public class ProtoPolymorphicObjectInfo : ProtoPolymorphicInfoBase where TKey : IEquatable +{ + public Dictionary PolymorphicDerivedTypes { get; init; } = []; + public override ProtoPolymorphicDerivedTypeDescriptor? GetDerivedTypeDescriptorFromReader(ref ProtoReader reader) + { + uint tag = reader.DecodeVarIntUnsafe(); + int field = (int)(tag >> 3); + var wireType = (WireType)(tag & 0x7); + if (field != PolymorphicIndicateIndex) + { + reader.Rewind(-ProtoHelper.GetVarIntLength(tag)); + return null; + } + var converter = ProtoTypeResolver.GetConverter(); + var key = converter.Read(field, wireType, ref reader); + var rst = PolymorphicDerivedTypes.GetValueOrDefault(key); + if (rst == null && !PolymorphicFallbackToBaseType) + ThrowHelper.ThrowInvalidOperationException_UnknownPolymorphicType(typeof(TKey), key); + return rst; + } + + public override void SetDerivedTypeDescriptor(object o, ProtoPolymorphicDerivedTypeDescriptor descriptor) + { + if (o is TKey key) + PolymorphicDerivedTypes[key] = descriptor; + else + ThrowHelper.ThrowInvalidOperationException_UnknownPolymorphicType(typeof(TKey), o); + } + + public override IEnumerable GetAllDerivedTypeDescriptors() + { + return PolymorphicDerivedTypes.Values; + } +} + + +public class ProtoPolymorphicDerivedTypeDescriptor +{ + public required Type CurrentType { get; init; } + public Func> FieldsGetter { get; init; } = () => throw new MissingMethodException(); + public Func IgnoreDefaultFieldsGetter { get; init; } = () => false; + public Func PolymorphicInfoGetter { get; init; } = () => null; + + public virtual object? CreateObject() + { + return null; + } +} + +public class ProtoPolymorphicDerivedTypeDescriptor : ProtoPolymorphicDerivedTypeDescriptor +{ + public Func ObjectCreator { get; init; } = null!; + + public override object? CreateObject() + { + return ObjectCreator(); + } +} \ No newline at end of file diff --git a/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs b/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs index d60a1443..7f917aca 100644 --- a/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs +++ b/Lagrange.Proto/Serialization/Metadata/ProtoTypeResolver.Dynamic.cs @@ -64,37 +64,104 @@ static MemberAccessor Initialize() [RequiresDynamicCode(ProtoSerializer.SerializationRequiresDynamicCodeMessage)] internal static ProtoObjectInfo CreateObjectInfo() { - var ctor = typeof(T).IsValueType ? null : typeof(T).GetConstructor(Type.EmptyTypes); - bool ignoreDefaultFields = typeof(T).GetCustomAttribute()?.IgnoreDefaultFields == true; + var objType = typeof(T); + var ctor = objType.IsValueType ? null : objType.GetConstructor(Type.EmptyTypes); + bool ignoreDefaultFields = objType.GetCustomAttribute()?.IgnoreDefaultFields == true; + var fields = CreateTypeFieldInfo(objType); + + return new ProtoObjectInfo + { + ObjectCreator = MemberAccessor.CreateParameterlessConstructor(ctor), + IgnoreDefaultFields = ignoreDefaultFields, + Fields = fields.OrderBy(x => x.Key).ToDictionary(x => x.Key, x => x.Value), + PolymorphicInfo = PopulatePolymorphicInfo() + }; + } + + internal static ProtoPolymorphicInfoBase? PopulatePolymorphicInfo() + { + var type = typeof(T); + ProtoPolymorphicInfoBase? objectInfo = null; + // goto root + var checkingType = type; + while (checkingType.BaseType != null) + { + // check if base type has polymorphic attribute + var basePolymorphicAttributes = checkingType.BaseType.GetCustomAttributes(typeof(ProtoDerivedTypeAttribute<>)) + .OfType().ToArray(); + if (basePolymorphicAttributes.Length > 0) checkingType = checkingType.BaseType; + else break; + } + + if (checkingType == type) checkingType = null; + + var polymorphicAttributes = type.GetCustomAttributes(typeof(ProtoDerivedTypeAttribute<>)) + .OfType().ToArray(); + if (polymorphicAttributes.Length > 0) + { + var polymorphicConfigAttribute = type.GetCustomAttribute(); + var polymorphicFieldNumber = polymorphicConfigAttribute?.FieldNumber ?? 0; + var fallbackToBaseType = polymorphicConfigAttribute?.FallbackToBaseType ?? true; + if (polymorphicFieldNumber == 0) polymorphicFieldNumber = 1; // use first for default + + // get the TKey from first + var firstAttr = polymorphicAttributes[0]; + var keyType = firstAttr.GetType().GetGenericArguments()[0]; + objectInfo = (ProtoPolymorphicInfoBase?) typeof(ProtoPolymorphicObjectInfo<>).MakeGenericType(keyType) + .GetConstructor(Type.EmptyTypes)?.Invoke(null); + + Debug.Assert(objectInfo != null); + objectInfo.PolymorphicIndicateIndex = polymorphicFieldNumber; + objectInfo.PolymorphicFallbackToBaseType = fallbackToBaseType; + + foreach (var attr in polymorphicAttributes) + { + var key = attr.GetType().GetProperty(nameof(ProtoDerivedTypeAttribute.TypeDiscriminator), BindingFlags.NonPublic | BindingFlags.Instance) + ?.GetValue(attr); + if (key == null) ThrowHelper.ThrowInvalidOperationException_UnknownPolymorphicType(type, attr.DerivedType); + objectInfo.SetDerivedTypeDescriptor(key, ProtoSerializer.GetObjectInfoReflection(attr.DerivedType)); + } + } + + if (checkingType is not null) + { + objectInfo ??= new ProtoPolymorphicInfoBase() + { + PolymorphicIndicateIndex = 0, + PolymorphicFallbackToBaseType = true + }; + objectInfo.RootTypeDescriptorGetter = () => ProtoSerializer.GetObjectInfoReflection(checkingType); + } + + return objectInfo; + } + + internal static Dictionary CreateTypeFieldInfo(Type type) + { var fields = new Dictionary(); - foreach (var field in typeof(T).GetFields(BindingFlags.Public | BindingFlags.Instance)) + foreach (var field in type.GetFields(BindingFlags.Public | BindingFlags.Instance)) { if (field.IsStatic) continue; - var fieldInfo = CreateFieldInfo(typeof(T), field); + var fieldInfo = CreateFieldInfo(type, field); if (fieldInfo == null) continue; uint tag = ((uint)fieldInfo.Field << 3) | (byte)fieldInfo.WireType; - if (fields.ContainsKey(tag)) ThrowHelper.ThrowInvalidOperationException_DuplicateField(typeof(T), fieldInfo.Field); + if (fields.ContainsKey(tag)) ThrowHelper.ThrowInvalidOperationException_DuplicateField(type, fieldInfo.Field); fields[tag] = fieldInfo; } - foreach (var field in typeof(T).GetProperties(BindingFlags.Public | BindingFlags.Instance)) + foreach (var field in type.GetProperties(BindingFlags.Public | BindingFlags.Instance)) { - var fieldInfo = CreateFieldInfo(typeof(T), field); + var fieldInfo = CreateFieldInfo(type, field); if (fieldInfo == null) continue; uint tag = ((uint)fieldInfo.Field << 3) | (byte)fieldInfo.WireType; - if (fields.ContainsKey(tag)) ThrowHelper.ThrowInvalidOperationException_DuplicateField(typeof(T), fieldInfo.Field); + if (fields.ContainsKey(tag)) ThrowHelper.ThrowInvalidOperationException_DuplicateField(type, fieldInfo.Field); fields[tag] = fieldInfo; } - - return new ProtoObjectInfo - { - ObjectCreator = MemberAccessor.CreateParameterlessConstructor(ctor), - IgnoreDefaultFields = ignoreDefaultFields, - Fields = fields.OrderBy(x => x.Key).ToDictionary(x => x.Key, x => x.Value) - }; + + return fields; } diff --git a/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs b/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs index 417ed77a..feb809ab 100644 --- a/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs +++ b/Lagrange.Proto/Serialization/ProtoSerializer.Deserialize.cs @@ -19,18 +19,55 @@ public static T DeserializeProtoPackable(ReadOnlySpan data) where T : I var reader = new ProtoReader(data); return DeserializeProtoPackableCore(ref reader); } - + private static T DeserializeProtoPackableCore(ref ProtoReader reader) where T : IProtoSerializable { var objectInfo = T.TypeInfo; Debug.Assert(objectInfo.ObjectCreator != null); - + T target = objectInfo.ObjectCreator(); + var fields = objectInfo.Fields; + var polymorphicInfo = objectInfo.PolymorphicInfo; + if (polymorphicInfo?.PolymorphicIndicateIndex is > 0) + { + var root = polymorphicInfo.RootTypeDescriptorGetter?.Invoke(); + if (root is not null) + { + fields = root.FieldsGetter(); + polymorphicInfo = root.PolymorphicInfoGetter(); + } + } + + polyDeserialize: + if (polymorphicInfo?.PolymorphicIndicateIndex is > 0) + { + var typeDescriptor = polymorphicInfo.GetDerivedTypeDescriptorFromReader(ref reader); + if (typeDescriptor is not null) + { + fields = typeDescriptor.FieldsGetter(); + polymorphicInfo = typeDescriptor.PolymorphicInfoGetter(); + if (!typeof(T).IsAssignableTo(typeDescriptor.CurrentType)) + { + if (typeDescriptor.CreateObject() is T newObj) + { + target = newObj; + } + else + { + ThrowHelper.ThrowInvalidOperationException_CanNotCreateObject(typeDescriptor.CurrentType); + } + } + + goto polyDeserialize; + } + } + + while (!reader.IsCompleted) { uint tag = reader.DecodeVarIntUnsafe(); - if (objectInfo.Fields.TryGetValue(tag, out var fieldInfo)) + if (fields.TryGetValue(tag, out var fieldInfo)) { fieldInfo.Read(ref reader, target); } @@ -39,10 +76,10 @@ private static T DeserializeProtoPackableCore(ref ProtoReader reader) where T reader.SkipField((WireType)(tag & 0x07)); } } - + return target; } - + /// /// Deserialize the ProtoPackable Object from the source buffer, based on reflection /// @@ -51,43 +88,61 @@ private static T DeserializeProtoPackableCore(ref ProtoReader reader) where T /// The deserialized object [RequiresUnreferencedCode(SerializationUnreferencedCodeMessage)] [RequiresDynamicCode(SerializationRequiresDynamicCodeMessage)] - public static T Deserialize<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.All)] T>(ReadOnlySpan data) + public static T Deserialize<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.All)] T>( + ReadOnlySpan data) { var reader = new ProtoReader(data); return DeserializeCore(ref reader); } - + [RequiresUnreferencedCode(SerializationUnreferencedCodeMessage)] [RequiresDynamicCode(SerializationRequiresDynamicCodeMessage)] - private static T DeserializeCore<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.All)] T>(ref ProtoReader reader) + private static T DeserializeCore<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.All)] T>( + ref ProtoReader reader) { - ProtoObjectConverter converter; - if (ProtoTypeResolver.IsRegistered()) + var converter = GetConverterOf(); + Debug.Assert(converter.ObjectInfo.ObjectCreator != null); + + T target = converter.ObjectInfo.ObjectCreator(); + var boxed = (object?)target; // avoid multiple times of boxing + if (boxed is null) ThrowHelper.ThrowInvalidOperationException_CanNotCreateObject(typeof(T)); + var fieldInfos = converter.ObjectInfo.Fields; + var polymorphicInfo = converter.ObjectInfo.PolymorphicInfo; + + if (polymorphicInfo?.PolymorphicIndicateIndex is > 0) { - if (ProtoTypeResolver.GetConverter() as ProtoObjectConverter is not { } c) - { - converter = new ProtoObjectConverter(ProtoTypeResolver.CreateObjectInfo()); - ProtoTypeResolver.Register(converter); - } - else + var root = polymorphicInfo.RootTypeDescriptorGetter?.Invoke(); + if (root is not null) { - converter = c; + fieldInfos = root.FieldsGetter(); + polymorphicInfo = root.PolymorphicInfoGetter(); } } - else + + startDeserialize: + if (polymorphicInfo?.PolymorphicIndicateIndex is > 0) { - ProtoTypeResolver.Register(converter = new ProtoObjectConverter()); + var polymorphicDescriptor = polymorphicInfo.GetDerivedTypeDescriptorFromReader(ref reader); + if (polymorphicDescriptor is not null) + { + fieldInfos = polymorphicDescriptor.FieldsGetter(); + polymorphicInfo = polymorphicDescriptor.PolymorphicInfoGetter(); + + if (!typeof(T).IsAssignableTo(polymorphicDescriptor.CurrentType)) + { + boxed = polymorphicDescriptor.CreateObject(); + target = (T)boxed!; + if (boxed is null) ThrowHelper.ThrowInvalidOperationException_CanNotCreateObject(polymorphicDescriptor.CurrentType); + } + goto startDeserialize; + } } - Debug.Assert(converter.ObjectInfo.ObjectCreator != null); - T target = converter.ObjectInfo.ObjectCreator(); - var boxed = (object?)target; // avoid multiple times of boxing - if (boxed is null) ThrowHelper.ThrowInvalidOperationException_CanNotCreateObject(typeof(T)); while (!reader.IsCompleted) { uint tag = reader.DecodeVarIntUnsafe(); - if (converter.ObjectInfo.Fields.TryGetValue(tag, out var fieldInfo)) + if (fieldInfos.TryGetValue(tag, out var fieldInfo)) { fieldInfo.Read(ref reader, boxed); } @@ -96,7 +151,7 @@ private static T DeserializeProtoPackableCore(ref ProtoReader reader) where T reader.SkipField((WireType)(tag & 0x07)); } } - + return target; } } \ No newline at end of file diff --git a/Lagrange.Proto/Serialization/ProtoSerializer.Helpers.cs b/Lagrange.Proto/Serialization/ProtoSerializer.Helpers.cs index 9958b536..b3c43303 100644 --- a/Lagrange.Proto/Serialization/ProtoSerializer.Helpers.cs +++ b/Lagrange.Proto/Serialization/ProtoSerializer.Helpers.cs @@ -1,7 +1,65 @@ +using System.Diagnostics; +using System.Reflection; +using Lagrange.Proto.Serialization.Converter; +using Lagrange.Proto.Serialization.Metadata; + namespace Lagrange.Proto.Serialization; public static partial class ProtoSerializer { internal const string SerializationUnreferencedCodeMessage = "Proto serialization and deserialization might require types that cannot be statically analyzed. Use the SerializePackable that takes a IProtoSerializable to ensure generated code is used, or make sure all of the required types are preserved."; internal const string SerializationRequiresDynamicCodeMessage = "Proto serialization and deserialization might require types that cannot be statically analyzed and might need runtime code generation. Use Lagrange.Proto source generation for native AOT applications."; + + internal static ProtoObjectConverter GetConverterOf() + { + ProtoObjectConverter converter; + if (ProtoTypeResolver.IsRegistered()) + { + if (ProtoTypeResolver.GetConverter() as ProtoObjectConverter is not { } c) + { + converter = new ProtoObjectConverter(ProtoTypeResolver.CreateObjectInfo()); + ProtoTypeResolver.Register(converter); + } + else + { + converter = c; + } + } + else + { + ProtoTypeResolver.Register(converter = new ProtoObjectConverter()); + } + + return converter; + } + + internal static ProtoPolymorphicDerivedTypeDescriptor GetObjectInfoReflection(Type polyType) + { + var method = typeof(ProtoSerializer).GetMethod(nameof(GetConverterOf), + BindingFlags.Static | BindingFlags.NonPublic); + Debug.Assert(method != null); + var genericMethod = method.MakeGenericMethod(polyType); + var polyConverter = genericMethod.Invoke(null, null)!; + + // get creator and fields, oh my reflection! + var polyObjectInfo = polyConverter.GetType() + .GetField("ObjectInfo",BindingFlags.NonPublic | BindingFlags.Instance)!.GetValue(polyConverter)!; + var polyCreator = polyObjectInfo.GetType() + .GetProperty("ObjectCreator")!.GetValue(polyObjectInfo)!; + var fieldInfos = (Dictionary)polyObjectInfo.GetType() + .GetProperty("Fields")!.GetValue(polyObjectInfo)!; + var polymorphicInfo = (ProtoPolymorphicInfoBase)polyObjectInfo.GetType() + .GetProperty("PolymorphicInfo")!.GetValue(polyObjectInfo)!; + var ignoreDefaultFields = polyObjectInfo.GetType() + .GetProperty("IgnoreDefaultFields")!.GetValue(polyObjectInfo) is true; + return new ProtoPolymorphicDerivedTypeDescriptor + { + CurrentType = polyType, + FieldsGetter = () => fieldInfos, + IgnoreDefaultFieldsGetter = () => ignoreDefaultFields, + PolymorphicInfoGetter = () => polymorphicInfo, + ObjectCreator = ObjectCreator + }; + object? ObjectCreator() => (object?)polyCreator.GetType().GetMethod("Invoke")!.Invoke(polyCreator, null)!; + } } \ No newline at end of file diff --git a/Lagrange.Proto/Serialization/ProtoSerializer.Serialize.cs b/Lagrange.Proto/Serialization/ProtoSerializer.Serialize.cs index e182aecc..5ea60778 100644 --- a/Lagrange.Proto/Serialization/ProtoSerializer.Serialize.cs +++ b/Lagrange.Proto/Serialization/ProtoSerializer.Serialize.cs @@ -102,32 +102,52 @@ private static void SerializeProtoPackableCore(ProtoWriter writer, T obj) whe [RequiresUnreferencedCode(SerializationUnreferencedCodeMessage)] [RequiresDynamicCode(SerializationRequiresDynamicCodeMessage)] private static void SerializeCore<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.All)] T>(ProtoWriter writer, T obj) - { - ProtoObjectConverter converter; - if (ProtoTypeResolver.IsRegistered()) + { + var converter = GetConverterOf(); + var objectInfo = converter.ObjectInfo; + object? boxed = obj; // avoid multiple times of boxing + if (boxed is null) return; + var fields = objectInfo.Fields; + List skipTags = []; + var polymorphicInfo = converter.ObjectInfo.PolymorphicInfo; + var ignoreDefaultFields = objectInfo.IgnoreDefaultFields; + if (polymorphicInfo?.RootTypeDescriptorGetter is not null) { - if (ProtoTypeResolver.GetConverter() as ProtoObjectConverter is not { } c) - { - converter = new ProtoObjectConverter(ProtoTypeResolver.CreateObjectInfo()); - ProtoTypeResolver.Register(converter); - } - else - { - converter = c; - } + var discriminator = polymorphicInfo.RootTypeDescriptorGetter(); + fields = discriminator.FieldsGetter(); + polymorphicInfo = discriminator.PolymorphicInfoGetter(); + ignoreDefaultFields = discriminator.IgnoreDefaultFieldsGetter(); } - else + + var actualType = obj?.GetType(); + // check polymorphic type + startSerialize: + if (polymorphicInfo?.PolymorphicIndicateIndex is > 0) { - ProtoTypeResolver.Register(converter = new ProtoObjectConverter()); + // has polymorphic type + var index = polymorphicInfo.PolymorphicIndicateIndex; + var fieldInfo = fields.FirstOrDefault(t=>t.Value.Field == index); + if (fieldInfo.Value is null) ThrowHelper.ThrowInvalidOperationException_NullPolymorphicDiscriminator(typeof(T)); + var discriminator = fieldInfo.Value.Get?.Invoke(boxed); + if (discriminator is null) ThrowHelper.ThrowInvalidOperationException_NullPolymorphicDiscriminator(typeof(T)); + skipTags.Add(fieldInfo.Key); + writer.EncodeVarInt(fieldInfo.Key); + fieldInfo.Value.Write(writer, boxed); + foreach (var descriptor in polymorphicInfo.GetAllDerivedTypeDescriptors()) + { + if (actualType?.IsAssignableTo(descriptor.CurrentType) is true) + { + fields = descriptor.FieldsGetter(); + ignoreDefaultFields = descriptor.IgnoreDefaultFieldsGetter(); + polymorphicInfo = descriptor.PolymorphicInfoGetter(); + goto startSerialize; + } + } } - var objectInfo = converter.ObjectInfo; - object? boxed = obj; // avoid multiple times of boxing - if (boxed is null) return; - - foreach (var (tag, info) in objectInfo.Fields) + foreach (var (tag, info) in fields) { - if (info.ShouldSerialize(boxed, objectInfo.IgnoreDefaultFields)) + if (!skipTags.Contains(tag) && info.ShouldSerialize(boxed, ignoreDefaultFields)) { writer.EncodeVarInt(tag); info.Write(writer, boxed); diff --git a/Lagrange.Proto/ThrowHelper.cs b/Lagrange.Proto/ThrowHelper.cs index 491106c9..3e5c089b 100644 --- a/Lagrange.Proto/ThrowHelper.cs +++ b/Lagrange.Proto/ThrowHelper.cs @@ -66,4 +66,24 @@ public static void ThrowInvalidOperationException_NodeWrongType(params ReadOnlyS [DoesNotReturn] [MethodImpl(MethodImplOptions.NoInlining)] public static void ThrowInvalidOperationException_InvalidNodesWireType(string fieldName) => throw new InvalidOperationException($"The wire type must be explicitly set for field {fieldName} as the wire type for the ProtoNode, ProtoValue, and ProtoArray types is not known at compile time, to set the wire type, use the NodesWireType Property in ProtoMember attribute"); + + [DoesNotReturn] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void ThrowInvalidOperationException_NullPolymorphicDiscriminator(Type type) => throw new InvalidOperationException($"The polymorphic discriminator field for type {type.Name} cannot be null. Please ensure that the field is set and has a valid value."); + + [DoesNotReturn] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void ThrowInvalidOperationException_DuplicatePolymorphicDiscriminator(Type type, object key) => throw new InvalidOperationException($"The polymorphic discriminator key '{key}' for type {type.Name} is duplicated. Please ensure that the keys are unique."); + + [DoesNotReturn] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void ThrowInvalidOperationException_PolymorphicFieldNotFirst(Type type, uint expected, uint actual) => throw new InvalidOperationException($"The polymorphic discriminator field for type {type.Name} must be the first field in the message. Expected field number {expected}, but found {actual}."); + + [DoesNotReturn] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void ThrowInvalidOperationException_FailedParsePolymorphicType(Type type, uint index) => throw new InvalidOperationException($"Failed to parse the polymorphic type from proto for type {type.Name} at '{index}'."); + + [DoesNotReturn] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void ThrowInvalidOperationException_UnknownPolymorphicType(Type type, object polyTypeKey) => throw new InvalidOperationException($"Unknown polymorphic type '{polyTypeKey}' for base type {type.Name}. Please ensure that the polymorphic type is registered."); } \ No newline at end of file