diff --git a/NWN.Anvil/src/main/Services/Core/Hooking/FunctionHook.cs b/NWN.Anvil/src/main/Services/Core/Hooking/FunctionHook.cs index 75546cb97..81fa06c09 100644 --- a/NWN.Anvil/src/main/Services/Core/Hooking/FunctionHook.cs +++ b/NWN.Anvil/src/main/Services/Core/Hooking/FunctionHook.cs @@ -15,26 +15,37 @@ public sealed unsafe class FunctionHook : IDisposable where T : Delegate // We hold a reference to the delegate to prevent clean up from the garbage collector. [UsedImplicitly] - private readonly T? handler; + private readonly T? managedHandle; private readonly HookService hookService; private readonly FunctionHook* functionHook; - internal FunctionHook(HookService hookService, FunctionHook* functionHook, T? handler = null) + internal FunctionHook(HookService hookService, FunctionHook* functionHook, T? managedHandle = null) { this.hookService = hookService; this.functionHook = functionHook; - this.handler = handler; + this.managedHandle = managedHandle; CallOriginal = Marshal.GetDelegateForFunctionPointer((IntPtr)functionHook->m_trampoline); } + private void ReleaseUnmanagedResources() + { + NWNXAPI.ReturnFunctionHook(functionHook); + } + /// /// Releases the FunctionHook, restoring the previous behaviour. /// public void Dispose() { - NWNXAPI.ReturnFunctionHook(functionHook); + ReleaseUnmanagedResources(); + GC.SuppressFinalize(this); hookService.RemoveHook(this); } + + ~FunctionHook() + { + ReleaseUnmanagedResources(); + } } } diff --git a/NWN.Anvil/src/main/Services/Core/Hooking/HookService.cs b/NWN.Anvil/src/main/Services/Core/Hooking/HookService.cs index 94b8b2471..adabeaad9 100644 --- a/NWN.Anvil/src/main/Services/Core/Hooking/HookService.cs +++ b/NWN.Anvil/src/main/Services/Core/Hooking/HookService.cs @@ -17,6 +17,7 @@ public sealed unsafe class HookService : ICoreService private static readonly Logger Log = LogManager.GetCurrentClassLogger(); private readonly HashSet hooks = []; + private readonly HashSet persistentHooks = []; /// /// Requests a hook for a native function. @@ -28,11 +29,7 @@ public sealed unsafe class HookService : ICoreService public FunctionHook RequestHook(T handler, int order = HookOrder.Default) where T : Delegate { IntPtr managedFuncPtr = Marshal.GetFunctionPointerForDelegate(handler); - FunctionHook* functionHook = CreateHook(managedFuncPtr, order); - FunctionHook hook = new FunctionHook(this, functionHook, handler); - hooks.Add(hook); - - return hook; + return CreateHook(managedFuncPtr, false, order, handler); } /// @@ -44,14 +41,21 @@ public FunctionHook RequestHook(T handler, int order = HookOrder.Default) /// A wrapper object containing a delegate to the original function. The wrapped object can be disposed to release the hook. public FunctionHook RequestHook(void* handler, int order = HookOrder.Default) where T : Delegate { - FunctionHook* functionHook = CreateHook((IntPtr)handler, order); - FunctionHook hook = new FunctionHook(this, functionHook); - hooks.Add(hook); + return CreateHook((IntPtr)handler, false, order); + } - return hook; + internal FunctionHook RequestCoreHook(T handler, int order = HookOrder.Default) where T : Delegate + { + IntPtr managedFuncPtr = Marshal.GetFunctionPointerForDelegate(handler); + return CreateHook(managedFuncPtr, true, order, handler); } - private FunctionHook* CreateHook(IntPtr handler, int order) + internal FunctionHook RequestCoreHook(void* handler, int order = HookOrder.Default) where T : Delegate + { + return CreateHook((IntPtr)handler, true, order); + } + + private FunctionHook CreateHook(IntPtr managedFuncPtr, bool persist, int order = HookOrder.Default, T? managedFunc = null) where T : Delegate { NativeFunctionAttribute? info = typeof(T).GetCustomAttribute(); if (info == null) @@ -60,14 +64,34 @@ public FunctionHook RequestHook(void* handler, int order = HookOrder.Defau } Log.Debug("Requesting function hook for {HookType}, address {Address}", typeof(T).Name, $"0x{info.Address:X}"); - return NWNXAPI.RequestFunctionHook(info.Address, handler, order); + FunctionHook* nativeHook = NWNXAPI.RequestFunctionHook(info.Address, managedFuncPtr, order); + FunctionHook hook = new FunctionHook(this, nativeHook, managedFunc); + + if (persist) + { + persistentHooks.Add(hook); + } + else + { + hooks.Add(hook); + } + + return hook; } void ICoreService.Init() {} void ICoreService.Load() {} - void ICoreService.Shutdown() {} + void ICoreService.Shutdown() + { + foreach (IDisposable hook in persistentHooks.ToList()) + { + hook.Dispose(); + } + + persistentHooks.Clear(); + } void ICoreService.Start() {} @@ -83,7 +107,10 @@ void ICoreService.Unload() internal void RemoveHook(FunctionHook hook) where T : Delegate { - hooks.Remove(hook); + if (!hooks.Remove(hook)) + { + persistentHooks.Remove(hook); + } } } } diff --git a/NWN.Anvil/src/main/Services/Core/Logging/ModuleLoadTracker.cs b/NWN.Anvil/src/main/Services/Core/Logging/ModuleLoadTracker.cs index 0d9429cc9..30ac55cfa 100644 --- a/NWN.Anvil/src/main/Services/Core/Logging/ModuleLoadTracker.cs +++ b/NWN.Anvil/src/main/Services/Core/Logging/ModuleLoadTracker.cs @@ -12,7 +12,7 @@ internal sealed unsafe class ModuleLoadTracker(HookService hookService) : ICoreS void ICoreService.Init() { - loadModuleInProgressHook = hookService.RequestHook(OnModuleLoadProgressChange, HookOrder.Earliest); + loadModuleInProgressHook = hookService.RequestCoreHook(OnModuleLoadProgressChange, HookOrder.Earliest); } void ICoreService.Load() {} diff --git a/NWN.Anvil/src/main/Services/ObjectStorage/ObjectStorageService.cs b/NWN.Anvil/src/main/Services/ObjectStorage/ObjectStorageService.cs index b0769b225..09c642ff1 100644 --- a/NWN.Anvil/src/main/Services/ObjectStorage/ObjectStorageService.cs +++ b/NWN.Anvil/src/main/Services/ObjectStorage/ObjectStorageService.cs @@ -9,8 +9,7 @@ namespace Anvil.Services { - [ServiceBinding(typeof(ObjectStorageService))] - public sealed unsafe class ObjectStorageService + public sealed unsafe class ObjectStorageService : ICoreService { private static readonly Logger Log = LogManager.GetCurrentClassLogger(); @@ -22,9 +21,10 @@ public sealed unsafe class ObjectStorageService private readonly FunctionHook eatTURDHook; private readonly FunctionHook loadFromGffHook; private readonly FunctionHook objectDestructorHook; - private readonly Dictionary objectStorage = new Dictionary(); private readonly FunctionHook saveToGffHook; + private readonly Dictionary objectStorage = new Dictionary(); + public ObjectStorageService(HookService hookService) { objectDestructorHook = hookService.RequestHook(OnObjectDestructor, HookOrder.VeryEarly); @@ -177,5 +177,15 @@ private void OnSaveToGff(void* pUUID, void* pRes, void* pStruct) saveToGffHook.CallOriginal(pUUID, pRes, pStruct); } + + void ICoreService.Init() {} + + void ICoreService.Load() {} + + void ICoreService.Shutdown() {} + + void ICoreService.Start() {} + + void ICoreService.Unload() {} } } diff --git a/NWN.Anvil/src/main/Services/Services/AnvilCoreServiceManager.cs b/NWN.Anvil/src/main/Services/Services/AnvilCoreServiceManager.cs index 4f4bbe984..a0a878e0a 100644 --- a/NWN.Anvil/src/main/Services/Services/AnvilCoreServiceManager.cs +++ b/NWN.Anvil/src/main/Services/Services/AnvilCoreServiceManager.cs @@ -22,6 +22,7 @@ internal sealed class AnvilCoreServiceManager private readonly EncodingService encodingService; private readonly ResourceManager resourceManager; private readonly AnvilMessageService anvilMessageService; + private readonly ObjectStorageService objectStorageService; public AnvilCoreServiceManager(IServiceContainer container) { @@ -37,6 +38,7 @@ public AnvilCoreServiceManager(IServiceContainer container) container.RegisterCoreService(); container.RegisterCoreService(); container.RegisterCoreService(); + container.RegisterCoreService(); container.Compile(); @@ -52,6 +54,7 @@ public AnvilCoreServiceManager(IServiceContainer container) encodingService = container.GetInstance(); hookService = container.GetInstance(); moduleLoadTracker = container.GetInstance(); + objectStorageService = container.GetInstance(); } public void Init() @@ -69,6 +72,7 @@ public void Init() InitService(encodingService); InitService(hookService); InitService(moduleLoadTracker); + InitService(objectStorageService); } public void Load() @@ -86,6 +90,7 @@ public void Load() LoadService(encodingService); LoadService(hookService); LoadService(moduleLoadTracker); + LoadService(objectStorageService); } public void Start() @@ -103,11 +108,13 @@ public void Start() StartService(encodingService); StartService(hookService); StartService(moduleLoadTracker); + StartService(objectStorageService); } public void Unload() { Log.Info("Unloading core services..."); + UnloadService(objectStorageService); UnloadService(moduleLoadTracker); UnloadService(hookService); UnloadService(encodingService); @@ -124,6 +131,7 @@ public void Unload() public void Shutdown() { + ShutdownService(objectStorageService); ShutdownService(moduleLoadTracker); ShutdownService(hookService); ShutdownService(encodingService);