Skip to content

Commit a296b55

Browse files
committed
Initial investigation into potential issue with batched executor (or batching in general)
1 parent 2f7a9d5 commit a296b55

File tree

2 files changed

+31
-7
lines changed

2 files changed

+31
-7
lines changed

LLama.Examples/Examples/BatchedExecutorSimple.cs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,15 @@ await AnsiConsole.Live(table).StartAsync(async ctx =>
8484

8585
foreach (var conversationData in conversations.Where(c => c.IsComplete == false))
8686
{
87-
if (conversationData.Conversation.RequiresSampling == false) continue;
87+
if (conversationData.Conversation.RequiresSampling == false)
88+
continue;
8889

8990
// sample a single token for the executor, passing the sample index of the conversation
91+
var sampleIndex = conversationData.Conversation.GetSampleIndex();
9092
var token = conversationData.Sampler.Sample(
91-
executor.Context.NativeHandle,
92-
conversationData.Conversation.GetSampleIndex());
93+
executor.Context,
94+
sampleIndex
95+
);
9396

9497
if (modelTokens.IsEndOfGeneration(token))
9598
{
@@ -99,7 +102,7 @@ await AnsiConsole.Live(table).StartAsync(async ctx =>
99102
{
100103
// it isn't the end of generation, so add this token to the decoder and then add that to our tracked data
101104
conversationData.Decoder.Add(token);
102-
conversationData.AppendAnswer(conversationData.Decoder.Read().ReplaceLineEndings(" "));
105+
todo: conversationData.AppendAnswer(conversationData.Decoder.Read().ReplaceLineEndings(" "));
103106

104107
// add the token to the conversation
105108
conversationData.Conversation.Prompt(token);

LLama/Native/LLamaBatch.cs

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Diagnostics;
34

45
namespace LLama.Native;
56

@@ -105,6 +106,25 @@ private void GrowMaxSequences(int atLeast)
105106

106107
internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
107108
{
109+
// Sanity checking
110+
#if DEBUG
111+
// Check every output logit position is actually generating logits for exactly one sequence
112+
foreach (var (seq, idx) in _logitPositions)
113+
{
114+
Debug.Assert(_logits[idx] != 0);
115+
Debug.Assert(_sequenceIdCount[idx] == 1);
116+
Debug.Assert(_sequenceIds[idx][0] == seq);
117+
}
118+
119+
// Check the reverse
120+
for (var i = 0; i < _logits.Length; i++)
121+
{
122+
var actual = _logitPositions.FindIndex(x => x.Item2 == i) >= 0;
123+
var expected = _logits[i] != 0;
124+
Debug.Assert(actual == expected);
125+
}
126+
#endif
127+
108128
// This group holds all of the memory pins
109129
var group = new GroupDisposable();
110130

@@ -146,6 +166,7 @@ internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
146166
/// <returns>The index that the token was added at. Use this for GetLogitsIth</returns>
147167
public int Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
148168
{
169+
// todo: token sharing in batch is broken?
149170
// Try to find this (token, position) combo somewhere in the batch to re-use it by adding this
150171
// sequence ID to the list.
151172
// Do **not** do this if this token wants logits, to prevent logits being shared between sequences.
@@ -171,9 +192,9 @@ public int Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequence
171192
if (sequences.Length > SequenceCapacity)
172193
GrowMaxSequences(sequences.Length);
173194

174-
// Store the position in the index, so it can be found later.
175-
// We need to check that it's not already there in case we skipped the check above (because logits is true).
176-
if (!_index.ContainsKey((token, pos)))
195+
// Store the position in the index, so it can be found later. We don't want to share tokens when logits are being generated so
196+
// do not add to the index in that case.
197+
if (!logits && !_index.ContainsKey((token, pos)))
177198
_index.Add((token, pos), TokenCount);
178199

179200
// Add the items to the arrays

0 commit comments

Comments
 (0)