Skip to content

Commit 388d775

Browse files
authored
refactor: Moved INetworkMessage discovery to ILPP (#1276)
1 parent 67ec68f commit 388d775

File tree

4 files changed

+196
-77
lines changed

4 files changed

+196
-77
lines changed

com.unity.netcode.gameobjects/Editor/CodeGen/CodeGenHelpers.cs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ internal static class CodeGenHelpers
2424
public static readonly string ClientRpcParams_FullName = typeof(ClientRpcParams).FullName;
2525
public static readonly string INetworkSerializable_FullName = typeof(INetworkSerializable).FullName;
2626
public static readonly string INetworkSerializable_NetworkSerialize_Name = nameof(INetworkSerializable.NetworkSerialize);
27+
public static readonly string IgnoreMessageIfSystemOwnerIsNotOfTypeAttribute_FullName = typeof(IgnoreMessageIfSystemOwnerIsNotOfTypeAttribute).FullName;
2728
public static readonly string UnityColor_FullName = typeof(Color).FullName;
2829
public static readonly string UnityColor32_FullName = typeof(Color32).FullName;
2930
public static readonly string UnityVector2_FullName = typeof(Vector2).FullName;
@@ -265,6 +266,42 @@ public static void AddError(this List<DiagnosticMessage> diagnostics, SequencePo
265266
});
266267
}
267268

