Skip to content

Commit daffe73

Browse files
authored
Merge pull request #1180 from dpmm99/feat/tensor-override
Feat/tensor override
2 parents 1668e76 + eeb8c8f commit daffe73

File tree

12 files changed

+265
-6
lines changed

12 files changed

+265
-6
lines changed

.github/_typos.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,6 @@ extend-exclude = [
1717

1818
[default.extend-words]
1919
# Used in a comment in SafeLLamaSamplerHandle.cs, as a prefix of "hello"
20-
teh = "hel"
20+
teh = "hel"
21+
# ot is the shorthand version of llama.cpp's override-tensor parameter
22+
ot = "ot"

LLama.Unittest/ModelsParamsTests.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ public void SerializeRoundTripSystemTextJson()
4141
actual.MetadataOverrides = null!;
4242
expected.MetadataOverrides = null!;
4343

44+
// Same deal
45+
Assert.True(expected.TensorBufferOverrides.SequenceEqual(actual.TensorBufferOverrides));
46+
actual.TensorBufferOverrides = null!;
47+
expected.TensorBufferOverrides = null!;
48+
4449
// Check encoding is the same
4550
var b1 = expected.Encoding.GetBytes("Hello");
4651
var b2 = actual.Encoding.GetBytes("Hello");

LLama.Web/Common/ModelOptions.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ public class ModelOptions
2626
/// <inheritdoc />
2727
public GPUSplitMode? SplitMode { get; set; }
2828

29+
/// <inheritdoc />
30+
public List<TensorBufferOverride> TensorBufferOverrides { get; set; } = new();
31+
2932
/// <inheritdoc />
3033
public int GpuLayerCount { get; set; } = 20;
3134

LLama/Abstractions/IModelParams.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ public interface IModelParams
3838
/// </summary>
3939
GPUSplitMode? SplitMode { get; }
4040

41+
/// <summary>
42+
/// Buffer type overrides for specific tensor patterns, allowing you to specify hardware devices to use for individual tensors or sets of tensors.
43+
/// Equivalent to --override-tensor or -ot on the llama.cpp command line or tensor_buft_overrides internally.
44+
/// </summary>
45+
List<TensorBufferOverride> TensorBufferOverrides { get; }
46+
4147
/// <summary>
4248
/// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
4349
/// </summary>
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using System;
2+
3+
namespace LLama.Abstractions
4+
{
5+
/// <summary>
6+
/// Represents a mapping between a tensor name pattern and a specific buffer type
7+
/// </summary>
8+
public class TensorBufferOverride
9+
{
10+
/// <summary>
11+
/// Pattern to match tensor names. This is a regular expression. You can check the tensor names via the model.Metadata.
12+
/// </summary>
13+
public string Pattern { get; set; }
14+
15+
/// <summary>
16+
/// Buffer type to use for matching tensors. Examples: CPU, GPU0, GPU1
17+
/// </summary>
18+
public string BufferType { get; set; }
19+
20+
/// <summary>
21+
/// Creates a new tensor buffer override
22+
/// </summary>
23+
/// <param name="pattern">Pattern to match tensor names</param>
24+
/// <param name="bufferType">Buffer type to use for matching tensors</param>
25+
public TensorBufferOverride(string pattern, string bufferType)
26+
{
27+
if (string.IsNullOrEmpty(pattern))
28+
throw new ArgumentException("Pattern cannot be null or empty", nameof(pattern));
29+
if (string.IsNullOrEmpty(bufferType))
30+
throw new ArgumentException("Buffer type cannot be null or empty", nameof(bufferType));
31+
32+
Pattern = pattern;
33+
BufferType = bufferType;
34+
}
35+
}
36+
}

LLama/Common/ModelParams.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ public record ModelParams
2121
/// <inheritdoc />
2222
public GPUSplitMode? SplitMode { get; set; }
2323

24+
/// <inheritdoc />
25+
public List<TensorBufferOverride> TensorBufferOverrides { get; set; } = new();
26+
2427
/// <inheritdoc />
2528
public int GpuLayerCount { get; set; } = 20;
2629

LLama/Extensions/IModelParamsExtensions.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,20 @@ public static IDisposable ToLlamaModelParams(this IModelParams @params, out LLam
4545
result.tensor_split = (float*)disposer.Add(@params.TensorSplits.Pin()).Pointer;
4646
}
4747

48+
// Add tensor buffer overrides, if any
49+
if (@params.TensorBufferOverrides.Count > 0)
50+
{
51+
var bufferOverrideHelper = new LLamaTensorBufferOverrideHelper();
52+
disposer.Add(bufferOverrideHelper);
53+
54+
foreach (var tensorOverride in @params.TensorBufferOverrides)
55+
{
56+
bufferOverrideHelper.AddOverride(tensorOverride.Pattern, tensorOverride.BufferType);
57+
}
58+
59+
bufferOverrideHelper.ApplyToModelParams(ref result);
60+
}
61+
4862
if (@params.MetadataOverrides.Count == 0)
4963
{
5064
unsafe

LLama/Native/LLamaModelParams.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ public unsafe struct LLamaModelParams
1313
/// todo: add support for llama_model_params.devices
1414
/// </summary>
1515
private IntPtr devices;
16-
17-
// NULL-terminated list of buffer types to use for tensors that match a pattern
18-
// actual type: llama_model_tensor_buft_override*
19-
// todo: add support for tensor_buft_overrides
20-
private IntPtr tensor_buft_overrides;
16+
17+
/// <summary>
18+
/// NULL-terminated list of buffer types to use for tensors that match a pattern
19+
/// </summary>
20+
public LLamaModelTensorBufferOverride* tensor_buft_overrides;
2121

2222
/// <summary>
2323
/// // number of layers to store in VRAM
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using System;
2+
3+
namespace LLama.Native
4+
{
5+
/// <summary>
6+
/// Represents a mapping between a tensor name pattern and a backend buffer type<br/>
7+
/// Original type: llama_model_tensor_buft_override
8+
/// </summary>
9+
[StructLayout(LayoutKind.Sequential)]
10+
public struct LLamaModelTensorBufferOverride
11+
{
12+
/// <summary>
13+
/// Tensor name pattern to match
14+
/// </summary>
15+
public IntPtr Pattern;
16+
17+
/// <summary>
18+
/// Backend buffer type to use for matching tensors, as obtained via ggml_backend_dev_buffer_type
19+
/// </summary>
20+
public IntPtr BufferType;
21+
}
22+
}
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace LLama.Native
6+
{
7+
/// <summary>
8+
/// Helper for creating and managing tensor buffer overrides
9+
/// </summary>
10+
internal class LLamaTensorBufferOverrideHelper : IDisposable
11+
{
12+
private readonly List<IntPtr> _allocatedMemory = new();
13+
private readonly List<LLamaModelTensorBufferOverride> _overrides = new();
14+
private IntPtr _overrideArray = IntPtr.Zero;
15+
private readonly Dictionary<string, IntPtr> _bufferTypeCache = new();
16+
17+
/// <summary>
18+
/// Get all available buffer types
19+
/// </summary>
20+
/// <returns>Dictionary mapping buffer type names to their handles</returns>
21+
public Dictionary<string, IntPtr> GetAvailableBufferTypes()
22+
{
23+
var result = new Dictionary<string, IntPtr>();
24+
25+
nuint count = NativeApi.ggml_backend_dev_count();
26+
for (nuint i = 0; i < count; i++)
27+
{
28+
IntPtr dev = NativeApi.ggml_backend_dev_get(i);
29+
IntPtr buft = NativeApi.ggml_backend_dev_buffer_type(dev);
30+
31+
if (buft != IntPtr.Zero)
32+
{
33+
IntPtr namePtr = NativeApi.ggml_backend_buft_name(buft);
34+
string name = Marshal.PtrToStringAnsi(namePtr) ?? string.Empty;
35+
36+
if (!string.IsNullOrEmpty(name))
37+
{
38+
result[name] = buft;
39+
_bufferTypeCache[name] = buft;
40+
}
41+
}
42+
}
43+
44+
return result;
45+
}
46+
47+
/// <summary>
48+
/// Add a tensor buffer override
49+
/// </summary>
50+
/// <param name="pattern">Tensor name pattern to match</param>
51+
/// <param name="bufferTypeName">Name of the buffer type to use</param>
52+
/// <returns>True if the override was added successfully</returns>
53+
public bool AddOverride(string pattern, string bufferTypeName)
54+
{
55+
if (string.IsNullOrEmpty(pattern) || string.IsNullOrEmpty(bufferTypeName))
56+
return false;
57+
58+
// Get all buffer types if cache is empty
59+
if (_bufferTypeCache.Count == 0)
60+
{
61+
GetAvailableBufferTypes();
62+
}
63+
64+
// Check if we have this buffer type
65+
if (!_bufferTypeCache.TryGetValue(bufferTypeName, out IntPtr bufferType))
66+
return false;
67+
68+
// Allocate memory for the pattern string and keep track of it
69+
byte[] patternBytes = Encoding.UTF8.GetBytes(pattern + "\0");
70+
IntPtr patternPtr = Marshal.AllocHGlobal(patternBytes.Length);
71+
Marshal.Copy(patternBytes, 0, patternPtr, patternBytes.Length);
72+
_allocatedMemory.Add(patternPtr);
73+
74+
// Create the override
75+
var @override = new LLamaModelTensorBufferOverride
76+
{
77+
Pattern = patternPtr,
78+
BufferType = bufferType
79+
};
80+
81+
_overrides.Add(@override);
82+
return true;
83+
}
84+
85+
/// <summary>
86+
/// Apply the overrides to model parameters
87+
/// </summary>
88+
/// <param name="modelParams">Model parameters to update</param>
89+
public unsafe void ApplyToModelParams(ref LLamaModelParams modelParams)
90+
{
91+
if (_overrides.Count == 0)
92+
{
93+
modelParams.tensor_buft_overrides = null;
94+
return;
95+
}
96+
97+
// Free previous array if it exists
98+
if (_overrideArray != IntPtr.Zero)
99+
{
100+
Marshal.FreeHGlobal(_overrideArray);
101+
}
102+
103+
// Allocate memory for the array + null terminator
104+
int size = Marshal.SizeOf<LLamaModelTensorBufferOverride>() * (_overrides.Count + 1);
105+
_overrideArray = Marshal.AllocHGlobal(size);
106+
_allocatedMemory.Add(_overrideArray);
107+
108+
// Copy overrides to array
109+
for (int i = 0; i < _overrides.Count; i++)
110+
{
111+
IntPtr elemPtr = IntPtr.Add(_overrideArray, i * Marshal.SizeOf<LLamaModelTensorBufferOverride>());
112+
Marshal.StructureToPtr(_overrides[i], elemPtr, false);
113+
}
114+
115+
// Add null terminator
116+
IntPtr nullTermPtr = IntPtr.Add(_overrideArray, _overrides.Count * Marshal.SizeOf<LLamaModelTensorBufferOverride>());
117+
Marshal.StructureToPtr(new LLamaModelTensorBufferOverride { Pattern = IntPtr.Zero, BufferType = IntPtr.Zero }, nullTermPtr, false);
118+
119+
// Update model params
120+
modelParams.tensor_buft_overrides = (LLamaModelTensorBufferOverride*)_overrideArray.ToPointer();
121+
}
122+
123+
/// <inheritdoc />
124+
public void Dispose()
125+
{
126+
foreach (IntPtr ptr in _allocatedMemory)
127+
{
128+
Marshal.FreeHGlobal(ptr);
129+
}
130+
_allocatedMemory.Clear();
131+
_overrides.Clear();
132+
_overrideArray = IntPtr.Zero;
133+
}
134+
}
135+
}

0 commit comments

Comments
 (0)