Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions NWN.Anvil/src/main/Services/Core/Hooking/FunctionHook.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,37 @@ public sealed unsafe class FunctionHook<T> : IDisposable where T : Delegate

// We hold a reference to the delegate to prevent clean up from the garbage collector.
[UsedImplicitly]
private readonly T? handler;
private readonly T? managedHandle;

private readonly HookService hookService;
private readonly FunctionHook* functionHook;

internal FunctionHook(HookService hookService, FunctionHook* functionHook, T? handler = null)
internal FunctionHook(HookService hookService, FunctionHook* functionHook, T? managedHandle = null)
{
this.hookService = hookService;
this.functionHook = functionHook;
this.handler = handler;
this.managedHandle = managedHandle;
CallOriginal = Marshal.GetDelegateForFunctionPointer<T>((IntPtr)functionHook->m_trampoline);
}

private void ReleaseUnmanagedResources()
{
NWNXAPI.ReturnFunctionHook(functionHook);
}

/// <summary>
/// Releases the FunctionHook, restoring the previous behaviour.
/// </summary>
public void Dispose()
{
NWNXAPI.ReturnFunctionHook(functionHook);
ReleaseUnmanagedResources();
GC.SuppressFinalize(this);
hookService.RemoveHook(this);
}

~FunctionHook()
{
ReleaseUnmanagedResources();
}
}
}
53 changes: 40 additions & 13 deletions NWN.Anvil/src/main/Services/Core/Hooking/HookService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
private static readonly Logger Log = LogManager.GetCurrentClassLogger();

private readonly HashSet<IDisposable> hooks = [];
private readonly HashSet<IDisposable> persistentHooks = [];

/// <summary>
/// Requests a hook for a native function.
Expand All @@ -28,11 +29,7 @@
public FunctionHook<T> RequestHook<T>(T handler, int order = HookOrder.Default) where T : Delegate
{
IntPtr managedFuncPtr = Marshal.GetFunctionPointerForDelegate(handler);
FunctionHook* functionHook = CreateHook<T>(managedFuncPtr, order);
FunctionHook<T> hook = new FunctionHook<T>(this, functionHook, handler);
hooks.Add(hook);

return hook;
return CreateHook<T>(managedFuncPtr, false, order, handler);
}

/// <summary>
Expand All @@ -44,14 +41,21 @@
/// <returns>A wrapper object containing a delegate to the original function. The wrapped object can be disposed to release the hook.</returns>
public FunctionHook<T> RequestHook<T>(void* handler, int order = HookOrder.Default) where T : Delegate
{
FunctionHook* functionHook = CreateHook<T>((IntPtr)handler, order);
FunctionHook<T> hook = new FunctionHook<T>(this, functionHook);
hooks.Add(hook);
return CreateHook<T>((IntPtr)handler, false, order);
}

return hook;
internal FunctionHook<T> RequestCoreHook<T>(T handler, int order = HookOrder.Default) where T : Delegate
{
IntPtr managedFuncPtr = Marshal.GetFunctionPointerForDelegate(handler);
return CreateHook<T>(managedFuncPtr, true, order);
}

private FunctionHook* CreateHook<T>(IntPtr handler, int order)
internal FunctionHook<T> RequestCoreHook<T>(void* handler, int order = HookOrder.Default) where T : Delegate
{
return CreateHook<T>((IntPtr)handler, true, order);
}

private FunctionHook<T> CreateHook<T>(IntPtr managedFuncPtr, bool persist, int order = HookOrder.Default, T? managedFunc = null) where T : Delegate
{
NativeFunctionAttribute? info = typeof(T).GetCustomAttribute<NativeFunctionAttribute>();
if (info == null)
Expand All @@ -60,14 +64,34 @@
}

Log.Debug("Requesting function hook for {HookType}, address {Address}", typeof(T).Name, $"0x{info.Address:X}");
return NWNXAPI.RequestFunctionHook(info.Address, handler, order);
FunctionHook* nativeHook = NWNXAPI.RequestFunctionHook(info.Address, managedFuncPtr, order);
FunctionHook<T> hook = new FunctionHook<T>(this, nativeHook, managedFunc);

if (persist)
{
persistentHooks.Add(hook);
}
else
{
hooks.Add(hook);
}

return hook;
}

void ICoreService.Init() {}

void ICoreService.Load() {}

void ICoreService.Shutdown() {}
void ICoreService.Shutdown()
{
foreach (IDisposable hook in persistentHooks.ToList())
{
hook.Dispose();
}

persistentHooks.Clear();
}

void ICoreService.Start() {}