269+
public static void RemoveRecursiveReferences(this ModuleDefinition moduleDefinition)
270+
{
271+
// Weird behavior from Cecil: When importing a reference to a specific implementation of a generic
272+
// method, it's importing the main module as a reference into itself. This causes Unity to have issues
273+
// when attempting to iterate the assemblies to discover unit tests, as it goes into infinite recursion
274+
// and eventually hits a stack overflow. I wasn't able to find any way to stop Cecil from importing the module
275+
// into itself, so at the end of it all, we're just going to go back and remove it again.
276+
var moduleName = moduleDefinition.Name;
277+
if (moduleName.EndsWith(".dll") || moduleName.EndsWith(".exe"))
278+
{
279+
moduleName = moduleName.Substring(0, moduleName.Length - 4);
280+
}
281+
282+
foreach (var reference in moduleDefinition.AssemblyReferences)
283+
{
284+
var referenceName = reference.Name.Split(',')[0];
285+
if (referenceName.EndsWith(".dll") || referenceName.EndsWith(".exe"))
286+
{
287+
referenceName = referenceName.Substring(0, referenceName.Length - 4);
288+
}
289+
290+
if (moduleName == referenceName)
291+
{
292+
try
293+
{
294+
moduleDefinition.AssemblyReferences.Remove(reference);
295+
break;
296+
}
297+
catch (Exception)
298+
{
299+
//
300+
}
301+
}
302+
}
303+
}
304+
268305
public static AssemblyDefinition AssemblyDefinitionFor(ICompiledAssembly compiledAssembly, out PostProcessorAssemblyResolver assemblyResolver)
269306
{
270307
assemblyResolver = new PostProcessorAssemblyResolver(compiledAssembly);

com.unity.netcode.gameobjects/Editor/CodeGen/INetworkMessageILPP.cs

Lines changed: 122 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1+
using System;
12
using System.IO;
23
using System.Linq;
34
using System.Collections.Generic;
5+
using System.Reflection;
46
using Mono.Cecil;
57
using Mono.Cecil.Cil;
8+
using Mono.Cecil.Rocks;
69
using Unity.CompilationPipeline.Common.Diagnostics;
710
using Unity.CompilationPipeline.Common.ILPostProcessing;
811
using ILPPInterface = Unity.CompilationPipeline.Common.ILPostProcessing.ILPostProcessor;
12+
using MethodAttributes = Mono.Cecil.MethodAttributes;
913

1014
namespace Unity.Netcode.Editor.CodeGen
1115
{
@@ -14,7 +18,9 @@ internal sealed class INetworkMessageILPP : ILPPInterface
1418
{
1519
public override ILPPInterface GetInstance() => this;
1620

17-
public override bool WillProcess(ICompiledAssembly compiledAssembly) => compiledAssembly.References.Any(filePath => Path.GetFileNameWithoutExtension(filePath) == CodeGenHelpers.RuntimeAssemblyName);
21+
public override bool WillProcess(ICompiledAssembly compiledAssembly) =>
22+
compiledAssembly.Name == CodeGenHelpers.RuntimeAssemblyName ||
23+
compiledAssembly.References.Any(filePath => Path.GetFileNameWithoutExtension(filePath) == CodeGenHelpers.RuntimeAssemblyName);
1824

1925
private readonly List<DiagnosticMessage> m_Diagnostics = new List<DiagnosticMessage>();
2026

@@ -42,11 +48,24 @@ public override ILPostProcessResult Process(ICompiledAssembly compiledAssembly)
4248
{
4349
if (ImportReferences(mainModule))
4450
{
51+
var types = mainModule.GetTypes()
52+
.Where(t => t.Resolve().HasInterface(CodeGenHelpers.INetworkMessage_FullName) && !t.Resolve().IsAbstract)
53+
.ToList();
4554
// process `INetworkMessage` types
46-
mainModule.GetTypes()
47-
.Where(t => t.HasInterface(CodeGenHelpers.INetworkMessage_FullName))
48-
.ToList()
49-
.ForEach(b => ProcessINetworkMessage(b));
55+
if (types.Count == 0)
56+
{
57+
return null;
58+
}
59+
60+
try
61+
{
62+
types.ForEach(b => ProcessINetworkMessage(b));
63+
CreateModuleInitializer(assemblyDefinition, types);
64+
}
65+
catch (Exception e)
66+
{
67+
m_Diagnostics.AddError((e.ToString() + e.StackTrace.ToString()).Replace("\n", "|").Replace("\r", "|"));
68+
}
5069
}
5170
else
5271
{
@@ -58,6 +77,8 @@ public override ILPostProcessResult Process(ICompiledAssembly compiledAssembly)
5877
m_Diagnostics.AddError($"Cannot get main module from assembly definition: {compiledAssembly.Name}");
5978
}
6079

80+
mainModule.RemoveRecursiveReferences();
81+
6182
// write
6283
var pe = new MemoryStream();
6384
var pdb = new MemoryStream();
@@ -77,12 +98,50 @@ public override ILPostProcessResult Process(ICompiledAssembly compiledAssembly)
7798

7899
private TypeReference m_FastBufferReader_TypeRef;
79100
private TypeReference m_NetworkContext_TypeRef;
101+
private FieldReference m_MessagingSystem___network_message_types_FieldRef;
102+
private MethodReference m_Type_GetTypeFromHandle_MethodRef;
103+
104+
private MethodReference m_List_Add_MethodRef;
80105

81106
private bool ImportReferences(ModuleDefinition moduleDefinition)
82107
{
83108
m_FastBufferReader_TypeRef = moduleDefinition.ImportReference(typeof(FastBufferReader));
84109
m_NetworkContext_TypeRef = moduleDefinition.ImportReference(typeof(NetworkContext));
85110

111+
var typeType = typeof(Type);
112+
foreach (var methodInfo in typeType.GetMethods())
113+
{
114+
switch (methodInfo.Name)
115+
{
116+
case nameof(Type.GetTypeFromHandle):
117+
m_Type_GetTypeFromHandle_MethodRef = moduleDefinition.ImportReference(methodInfo);
118+
break;
119+
}
120+
}
121+
122+
var messagingSystemType = typeof(MessagingSystem);
123+
foreach (var fieldInfo in messagingSystemType.GetFields(BindingFlags.Static | BindingFlags.NonPublic))
124+
{
125+
switch (fieldInfo.Name)
126+
{
127+
case nameof(MessagingSystem.__network_message_types):
128+
m_MessagingSystem___network_message_types_FieldRef = moduleDefinition.ImportReference(fieldInfo);
129+
break;
130+
}
131+
}
132+
133+
var listType = typeof(List<Type>);
134+
foreach (var methodInfo in listType.GetMethods())
135+
{
136+
switch (methodInfo.Name)
137+
{
138+
case nameof(List<Type>.Add):
139+
m_List_Add_MethodRef = moduleDefinition.ImportReference(methodInfo);
140+
break;
141+
}
142+
}
143+
144+
86145
return true;
87146
}
88147

@@ -98,6 +157,7 @@ private void ProcessINetworkMessage(TypeDefinition typeDefinition)
98157
{
99158
typeSequence = methodSequence;
100159
}
160+
101161
if (resolved.IsStatic && resolved.IsPublic && resolved.Name == "Receive" && resolved.Parameters.Count == 2
102162
&& !resolved.Parameters[0].IsIn
103163
&& !resolved.Parameters[0].ParameterType.IsByReference
@@ -118,5 +178,62 @@ private void ProcessINetworkMessage(TypeDefinition typeDefinition)
118178
m_Diagnostics.AddError(typeSequence, $"Class {typeDefinition.FullName} does not implement required function: `public static void Receive(FastBufferReader, in NetworkContext)`");
119179
}
120180
}
181+
182+
private MethodDefinition GetOrCreateStaticConstructor(TypeDefinition typeDefinition)
183+
{
184+
var staticCtorMethodDef = typeDefinition.GetStaticConstructor();
185+
if (staticCtorMethodDef == null)
186+
{
187+
staticCtorMethodDef = new MethodDefinition(
188+
".cctor", // Static Constructor (constant-constructor)
189+
MethodAttributes.HideBySig |
190+
MethodAttributes.SpecialName |
191+
MethodAttributes.RTSpecialName |
192+
MethodAttributes.Static,
193+
typeDefinition.Module.TypeSystem.Void);
194+
staticCtorMethodDef.Body.Instructions.Add(Instruction.Create(OpCodes.Ret));
195+
typeDefinition.Methods.Add(staticCtorMethodDef);
196+
}
197+
198+
return staticCtorMethodDef;
199+
}
200+
201+
private void CreateInstructionsToRegisterType(ILProcessor processor, List<Instruction> instructions, TypeReference type)
202+
{
203+
// MessagingSystem.__network_message_types.Add(typeof(type));
204+
instructions.Add(processor.Create(OpCodes.Ldsfld, m_MessagingSystem___network_message_types_FieldRef));
205+
instructions.Add(processor.Create(OpCodes.Ldtoken, type));
206+
instructions.Add(processor.Create(OpCodes.Call, m_Type_GetTypeFromHandle_MethodRef));
207+
instructions.Add(processor.Create(OpCodes.Callvirt, m_List_Add_MethodRef));
208+
}
209+
210+
// Creates a static module constructor (which is executed when the module is loaded) that registers all the
211+
// message types in the assembly with MessagingSystem.
212+
// This is the same behavior as annotating a static method with [ModuleInitializer] in standardized
213+
// C# (that attribute doesn't exist in Unity, but the static module constructor still works)
214+
// https://docs.microsoft.com/en-us/dotnet/api/system.runtime.compilerservices.moduleinitializerattribute?view=net-5.0
215+
// https://web.archive.org/web/20100212140402/http://blogs.msdn.com/junfeng/archive/2005/11/19/494914.aspx
216+
private void CreateModuleInitializer(AssemblyDefinition assembly, List<TypeDefinition> networkMessageTypes)
217+
{
218+
foreach (var typeDefinition in assembly.MainModule.Types)
219+
{
220+
if (typeDefinition.FullName == "<Module>")
221+
{
222+
var staticCtorMethodDef = GetOrCreateStaticConstructor(typeDefinition);
223+
224+
var processor = staticCtorMethodDef.Body.GetILProcessor();
225+
226+
var instructions = new List<Instruction>();
227+
228+
foreach (var type in networkMessageTypes)
229+
{
230+
CreateInstructionsToRegisterType(processor, instructions, type);
231+
}
232+
233+
instructions.ForEach(instruction => processor.Body.Instructions.Insert(processor.Body.Instructions.Count - 1, instruction));
234+
break;
235+
}
236+
}
237+
}
121238
}
122239
}

com.unity.netcode.gameobjects/Editor/CodeGen/NetworkBehaviourILPP.cs

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -467,38 +467,7 @@ private void ProcessNetworkBehaviour(TypeDefinition typeDefinition, string[] ass
467467
typeDefinition.Methods.Add(newGetTypeNameMethod);
468468
}
469469

470-
// Weird behavior from Cecil: When importing a reference to a specific implementation of a generic
471-
// method, it's importing the main module as a reference into itself. This causes Unity to have issues
472-
// when attempting to iterate the assemblies to discover unit tests, as it goes into infinite recursion
473-
// and eventually hits a stack overflow. I wasn't able to find any way to stop Cecil from importing the module
474-
// into itself, so at the end of it all, we're just going to go back and remove it again.
475-
var moduleName = m_MainModule.Name;
476-
if (moduleName.EndsWith(".dll") || moduleName.EndsWith(".exe"))
477-
{
478-
moduleName = moduleName.Substring(0, moduleName.Length - 4);
479-
}
480-
481-
foreach (var reference in m_MainModule.AssemblyReferences)
482-
{
483-
var referenceName = reference.Name.Split(',')[0];
484-
if (referenceName.EndsWith(".dll") || referenceName.EndsWith(".exe"))
485-
{
486-
referenceName = referenceName.Substring(0, referenceName.Length - 4);
487-
}
488-
489-
if (moduleName == referenceName)
490-
{
491-
try
492-
{
493-
m_MainModule.AssemblyReferences.Remove(reference);
494-
break;
495-
}
496-
catch (Exception)
497-
{
498-
//
499-
}
500-
}
501-
}
470+
m_MainModule.RemoveRecursiveReferences();
502471
}
503472

