Skip to content

Commit 50f2425

Browse files
committed
WIP
1 parent 01f466c commit 50f2425

22 files changed

+378
-527
lines changed

LLama.Benchmark/LLamaExecutorBenchmark/Prefill.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ public void GlobalCleanup()
119119
{
120120
if(ExecutorType != ExecutorType.Stateless) // stateless executor always dispose its `Context` property
121121
{
122-
Executor.Context.NativeHandle.KvCacheClear();
122+
Executor.Context.NativeHandle.MemoryClear();
123123
}
124124
}
125125

LLama.Examples/Examples/LlavaInteractiveModeExecute.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ public static async Task Run()
7979
// When the prompt contains images we clear KV_CACHE to restart conversation
8080
// See:
8181
// https://github.com/ggerganov/llama.cpp/discussions/3620
82-
ex.Context.NativeHandle.KvCacheRemove( LLamaSeqId.Zero, -1, -1 );
82+
ex.Context.NativeHandle.MemorySequenceRemove( LLamaSeqId.Zero, -1, -1 );
8383

8484
int index = 0;
8585
foreach (var path in imagePathsWithCurlyBraces)

LLama/Abstractions/IContextParams.cs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ public interface IContextParams
109109
bool FlashAttention { get; }
110110

111111
/// <summary>
112-
/// defragment the KV cache if holes/size &gt; defrag_threshold, Set to &lt; 0 to disable (default)
113-
/// defragment the KV cache if holes/size &gt; defrag_threshold, Set to <see langword="null"/> or &lt; 0 to disable (default)
112+
/// defragment the KV cache if holes/size &gt; defrag_threshold, Set to &lt;= 0 to disable (default)
114113
/// </summary>
115114
float? DefragThreshold { get; }
116115

@@ -123,4 +122,17 @@ public interface IContextParams
123122
/// Attention type to use for embeddings
124123
/// </summary>
125124
LLamaAttentionType AttentionType { get; }
125+
126+
/// <summary>
127+
/// Offload host tensor operations to device
128+
/// </summary>
129+
bool? OpOffload { get; }
130+
131+
/// <summary>
132+
/// Use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
133+
/// </summary>
134+
/// <remarks>Setting to false when n_seq_max > 1 can cause bad performance in some cases
135+
/// ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
136+
/// </remarks>
137+
bool? SwaFull { get; }
126138
}

LLama/Batched/Conversation.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ public void Dispose()
8484
_disposed = true;
8585

8686
// Remove this conversation from the KV cache
87-
Executor.Context.NativeHandle.KvCacheRemove(ConversationId, -1, -1);
87+
Executor.Context.NativeHandle.MemorySequenceRemove(ConversationId, -1, -1);
8888

8989
// Prevent finalizer from running
9090
GC.SuppressFinalize(this);
@@ -129,7 +129,7 @@ public Conversation Fork()
129129
_forked = true;
130130

131131
// Assign tokens to the new sequence
132-
Executor.Context.NativeHandle.KvCacheSequenceCopy(ConversationId, c.ConversationId, 0, _end);
132+
Executor.Context.NativeHandle.MemorySequenceCopy(ConversationId, c.ConversationId, 0, _end);
133133

134134
return c;
135135
}
@@ -406,7 +406,7 @@ internal KvAccessor(Conversation conversation)
406406
/// <param name="end">End position (exclusive)</param>
407407
public void Remove(LLamaPos start, LLamaPos end)
408408
{
409-
_conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end);
409+
_conversation.Executor.Context.NativeHandle.MemorySequenceRemove(_conversation.ConversationId, start, end);
410410
}
411411

412412
/// <summary>
@@ -420,7 +420,7 @@ public void Remove(LLamaPos start, int count)
420420
return;
421421

422422
var end = start.Value + count;
423-
_conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end);
423+
_conversation.Executor.Context.NativeHandle.MemorySequenceRemove(_conversation.ConversationId, start, end);
424424
}
425425
#endregion
426426

