Skip to content

Commit 6df26d3

Browse files
authored
Merge pull request #961 from martindevans/experimental_custom_sampler_wip
Custom Sampler Stages
2 parents 40ea046 + 8af713e commit 6df26d3

File tree

4 files changed

+337
-38
lines changed

4 files changed

+337
-38
lines changed

LLama.Examples/ExampleRunner.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ public class ExampleRunner
3535
{ "Batched Executor: LLava", BatchedExecutorLLava.Run },
3636
{ "Batched Executor: BoolQ Benchmark", BatchedExecutorBoolQ.Run },
3737
{ "Batched Executor: Beam Search", BatchedExecutorBeamSearch.Run },
38+
{ "Custom Sampling Pipeline", CustomSampler.Run },
3839
{ "Speech Chat: Integration with Whisper.net", SpeechChat.Run },
3940
{ "Exit", () => { Environment.Exit(0); return Task.CompletedTask; } }
4041
};
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
using LLama.Common;
2+
using LLama.Examples.Extensions;
3+
using LLama.Native;
4+
using LLama.Sampling;
5+
6+
namespace LLama.Examples.Examples
7+
{
8+
public class CustomSampler
9+
{
10+
public static async Task Run()
11+
{
12+
var modelPath = UserSettings.GetModelPath();
13+
14+
var parameters = new ModelParams(modelPath);
15+
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
16+
17+
var ex = new StatelessExecutor(model, parameters);
18+
19+
Console.ForegroundColor = ConsoleColor.Yellow;
20+
Console.WriteLine("In this example a custom sampling pipeline with a custom sampler stage is being used. This demonstrates how to customise the samplers used, and " +
21+
"how to create a completely custom sampler stage which modifies the logits or selects a token." +
22+
"" +
23+
"In this case the custom sampler stage removes the most likely token. This will probably produce bad results, it's just a demo!"
24+
);
25+
Console.ForegroundColor = ConsoleColor.White;
26+
27+
var inferenceParams = new InferenceParams
28+
{
29+
SamplingPipeline = new CustomSamplingPipeline(),
30+
MaxTokens = 50
31+
};
32+
33+
while (true)
34+
{
35+
Console.Write("\nQuestion: ");
36+
Console.ForegroundColor = ConsoleColor.Green;
37+
var prompt = Console.ReadLine();
38+
Console.ForegroundColor = ConsoleColor.White;
39+
Console.Write("Answer: ");
40+
prompt = $"Question: {prompt?.Trim()} Answer: ";
41+
await foreach (var text in ex.InferAsync(prompt, inferenceParams).Spinner())
42+
{
43+
Console.Write(text);
44+
}
45+
}
46+
}
47+
}
48+
49+
public class CustomSamplingPipeline
50+
: BaseSamplingPipeline
51+
{
52+
protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandle context)
53+
{
54+
var chain = SafeLLamaSamplerChainHandle.Create(LLamaSamplerChainParams.Default());
55+
56+
// Take only the 10 most likely tokens
57+
chain.AddTopK(10);
58+
59+
// Remove the most likely token
60+
chain.AddCustom(new RemoveMostLikelyToken());
61+
62+
// Select from the distribution
63+
chain.AddSoftmax();
64+
chain.AddDistributionSampler(42);
65+
66+
return chain;
67+
}
68+
}
69+
70+
public class RemoveMostLikelyToken
71+
: ICustomSampler
72+
{
73+
public string Name => "Remove Most Likely Token";
74+
75+
public void Apply(ref LLamaTokenDataArrayNative tokenData)
76+
{
77+
// Doesn't make sense to run this stage if there is only one candidate left
78+
if (tokenData.Size <= 1)
79+
return;
80+
81+
// Ensure token data is sorted, so most likely token is first.
82+
// Note that this is a descending sort, the **largest** value is first.
83+
if (!tokenData.Sorted)
84+
tokenData.Data.Sort((a, b) => b.Logit.CompareTo(a.Logit));
85+
86+
// Make the most likely token impossible to pick
87+
tokenData.Data[0].Logit = float.NegativeInfinity;
88+
89+
// It's **critically** important to set this if the logits are no longer sorted after the custom
90+
// sampler has run. If you're not sure, it's always safer to set it to false.
91+
//
92+
// In this case, because the first logit has just been set to negative infinity
93+
// the token data is definitely not sorted!
94+
tokenData.Sorted = false;
95+
}
96+
97+
public void Accept(LLamaToken token)
98+
{
99+
}
100+
101+
public void Reset()
102+
{
103+
}
104+
105+
public ICustomSampler Clone()
106+
{
107+
return new RemoveMostLikelyToken();
108+
}
109+
110+
public void Dispose()
111+
{
112+
}
113+
}
114+
}

LLama/Native/LLamaTokenDataArray.cs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ public struct LLamaTokenDataArrayNative
149149
/// <summary>
150150
/// Number of LLamaTokenData in the array
151151
/// </summary>
152-
public ulong size;
152+
private ulong _size;
153153

154154
/// <summary>
155155
/// The index in the array (i.e. not the token id)
@@ -167,13 +167,13 @@ public Span<LLamaTokenData> Data
167167
{
168168
unsafe
169169
{
170-
return new Span<LLamaTokenData>(_data, checked((int)size));
170+
return new Span<LLamaTokenData>(_data, checked((int)Size));
171171
}
172172
}
173173
}
174174

175175
/// <summary>
176-
/// Indicates if the items in the array are sorted
176+
/// Indicates if the items in the array are sorted, so the most likely token is first
177177
/// </summary>
178178
public bool Sorted
179179
{
@@ -190,6 +190,20 @@ public long Selected
190190
set => _selected = value;
191191
}
192192

193+
/// <summary>
194+
/// Number of LLamaTokenData in the array. Set this to shrink the array
195+
/// </summary>
196+
public ulong Size
197+
{
198+
get => _size;
199+
set
200+
{
201+
if (value > _size)
202+
throw new ArgumentOutOfRangeException(nameof(value), "Cannot set Size property to a larger value");
203+
_size = value;
204+
}
205+
}
206+
193207
/// <summary>
194208
/// Create a new LLamaTokenDataArrayNative around the data in the LLamaTokenDataArray
195209
/// </summary>
@@ -205,7 +219,7 @@ public static MemoryHandle Create(LLamaTokenDataArray array, out LLamaTokenDataA
205219
native = new LLamaTokenDataArrayNative
206220
{
207221
_data = (LLamaTokenData*)handle.Pointer,
208-
size = (ulong)array.Data.Length,
222+
Size = (ulong)array.Data.Length,
209223
Sorted = array.Sorted
210224
};
211225
}

0 commit comments

Comments
 (0)