504473
private CustomAttribute CheckAndGetRpcAttribute(MethodDefinition methodDefinition)

com.unity.netcode.gameobjects/Runtime/Messaging/MessagingSystem.cs

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ public InvalidMessageStructureException(string issue) : base(issue) { }
1717

1818
internal class MessagingSystem : IDisposable
1919
{
20+
21+
22+
#pragma warning disable IDE1006 // disable naming rule violation check
23+
// This is NOT modified by RuntimeAccessModifiersILPP right now, but is populated by ILPP.
24+
internal static readonly List<Type> __network_message_types = new List<Type>();
25+
#pragma warning restore IDE1006 // restore naming rule violation check
26+
2027
private struct ReceiveQueueItem
2128
{
2229
public FastBufferReader Reader;
@@ -77,54 +84,43 @@ public MessagingSystem(IMessageSender messageSender, object owner, ulong localCl
7784
m_MessageSender = messageSender;
7885
m_Owner = owner;
7986

80-
var interfaceType = typeof(INetworkMessage);
81-
var implementationTypes = new List<Type>();
82-
foreach (var assembly in AppDomain.CurrentDomain.GetAssemblies())
87+
var allowedTypes = new List<Type>();
88+
foreach (var type in __network_message_types)
8389
{
84-
foreach (var type in assembly.GetTypes())
90+
var attributes = type.GetCustomAttributes(typeof(IgnoreMessageIfSystemOwnerIsNotOfTypeAttribute), false);
91+
// If [IgnoreMessageIfSystemOwnerIsNotOfTypeAttribute(ownerType)] isn't provided, it defaults
92+
// to being bound to NetworkManager. This is technically a breach of domain by having
93+
// MessagingSystem know about the existence of NetworkManager... but ultimately,
94+
// IgnoreMessageIfSystemOwnerIsNotOfTypeAttribute is provided to support testing, not to support
95+
// general use of MessagingSystem outside of Netcode for GameObjects, so having MessagingSystem
96+
// know about NetworkManager isn't so bad. Especially since it's just a default value.
97+
// This is just a convenience to keep us and our users from having to use
98+
// [Bind(typeof(NetworkManager))] on every message - only tests that don't want to use
99+
// the full NetworkManager need to worry about it.
100+
var shouldSkip = attributes.Length != 0 || !(m_Owner is NetworkManager);
101+
for (var i = 0; i < attributes.Length; ++i)
85102
{
86-
if (type.IsInterface || type.IsAbstract)
103+
var bindAttribute = (IgnoreMessageIfSystemOwnerIsNotOfTypeAttribute)attributes[i];
104+
if (
105+
(bindAttribute.BoundType != null &&
106+
bindAttribute.BoundType.IsInstanceOfType(m_Owner)) ||
107+
(m_Owner == null && bindAttribute.BoundType == null))
87108
{
88-
continue;
109+
shouldSkip = false;
110+
break;
89111
}
112+
}
90113

91-
if (interfaceType.IsAssignableFrom(type))
92-
{
93-
var attributes = type.GetCustomAttributes(typeof(IgnoreMessageIfSystemOwnerIsNotOfTypeAttribute), false);
94-
// If [Bind(ownerType)] isn't provided, it defaults to being bound to NetworkManager
95-
// This is technically a breach of domain by having MessagingSystem know about the existence
96-
// of NetworkManager... but ultimately, Bind is provided to support testing, not to support
97-
// general use of MessagingSystem outside of Netcode for GameObjects, so having MessagingSystem
98-
// know about NetworkManager isn't so bad. Especially since it's just a default value.
99-
// This is just a convenience to keep us and our users from having to use
100-
// [Bind(typeof(NetworkManager))] on every message - only tests that don't want to use
101-
// the full NetworkManager need to worry about it.
102-
var allowedToBind = attributes.Length == 0 && m_Owner is NetworkManager;
103-
for (var i = 0; i < attributes.Length; ++i)
104-
{
105-
var bindAttribute = (IgnoreMessageIfSystemOwnerIsNotOfTypeAttribute)attributes[i];
106-
if (
107-
(bindAttribute.BoundType != null &&
108-
bindAttribute.BoundType.IsInstanceOfType(m_Owner)) ||
109-
(m_Owner == null && bindAttribute.BoundType == null))
110-
{
111-
allowedToBind = true;
112-
break;
113-
}
114-
}
115-
116-
if (!allowedToBind)
117-
{
118-
continue;
119-
}
120-
121-
implementationTypes.Add(type);
122-
}
114+
if (shouldSkip)
115+
{
116+
continue;
123117
}
118+
119+
allowedTypes.Add(type);
124120
}
125121

126-
implementationTypes.Sort((a, b) => string.CompareOrdinal(a.FullName, b.FullName));
127-
foreach (var type in implementationTypes)
122+
allowedTypes.Sort((a, b) => string.CompareOrdinal(a.FullName, b.FullName));
123+
foreach (var type in allowedTypes)
128124
{
129125
RegisterMessageType(type);
130126
}

0 commit comments

Comments
 (0)