@@ -435,7 +435,7 @@ public void Remove(LLamaPos start, int count)
435435
/// <param name="delta">Amount to add on to each token position</param>
436436
public void Add(LLamaPos start, LLamaPos end, int delta)
437437
{
438-
_conversation.Executor.Context.NativeHandle.KvCacheSequenceAdd(_conversation.ConversationId, start, end, delta);
438+
_conversation.Executor.Context.NativeHandle.MemorySequenceAdd(_conversation.ConversationId, start, end, delta);
439439
}
440440
#endregion
441441

@@ -452,7 +452,7 @@ public void Divide(LLamaPos start, LLamaPos end, int divisor)
452452
if (divisor <= 0)
453453
throw new ArgumentOutOfRangeException(nameof(divisor));
454454

455-
_conversation.Executor.Context.NativeHandle.KvCacheSequenceDivide(_conversation.ConversationId, start, end, divisor);
455+
_conversation.Executor.Context.NativeHandle.MemorySequenceDivide(_conversation.ConversationId, start, end, divisor);
456456
}
457457
#endregion
458458
}

LLama/ChatSession.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ public void LoadSession(SessionState state, bool loadTransforms = true)
199199
}
200200
if (state.ContextState is null)
201201
{
202-
Executor.Context.NativeHandle.KvCacheClear();
202+
Executor.Context.NativeHandle.MemoryClear();
203203
}
204204
else
205205
{

LLama/Extensions/IContextParamsExtensions.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo
5555

5656
result.n_threads = Threads(@params.Threads);
5757
result.n_threads_batch = Threads(@params.BatchThreads);
58+
59+
if (@params.SwaFull.HasValue)
60+
result.swa_full = @params.SwaFull.Value;
61+
if (@params.OpOffload.HasValue)
62+
result.op_offload = @params.OpOffload.Value;
5863
}
5964

6065
private static int Threads(int? value)