Expand All @@ -83,7 +107,10 @@

internal void RemoveHook<T>(FunctionHook<T> hook) where T : Delegate
{
hooks.Remove(hook);
if (!hooks.Remove(hook))
{
persistentHooks.Remove(hook);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ internal sealed unsafe class ModuleLoadTracker(HookService hookService) : ICoreS

void ICoreService.Init()
{
loadModuleInProgressHook = hookService.RequestHook<Functions.CNWSModule.LoadModuleInProgress>(OnModuleLoadProgressChange, HookOrder.Earliest);
loadModuleInProgressHook = hookService.RequestCoreHook<Functions.CNWSModule.LoadModuleInProgress>(OnModuleLoadProgressChange, HookOrder.Earliest);
}

void ICoreService.Load() {}
Expand Down
16 changes: 13 additions & 3 deletions NWN.Anvil/src/main/Services/ObjectStorage/ObjectStorageService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@

namespace Anvil.Services
{
[ServiceBinding(typeof(ObjectStorageService))]
public sealed unsafe class ObjectStorageService
public sealed unsafe class ObjectStorageService : ICoreService
{
private static readonly Logger Log = LogManager.GetCurrentClassLogger();

Expand All @@ -22,9 +21,10 @@ public sealed unsafe class ObjectStorageService
private readonly FunctionHook<Functions.CNWSPlayer.EatTURD> eatTURDHook;
private readonly FunctionHook<Functions.CNWSUUID.LoadFromGff> loadFromGffHook;
private readonly FunctionHook<Functions.CNWSObject.Destructor> objectDestructorHook;
private readonly Dictionary<IntPtr, ObjectStorage> objectStorage = new Dictionary<IntPtr, ObjectStorage>();
private readonly FunctionHook<Functions.CNWSUUID.SaveToGff> saveToGffHook;

private readonly Dictionary<IntPtr, ObjectStorage> objectStorage = new Dictionary<IntPtr, ObjectStorage>();

public ObjectStorageService(HookService hookService)
{
objectDestructorHook = hookService.RequestHook<Functions.CNWSObject.Destructor>(OnObjectDestructor, HookOrder.VeryEarly);
Expand Down Expand Up @@ -177,5 +177,15 @@ private void OnSaveToGff(void* pUUID, void* pRes, void* pStruct)

saveToGffHook.CallOriginal(pUUID, pRes, pStruct);
}

void ICoreService.Init() {}

void ICoreService.Load() {}

void ICoreService.Shutdown() {}

void ICoreService.Start() {}

void ICoreService.Unload() {}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ internal sealed class AnvilCoreServiceManager
private readonly EncodingService encodingService;
private readonly ResourceManager resourceManager;
private readonly AnvilMessageService anvilMessageService;
private readonly ObjectStorageService objectStorageService;

public AnvilCoreServiceManager(IServiceContainer container)
{
Expand All @@ -37,6 +38,7 @@ public AnvilCoreServiceManager(IServiceContainer container)
container.RegisterCoreService<EncodingService>();
container.RegisterCoreService<ResourceManager>();
container.RegisterCoreService<AnvilMessageService>();
container.RegisterCoreService<ObjectStorageService>();

container.Compile();

Expand All @@ -52,6 +54,7 @@ public AnvilCoreServiceManager(IServiceContainer container)
encodingService = container.GetInstance<EncodingService>();
hookService = container.GetInstance<HookService>();
moduleLoadTracker = container.GetInstance<ModuleLoadTracker>();
objectStorageService = container.GetInstance<ObjectStorageService>();
}

public void Init()
Expand All @@ -69,6 +72,7 @@ public void Init()
InitService(encodingService);
InitService(hookService);
InitService(moduleLoadTracker);
InitService(objectStorageService);
}

public void Load()
Expand All @@ -86,6 +90,7 @@ public void Load()
LoadService(encodingService);
LoadService(hookService);
LoadService(moduleLoadTracker);
LoadService(objectStorageService);
}

public void Start()
Expand All @@ -103,11 +108,13 @@ public void Start()
StartService(encodingService);
StartService(hookService);
StartService(moduleLoadTracker);
StartService(objectStorageService);
}

public void Unload()
{
Log.Info("Unloading core services...");
UnloadService(objectStorageService);
UnloadService(moduleLoadTracker);
UnloadService(hookService);
UnloadService(encodingService);
Expand All @@ -124,6 +131,7 @@ public void Unload()

public void Shutdown()
{
ShutdownService(objectStorageService);
ShutdownService(moduleLoadTracker);
ShutdownService(hookService);
ShutdownService(encodingService);
Expand Down
Loading