Skip to content

Commit 59cd487

Browse files
committed
- Removed the automatic shared token mechanism in LLamaBatch, it was causing problems with tokens that should never have been shared.
- Added a helper to simplify sampling a conversation with a sampling pipeline - Added many more comments explaining BatchedExecutorSimple in detail
1 parent a296b55 commit 59cd487

File tree

4 files changed

+65
-60
lines changed

4 files changed

+65
-60
lines changed

LLama.Examples/Examples/BatchedExecutorSimple.cs

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
using System.Diagnostics.CodeAnalysis;
21
using System.Text;
32
using LLama.Batched;
43
using LLama.Common;
@@ -34,6 +33,7 @@ public static async Task Run()
3433
var name = model.Metadata.GetValueOrDefault("general.name", "unknown model name");
3534
Console.WriteLine($"Created executor with model: {name}");
3635

36+
// A set of questions to evaluate all at once
3737
var messages = new[]
3838
{
3939
"What's 2+2?",
@@ -46,8 +46,10 @@ public static async Task Run()
4646
"I have two sons, Bert and Ernie. What should I name my daughter?",
4747
"What day comes after Friday?",
4848
"What color shoes should I wear with dark blue pants?",
49+
"Wy ae cts btr tn dgs?"
4950
};
5051

52+
// Create a "Conversation" for each question
5153
var conversations = new List<ConversationData>();
5254
foreach (var message in messages)
5355
{
@@ -57,11 +59,14 @@ public static async Task Run()
5759
template.Add("user", message);
5860
template.AddAssistant = true;
5961
var templatedMessage = Encoding.UTF8.GetString(template.Apply());
60-
62+
6163
// create a new conversation and prompt it. include special and bos because we are using the template
64+
// - BOS is the "Beginning of Sequence" token and should be included at the start of any prompt
65+
// - Special tokens are special non-text tokens which an LLM is trained to understand (e.g. BOS). The templated text may contains special tokens.
6266
var conversation = executor.Create();
6367
conversation.Prompt(executor.Context.Tokenize(templatedMessage, addBos: true, special: true));
6468

69+
// Store everything we need to process this conversation
6570
conversations.Add(new ConversationData {
6671
Prompt = message,
6772
Conversation = conversation,
@@ -73,50 +78,64 @@ public static async Task Run()
7378
var table = BuildTable(conversations);
7479
await AnsiConsole.Live(table).StartAsync(async ctx =>
7580
{
81+
// Enter a loop generating tokens
7682
for (var i = 0; i < TokenCount; i++)
7783
{
7884
// Run inference for all conversations in the batch which have pending tokens.
7985
var decodeResult = await executor.Infer();
86+
87+
// Inference can fail, always check the return value!
88+
// NoKvSlot is not a fatal error, it just means that there's not enough memory available in the KV cache to process everything. You can force
89+
// this to happen by setting a small value for ContextSize in the ModelParams at the top of this file (e.g. 512).
90+
// In this case it's handled by ending a conversation (which will free up some space) and trying again. You could also handle this by
91+
// saving the conversation to disk and loading it up again later once some other conversations have finished.
8092
if (decodeResult == DecodeResult.NoKvSlot)
81-
throw new Exception("Could not find a KV slot for the batch. Try reducing the size of the batch or increase the context.");
93+
{
94+
conversations.FirstOrDefault(a => !a.IsComplete)?.MarkComplete(failed:true);
95+
continue;
96+
}
97+
98+
// A generic error, this is fatal and the batch can no longer be used. This should never occur and generally indicates
99+
// a bug in LLamaSharp, llama.cpp or a hardware error.
82100
if (decodeResult == DecodeResult.Error)
83101
throw new Exception("Unknown error occurred while inferring.");
84102

85-
foreach (var conversationData in conversations.Where(c => c.IsComplete == false))
103+
// After inference all of the conversations must be sampled before running inference again.
104+
foreach (var conversationData in conversations)
86105
{
87-
if (conversationData.Conversation.RequiresSampling == false)
106+
// Completed conversations don't need sampling.
107+
if (conversationData.IsComplete)
88108
continue;
89-
90-
// sample a single token for the executor, passing the sample index of the conversation
91-
var sampleIndex = conversationData.Conversation.GetSampleIndex();
92-
var token = conversationData.Sampler.Sample(
93-
executor.Context,
94-
sampleIndex
95-
);
96-
109+
110+
// If the conversation wasn't prompted before the last call to Infer then it won't need sampling.
111+
if (!conversationData.Conversation.RequiresSampling)
112+
continue;
113+
114+
// Use the sampling pipeline to choose a single token for this conversation.
115+
var token = conversationData.Conversation.Sample(conversationData.Sampler);
116+
117+
// Some special tokens indicate that this sequence has ended. Check if that's what has been chosen by the sampling pipeline.
97118
if (modelTokens.IsEndOfGeneration(token))
98119
{
99120
conversationData.MarkComplete();
100121
}
101122
else
102123
{
103-
// it isn't the end of generation, so add this token to the decoder and then add that to our tracked data
124+
// It isn't the end of generation, so add this token to the decoder and then add that to our tracked data
104125
conversationData.Decoder.Add(token);
105-
todo: conversationData.AppendAnswer(conversationData.Decoder.Read().ReplaceLineEndings(" "));
126+
conversationData.AppendAnswer(conversationData.Decoder.Read().ReplaceLineEndings(" "));
106127

107-
// add the token to the conversation
128+
// Prompt the conversation with this token, ready for the next round of inference to generate another token
108129
conversationData.Conversation.Prompt(token);
109130
}
110131
}
111132

112-
// render the current state
133+
// Render the current state
113134
table = BuildTable(conversations);
114135
ctx.UpdateTarget(table);
115136

116137
if (conversations.All(c => c.IsComplete))
117-
{
118138
break;
119-
}
120139
}
121140

122141
// if we ran out of tokens before completing just mark them as complete for rendering purposes.
@@ -155,20 +174,23 @@ public class ConversationData
155174
public required BaseSamplingPipeline Sampler { get; init; }
156175
public required StreamingTokenDecoder Decoder { get; init; }
157176

158-
public string AnswerMarkdown => IsComplete
159-
? $"[green]{_inProgressAnswer.Message.EscapeMarkup()}{_inProgressAnswer.LatestToken.EscapeMarkup()}[/]"
160-
: $"[grey]{_inProgressAnswer.Message.EscapeMarkup()}[/][white]{_inProgressAnswer.LatestToken.EscapeMarkup()}[/]";
177+
public string AnswerMarkdown =>
178+
IsComplete
179+
? $"[{(IsFailed ? "red" : "green")}]{_inProgressAnswer.Message.EscapeMarkup()}{_inProgressAnswer.LatestToken.EscapeMarkup()}[/]"
180+
: $"[grey]{_inProgressAnswer.Message.EscapeMarkup()}[/][white]{_inProgressAnswer.LatestToken.EscapeMarkup()}[/]";
161181

162182
public bool IsComplete { get; private set; }
183+
public bool IsFailed { get; private set; }
163184

164185
// we are only keeping track of the answer in two parts to render them differently.
165186
private (string Message, string LatestToken) _inProgressAnswer = (string.Empty, string.Empty);
166187

167188
public void AppendAnswer(string newText) => _inProgressAnswer = (_inProgressAnswer.Message + _inProgressAnswer.LatestToken, newText);
168189

169-
public void MarkComplete()
190+
public void MarkComplete(bool failed = false)
170191
{
171192
IsComplete = true;
193+
IsFailed = failed;
172194
if (Conversation.IsDisposed == false)
173195
{
174196
// clean up the conversation and sampler to release more memory for inference.

LLama/Batched/ConversationExtensions.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using LLama.Native;
3+
using LLama.Sampling;
34

45
namespace LLama.Batched;
56

@@ -20,6 +21,18 @@ public static LLamaToken Sample(this Conversation conversation, SafeLLamaSampler
2021
return sampler.Sample(conversation.Executor.Context.NativeHandle, conversation.GetSampleIndex(offset));
2122
}
2223

24+
/// <summary>
25+
/// Sample a token from this conversation using the given sampling pipeline
26+
/// </summary>
27+
/// <param name="conversation"><see cref="Conversation"/> to sample from</param>
28+
/// <param name="sampler"></param>
29+
/// <param name="offset">Offset from the end of the conversation to the logits to sample, see <see cref="Conversation.GetSampleIndex"/> for more details</param>
30+
/// <returns></returns>
31+
public static LLamaToken Sample(this Conversation conversation, ISamplingPipeline sampler, int offset = 0)
32+
{
33+
return sampler.Sample(conversation.Executor.Context.NativeHandle, conversation.GetSampleIndex(offset));
34+
}
35+
2336
/// <summary>
2437
/// Rewind a <see cref="Conversation"/> back to an earlier state by removing tokens from the end
2538
/// </summary>

LLama/LLamaTemplate.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ public void Clear()
210210
#endregion
211211

212212
/// <summary>
213-
/// Apply the template to the messages and write it into the output buffer
213+
/// Apply the template to the messages and return a span containing the results
214214
/// </summary>
215215
/// <returns>A span over the buffer that holds the applied template</returns>
216216
public ReadOnlySpan<byte> Apply()

LLama/Native/LLamaBatch.cs

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Diagnostics;
4+
using System.Linq;
45

56
namespace LLama.Native;
67

@@ -18,11 +19,6 @@ public class LLamaBatch
1819
private LLamaSeqId[][] _sequenceIds;
1920
private IntPtr[] _sequenceIdsPtrs;
2021

21-
/// <summary>
22-
/// Keep track of the index of existing token/position combos in the batch
23-
/// </summary>
24-
private readonly Dictionary<(LLamaToken, LLamaPos), int> _index = new();
25-
2622
/// <summary>
2723
/// Keep a list of where logits can be sampled from
2824
/// </summary>
@@ -108,18 +104,18 @@ internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
108104
{
109105
// Sanity checking
110106
#if DEBUG
111-
// Check every output logit position is actually generating logits for exactly one sequence
107+
// Check every output logit position is generating logits for exactly one sequence
112108
foreach (var (seq, idx) in _logitPositions)
113109
{
114110
Debug.Assert(_logits[idx] != 0);
115111
Debug.Assert(_sequenceIdCount[idx] == 1);
116112
Debug.Assert(_sequenceIds[idx][0] == seq);
117113
}
118114

119-
// Check the reverse
115+
// Check every index, if it's generating logits it must be in the _logitPositions list. Otherwise it must not.
120116
for (var i = 0; i < _logits.Length; i++)
121117
{
122-
var actual = _logitPositions.FindIndex(x => x.Item2 == i) >= 0;
118+
var actual = _logitPositions.Any(x => x.Item2 == i);
123119
var expected = _logits[i] != 0;
124120
Debug.Assert(actual == expected);
125121
}
@@ -166,37 +162,12 @@ internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
166162
/// <returns>The index that the token was added at. Use this for GetLogitsIth</returns>
167163
public int Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
168164
{
169-
// todo: token sharing in batch is broken?
170-
// Try to find this (token, position) combo somewhere in the batch to re-use it by adding this
171-
// sequence ID to the list.
172-
// Do **not** do this if this token wants logits, to prevent logits being shared between sequences.
173-
if (!logits && _index.TryGetValue((token, pos), out var existingIndex))
174-
{
175-
if (_sequenceIdCount[existingIndex] + sequences.Length > SequenceCapacity)
176-
GrowMaxSequences(_sequenceIdCount[existingIndex] + sequences.Length);
177-
178-
foreach (var sequence in sequences)
179-
{
180-
_sequenceIds[existingIndex][_sequenceIdCount[existingIndex]] = sequence;
181-
_sequenceIdCount[existingIndex]++;
182-
}
183-
184-
return existingIndex;
185-
}
186-
187-
// Couldn't find this token/position combo anywhere in the batch. Add a new item.
188-
189165
// Grow capacity as necessary
190166
if (TokenCount == TokenCapacity)
191167
GrowTokenCapacity();
192168
if (sequences.Length > SequenceCapacity)
193169
GrowMaxSequences(sequences.Length);
194170

195-
// Store the position in the index, so it can be found later. We don't want to share tokens when logits are being generated so
196-
// do not add to the index in that case.
197-
if (!logits && !_index.ContainsKey((token, pos)))
198-
_index.Add((token, pos), TokenCount);
199-
200171
// Add the items to the arrays
201172
_tokens[TokenCount] = token;
202173
_positions[TokenCount] = pos;
@@ -234,15 +205,15 @@ public int Add(LLamaToken token, LLamaPos pos, List<LLamaSeqId> sequences, bool
234205
// the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't
235206
// avoid the copying.
236207

237-
var rented = System.Buffers.ArrayPool<LLamaSeqId>.Shared.Rent(sequences.Count);
208+
var rented = ArrayPool<LLamaSeqId>.Shared.Rent(sequences.Count);
238209
try
239210
{
240211
sequences.CopyTo(rented, 0);
241212
return Add(token, pos, rented.AsSpan(0, sequences.Count), logits);
242213
}
243214
finally
244215
{
245-
System.Buffers.ArrayPool<LLamaSeqId>.Shared.Return(rented);
216+
ArrayPool<LLamaSeqId>.Shared.Return(rented);
246217
}
247218
#endif
248219
}
@@ -294,7 +265,6 @@ public void Clear()
294265
{
295266
TokenCount = 0;
296267

297-
_index.Clear();
298268
_logitPositions.Clear();
299269
}
300270

0 commit comments

Comments
 (0)