LLama/LLamaExecutorBase.cs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,16 @@ public StatefulExecutorBase WithSessionFile(string filename)
128128
}
129129
if (File.Exists(filename))
130130
{
131-
_logger?.LogInformation($"[LLamaExecutor] Attempting to load saved session from {filename}");
131+
_logger?.LogInformation("[LLamaExecutor] Attempting to load saved session from {0}", filename);
132+
132133
var session_tokens = new LLamaToken[Context.ContextSize];
133134
if (!NativeApi.llama_state_load_file(Context.NativeHandle, _pathSession, session_tokens, (ulong)Context.ContextSize, out var n_token_count_out))
134135
{
135136
_logger?.LogError($"[LLamaExecutor] Failed to load session file {filename}");
136137
throw new RuntimeError($"Failed to load session file {_pathSession}");
137138
}
138139
_session_tokens = session_tokens.Take((int)n_token_count_out).ToList();
139-
_logger?.LogInformation($"[LLamaExecutor] Loaded a session with prompt size of {session_tokens.Length} tokens");
140+
_logger?.LogInformation("[LLamaExecutor] Loaded a session with prompt size of {0} tokens", session_tokens.Length);
140141
}
141142
else
142143
{
@@ -190,11 +191,11 @@ protected virtual void HandleRunOutOfContext(int tokensToKeep)
190191
// if we run out of context:
191192
// - take the tokensToKeep first tokens from the original prompt (via n_past)
192193
// - take half of the last (n_ctx - tokensToKeep) tokens and recompute the logits in batches
193-
int n_left = _pastTokensCount - tokensToKeep;
194-
int n_discard = n_left / 2;
194+
var n_left = _pastTokensCount - tokensToKeep;
195+
var n_discard = n_left / 2;
195196

196-
NativeApi.llama_kv_self_seq_rm(Context.NativeHandle, LLamaSeqId.Zero, tokensToKeep, tokensToKeep + n_discard);
197-
NativeApi.llama_kv_self_seq_add(Context.NativeHandle, LLamaSeqId.Zero, tokensToKeep + n_discard, _pastTokensCount, -n_discard);
197+
Context.NativeHandle.MemorySequenceRemove(LLamaSeqId.Zero, tokensToKeep, tokensToKeep + n_discard);
198+
Context.NativeHandle.MemorySequenceAdd(LLamaSeqId.Zero, tokensToKeep + n_discard, _pastTokensCount, -n_discard);
198199

199200
_pastTokensCount -= n_discard;
200201
// stop saving session if we run out of context

LLama/LLamaReranker.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ public async Task<IReadOnlyList<float>> GetRelevanceScores(string input, IReadOn
114114
batch.Add(tokens[i], i, LLamaSeqId.Zero, true);
115115

116116
// clear previous kv_cache values
117-
Context.NativeHandle.KvCacheClear();
117+
Context.NativeHandle.MemoryClear();
118118

119119
// Check if we should cancel the work, just before doing anything expensive (encode/decode)
120120
cancellationToken.ThrowIfCancellationRequested();
@@ -144,7 +144,7 @@ public async Task<IReadOnlyList<float>> GetRelevanceScores(string input, IReadOn
144144

145145
var score = Context.NativeHandle.GetEmbeddingsSeq(LLamaSeqId.Zero)[0];
146146

147-
Context.NativeHandle.KvCacheClear();
147+
Context.NativeHandle.MemoryClear();
148148

149149
return (normalize ? Sigmoid(score) : score, tokens.Length);
150150
}
@@ -155,7 +155,7 @@ private async Task<IReadOnlyList<float>> CalcRelevanceScores(LLamaBatch batch, b
155155
var seqNum = logicCap.Value + 1;
156156
List<float> scores = new List<float>(seqNum);
157157
// clear previous kv_cache values
158-
Context.NativeHandle.KvCacheClear();
158+
Context.NativeHandle.MemoryClear();
159159

160160
// Check if we should cancel the work, just before doing anything expensive (encode/decode)
161161
cancellationToken.ThrowIfCancellationRequested();
@@ -189,7 +189,7 @@ private async Task<IReadOnlyList<float>> CalcRelevanceScores(LLamaBatch batch, b
189189
scores.Add(normalize ? Sigmoid(score) : score);
190190
}
191191

192-
Context.NativeHandle.KvCacheClear();
192+
Context.NativeHandle.MemoryClear();
193193

194194
return scores;
195195
}

LLama/LLamaStatelessExecutor.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
158158
var n_left = n_past - tokensKeep;
159159
var n_discard = n_left / 2;
160160

161-
NativeApi.llama_kv_self_seq_rm(Context.NativeHandle, LLamaSeqId.Zero, tokensKeep , tokensKeep + n_discard);
162-
NativeApi.llama_kv_self_seq_add(Context.NativeHandle, LLamaSeqId.Zero, tokensKeep + n_discard, n_past, -n_discard);
161+
Context.NativeHandle.MemorySequenceRemove(LLamaSeqId.Zero, tokensKeep, tokensKeep + n_discard);
162+
Context.NativeHandle.MemorySequenceAdd(LLamaSeqId.Zero, tokensKeep + n_discard, n_past, -n_discard);
163163

164164
n_past -= n_discard;
165165
}

LLama/Native/DecodeResult.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
namespace LLama.Native;
1+
namespace LLama.Native;
22

33
/// <summary>
44
/// Return codes from llama_decode
55
/// </summary>
66
public enum DecodeResult
77
{
88
/// <summary>
9-
/// An unspecified error
9+
/// Input batch was invalid
1010
/// </summary>
11-
Error = -1,
11+
InvalidInputBatch = -1,
1212

1313
/// <summary>
1414
/// Ok.

0 commit comments

Comments
 (0)