Skip to content

Commit 8af713e

Browse files
committed
Fixed getting name of sampler stages:
- Preallocating a block of memory to store the name, we can return a pointer to this - Fixed string marshalling back from C++
1 parent 44c1ff0 commit 8af713e

File tree

1 file changed

+24
-7
lines changed

1 file changed

+24
-7
lines changed

LLama/Native/SafeLLamaSamplerHandle.cs

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
using System;
2+
using System.Runtime.CompilerServices;
3+
using System.Text;
24

35
namespace LLama.Native;
46

@@ -117,10 +119,10 @@ public string GetName(int index)
117119
if (index < 0 || index >= Count)
118120
throw new ArgumentOutOfRangeException(nameof(index));
119121

120-
return llama_sampler_name(llama_sampler_chain_get(this, index));
122+
return Marshal.PtrToStringAnsi(llama_sampler_name(llama_sampler_chain_get(this, index))) ?? "Unknown Name";
121123

122124
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
123-
static extern string llama_sampler_name(IntPtr smpl);
125+
static extern IntPtr llama_sampler_name(IntPtr smpl);
124126
}
125127

126128
/// <summary>
@@ -531,7 +533,7 @@ internal struct LLamaSamplerINative
531533
/// <param name="smpl"></param>
532534
/// <returns></returns>
533535
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
534-
public delegate string NameDelegate(ref LLamaSamplerNative smpl);
536+
public unsafe delegate byte* NameDelegate(ref LLamaSamplerNative smpl);
535537

536538
/// <summary>
537539
/// Update internal sampler state after a token has been chosen
@@ -571,7 +573,7 @@ internal struct LLamaSamplerINative
571573
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
572574
public delegate void FreeDelegate(ref LLamaSamplerNative smpl);
573575

574-
public unsafe delegate*<char*> Name;
576+
public unsafe delegate*<byte*> Name;
575577
public unsafe delegate*<LLamaSamplerNative*, LLamaToken, void> Accept;
576578
public unsafe delegate*<LLamaSamplerNative*, LLamaTokenDataArrayNative*, void> Apply;
577579
public unsafe delegate*<LLamaSamplerNative*, void> Reset;
@@ -612,6 +614,7 @@ internal class CustomSamplerHandle
612614

613615
private unsafe LLamaSamplerNative* _samplerNativePtr;
614616
private unsafe LLamaSamplerINative* _samplerNativeInterfacePtr;
617+
private unsafe byte* _samplerNamePtr;
615618

616619
private CustomSamplerHandle(ICustomSampler sampler)
617620
{
@@ -620,22 +623,30 @@ private CustomSamplerHandle(ICustomSampler sampler)
620623

621624
public static CustomSamplerHandle Create(ICustomSampler sampler)
622625
{
626+
var nameArr = Encoding.UTF8.GetBytes(sampler.Name + '\0');
627+
623628
var handle = new CustomSamplerHandle(sampler);
624629
handle._gcHandle = GCHandle.Alloc(handle);
625630

626631
unsafe
627632
{
633+
// Allocate space for a `LLamaSamplerINative` struct. So we can pass pointers to it.
628634
handle._samplerNativeInterfacePtr = (LLamaSamplerINative*)Marshal.AllocHGlobal(sizeof(LLamaSamplerINative));
629-
handle._samplerNativeInterfacePtr->Name = (delegate*<char*>)Marshal.GetFunctionPointerForDelegate<LLamaSamplerINative.NameDelegate>(Name);
635+
handle._samplerNativeInterfacePtr->Name = (delegate*<byte*>)Marshal.GetFunctionPointerForDelegate<LLamaSamplerINative.NameDelegate>(Name);
630636
handle._samplerNativeInterfacePtr->Accept = (delegate*<LLamaSamplerNative*, LLamaToken, void>)Marshal.GetFunctionPointerForDelegate<LLamaSamplerINative.AcceptDelegate>(Accept);
631637
handle._samplerNativeInterfacePtr->Apply = (delegate*<LLamaSamplerNative*, LLamaTokenDataArrayNative*, void>)Marshal.GetFunctionPointerForDelegate<LLamaSamplerINative.ApplyDelegate>(Apply);
632638
handle._samplerNativeInterfacePtr->Reset = (delegate*<LLamaSamplerNative*, void>)Marshal.GetFunctionPointerForDelegate<LLamaSamplerINative.ResetDelegate>(Reset);
633639
handle._samplerNativeInterfacePtr->Clone = (delegate*<LLamaSamplerNative*, IntPtr>)Marshal.GetFunctionPointerForDelegate<LLamaSamplerINative.CloneDelegate>(Clone);
634640
handle._samplerNativeInterfacePtr->Free = (delegate*<LLamaSamplerNative*, void>)Marshal.GetFunctionPointerForDelegate<LLamaSamplerINative.FreeDelegate>(Free);
635641

642+
// Allocate space for a `LLamaSamplerNative` struct. So we can pass pointers to it.
636643
handle._samplerNativePtr = (LLamaSamplerNative*)Marshal.AllocHGlobal(sizeof(LLamaSamplerNative));
637644
handle._samplerNativePtr->Context = (IntPtr)handle._gcHandle;
638645
handle._samplerNativePtr->Interface = handle._samplerNativeInterfacePtr;
646+
647+
// Allocate space for the name string
648+
handle._samplerNamePtr = (byte*)Marshal.AllocHGlobal(nameArr.Length);
649+
nameArr.AsSpan().CopyTo(new Span<byte>(handle._samplerNamePtr, nameArr.Length));
639650
}
640651

641652
return handle;
@@ -656,9 +667,9 @@ private static CustomSamplerHandle GetSampler(ref LLamaSamplerNative smpl)
656667
return (CustomSamplerHandle)GCHandle.FromIntPtr(smpl.Context).Target!;
657668
}
658669

659-
private static string Name(ref LLamaSamplerNative smpl)
670+
private static unsafe byte* Name(ref LLamaSamplerNative smpl)
660671
{
661-
return GetSampler(ref smpl)._sampler.Name;
672+
return GetSampler(ref smpl)._samplerNamePtr;
662673
}
663674

664675
private static void Accept(ref LLamaSamplerNative smpl, LLamaToken token)
@@ -699,6 +710,12 @@ private static unsafe void Free(ref LLamaSamplerNative smpl)
699710
sampler._samplerNativeInterfacePtr = null;
700711
}
701712

713+
if (sampler._samplerNamePtr != null)
714+
{
715+
Marshal.FreeHGlobal((IntPtr)sampler._samplerNamePtr);
716+
sampler._samplerNamePtr = null;
717+
}
718+
702719
sampler._gcHandle.Free();
703720

704721
sampler._sampler.Dispose();

0 commit comments

Comments
 (0)