Skip to content

Commit a74314c

Browse files
committed
Updated to be7c3034108473beda214fd1d7c98fd6a7a3bdf5
1 parent 0778277 commit a74314c

File tree

8 files changed

+67
-34
lines changed

8 files changed

+67
-34
lines changed

LLama.Unittest/LLamaContextTests.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ public LLamaContextTests()
1414
var @params = new ModelParams(Constants.GenerativeModelPath2)
1515
{
1616
ContextSize = 128,
17+
BatchSize = 8,
18+
UBatchSize = 8,
19+
SeqMax = 1,
20+
VocabOnly = false,
1721
GpuLayerCount = Constants.CIGpuLayerCount,
1822
};
1923
_weights = LLamaWeights.LoadFromFile(@params);
@@ -84,6 +88,11 @@ public void TokenizeEmpty()
8488
[Fact]
8589
public void SaveLoadState()
8690
{
91+
// Make sure there's something in the context worth saving
92+
var batch = new LLamaBatch();
93+
batch.Add(17, 0, LLamaSeqId.Zero, true);
94+
_context.Decode(batch);
95+
8796
using var state1 = _context.GetState();
8897

8998
var stream = new MemoryStream();
@@ -99,6 +108,11 @@ public void SaveLoadState()
99108
[Fact]
100109
public async Task SaveLoadStateAsync()
101110
{
111+
// Make sure there's something in the context worth saving
112+
var batch = new LLamaBatch();
113+
batch.Add(17, 0, LLamaSeqId.Zero, true);
114+
_context.Decode(batch);
115+
102116
using var state1 = _context.GetState();
103117

104118
var stream = new MemoryStream();

LLama/Batched/Conversation.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ public Conversation Fork()
128128
_forked = true;
129129

130130
// Assign tokens to the new sequence
131-
NativeApi.llama_kv_cache_seq_cp(Executor.Context.NativeHandle, ConversationId, c.ConversationId, 0, _end);
131+
Executor.Context.NativeHandle.KvCacheSequenceCopy(ConversationId, c.ConversationId, 0, _end);
132132

133133
return c;
134134
}

LLama/LLamaExecutorBase.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,8 @@ protected virtual void HandleRunOutOfContext(int tokensToKeep)
193193
int n_left = _pastTokensCount - tokensToKeep;
194194
int n_discard = n_left / 2;
195195

196-
NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, LLamaSeqId.Zero, tokensToKeep, tokensToKeep + n_discard);
197-
NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, LLamaSeqId.Zero, tokensToKeep + n_discard, _pastTokensCount, -n_discard);
196+
NativeApi.llama_kv_self_seq_rm(Context.NativeHandle, LLamaSeqId.Zero, tokensToKeep, tokensToKeep + n_discard);
197+
NativeApi.llama_kv_self_seq_add(Context.NativeHandle, LLamaSeqId.Zero, tokensToKeep + n_discard, _pastTokensCount, -n_discard);
198198

199199
_pastTokensCount -= n_discard;
200200
// stop saving session if we run out of context

LLama/LLamaSharp.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
</ItemGroup>
5757

5858
<PropertyGroup>
59-
<BinaryReleaseId>6fefc05a7a4e67678v2</BinaryReleaseId>
59+
<BinaryReleaseId>be7c3034108473be</BinaryReleaseId>
6060
</PropertyGroup>
6161

6262
<PropertyGroup>

LLama/LLamaStatelessExecutor.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,8 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
155155
var n_left = n_past - tokensKeep;
156156
var n_discard = n_left / 2;
157157

158-
NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, LLamaSeqId.Zero, tokensKeep , tokensKeep + n_discard);
159-
NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, LLamaSeqId.Zero, tokensKeep + n_discard, n_past, -n_discard);
158+
NativeApi.llama_kv_self_seq_rm(Context.NativeHandle, LLamaSeqId.Zero, tokensKeep , tokensKeep + n_discard);
159+
NativeApi.llama_kv_self_seq_add(Context.NativeHandle, LLamaSeqId.Zero, tokensKeep + n_discard, n_past, -n_discard);
160160

161161
n_past -= n_discard;
162162
}

LLama/Native/LLamaKvCache.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
namespace LLama.Native;
2+
3+
/// <summary>
4+
/// C# representation of llama_kv_cache
5+
/// </summary>
6+
/// <remarks>llama_kv_cache</remarks>
7+
internal struct LLamaKvCacheNative
8+
{
9+
10+
}

LLama/Native/NativeApi.cs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -273,22 +273,22 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
273273
/// <param name="ctx"></param>
274274
/// <returns></returns>
275275
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
276-
public static extern int llama_get_kv_cache_token_count(SafeLLamaContextHandle ctx);
276+
internal static extern int llama_kv_self_n_tokens(SafeLLamaContextHandle ctx);
277277

278278
/// <summary>
279279
/// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
280280
/// </summary>
281281
/// <param name="ctx"></param>
282282
/// <returns></returns>
283283
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
284-
public static extern int llama_get_kv_cache_used_cells(SafeLLamaContextHandle ctx);
284+
internal static extern int llama_kv_self_used_cells(SafeLLamaContextHandle ctx);
285285

286286
/// <summary>
287287
/// Clear the KV cache. Both cell info is erased and KV data is zeroed
288288
/// </summary>
289289
/// <param name="ctx"></param>
290290
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
291-
public static extern void llama_kv_cache_clear(SafeLLamaContextHandle ctx);
291+
internal static extern void llama_kv_self_clear(SafeLLamaContextHandle ctx);
292292

293293
/// <summary>
294294
/// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
@@ -300,7 +300,7 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
300300
/// <returns>Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails</returns>
301301
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
302302
[return: MarshalAs(UnmanagedType.U1)]
303-
public static extern bool llama_kv_cache_seq_rm(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1);
303+
public static extern bool llama_kv_self_seq_rm(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1);
304304

305305
/// <summary>
306306
/// Copy all tokens that belong to the specified sequence to another sequence
@@ -312,35 +312,35 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
312312
/// <param name="p0"></param>
313313
/// <param name="p1"></param>
314314
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
315-
public static extern void llama_kv_cache_seq_cp(SafeLLamaContextHandle ctx, LLamaSeqId src, LLamaSeqId dest, LLamaPos p0, LLamaPos p1);
315+
internal static extern void llama_kv_self_seq_cp(SafeLLamaContextHandle ctx, LLamaSeqId src, LLamaSeqId dest, LLamaPos p0, LLamaPos p1);
316316

317317
/// <summary>
318318
/// Removes all tokens that do not belong to the specified sequence
319319
/// </summary>
320320
/// <param name="ctx"></param>
321321
/// <param name="seq"></param>
322322
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
323-
public static extern void llama_kv_cache_seq_keep(SafeLLamaContextHandle ctx, LLamaSeqId seq);
323+
internal static extern void llama_kv_self_seq_keep(SafeLLamaContextHandle ctx, LLamaSeqId seq);
324324

325325
/// <summary>
326326
/// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
327327
/// If the KV cache is RoPEd, the KV data is updated accordingly:
328328
/// - lazily on next llama_decode()
329-
/// - explicitly with llama_kv_cache_update()
329+
/// - explicitly with llama_kv_self_update()
330330
/// </summary>
331331
/// <param name="ctx"></param>
332332
/// <param name="seq"></param>
333333
/// <param name="p0"></param>
334334
/// <param name="p1"></param>
335335
/// <param name="delta"></param>
336336
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
337-
public static extern void llama_kv_cache_seq_add(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta);
337+
internal static extern void llama_kv_self_seq_add(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta);
338338

339339
/// <summary>
340340
/// Integer division of the positions by factor of `d > 1`
341341
/// If the KV cache is RoPEd, the KV data is updated accordingly:
342342
/// - lazily on next llama_decode()
343-
/// - explicitly with llama_kv_cache_update()
343+
/// - explicitly with llama_kv_self_update()
344344
/// <br />
345345
/// p0 &lt; 0 : [0, p1]
346346
/// <br />
@@ -352,7 +352,7 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
352352
/// <param name="p1"></param>
353353
/// <param name="d"></param>
354354
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
355-
public static extern void llama_kv_cache_seq_div(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int d);
355+
internal static extern void llama_kv_self_seq_div(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int d);
356356

357357
/// <summary>
358358
/// Returns the largest position present in the KV cache for the specified sequence
@@ -361,7 +361,7 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
361361
/// <param name="seq"></param>
362362
/// <returns></returns>
363363
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
364-
public static extern LLamaPos llama_kv_cache_seq_pos_max(SafeLLamaContextHandle ctx, LLamaSeqId seq);
364+
internal static extern LLamaPos llama_kv_self_seq_pos_max(SafeLLamaContextHandle ctx, LLamaSeqId seq);
365365

366366
/// <summary>
367367
/// Allocates a batch of tokens on the heap

LLama/Native/SafeLLamaContextHandle.cs

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -313,27 +313,27 @@ static SafeLLamaContextHandle()
313313
/// <summary>
314314
/// Defragment the KV cache. This will be applied:
315315
/// - lazily on next llama_decode()
316-
/// - explicitly with llama_kv_cache_update()
316+
/// - explicitly with llama_kv_self_update()
317317
/// </summary>
318318
/// <param name="ctx"></param>
319319
/// <returns></returns>
320320
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
321-
private static extern void llama_kv_cache_defrag(SafeLLamaContextHandle ctx);
321+
private static extern void llama_kv_self_defrag(SafeLLamaContextHandle ctx);
322322

323323
/// <summary>
324324
/// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
325325
/// </summary>
326326
/// <param name="ctx"></param>
327327
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
328-
private static extern void llama_kv_cache_update(SafeLLamaContextHandle ctx);
328+
private static extern void llama_kv_self_update(SafeLLamaContextHandle ctx);
329329

330330
/// <summary>
331331
/// Check if the context supports KV cache shifting
332332
/// </summary>
333333
/// <param name="ctx"></param>
334334
/// <returns></returns>
335335
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
336-
private static extern bool llama_kv_cache_can_shift(SafeLLamaContextHandle ctx);
336+
private static extern bool llama_kv_self_can_shift(SafeLLamaContextHandle ctx);
337337

338338
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
339339
private static extern LLamaPerfContextTimings llama_perf_context(SafeLLamaContextHandle ctx);
@@ -386,6 +386,9 @@ static SafeLLamaContextHandle()
386386
/// <returns>A pointer to the first float in an embedding, length = ctx.EmbeddingSize</returns>
387387
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
388388
private static extern unsafe float* llama_get_embeddings_ith(SafeLLamaContextHandle ctx, int i);
389+
390+
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
391+
private static extern LLamaKvCacheNative llama_get_kv_self(SafeLLamaContextHandle ctx);
389392
#endregion
390393

391394
#region LoRA
@@ -751,25 +754,25 @@ public void ResetTimings()
751754
/// <summary>
752755
/// Check if the context supports KV cache shifting
753756
/// </summary>
754-
public bool KvCacheCanShift => llama_kv_cache_can_shift(this);
757+
public bool KvCacheCanShift => llama_kv_self_can_shift(this);
755758

756759
/// <summary>
757760
/// Apply KV cache updates (such as K-shifts, defragmentation, etc.)
758761
/// </summary>
759762
public void KvCacheUpdate()
760763
{
761-
llama_kv_cache_update(this);
764+
llama_kv_self_update(this);
762765
}
763766

764767
/// <summary>
765768
/// Defragment the KV cache. This will be applied:
766769
/// - lazily on next llama_decode()
767-
/// - explicitly with llama_kv_cache_update()
770+
/// - explicitly with llama_kv_self_update()
768771
/// </summary>
769772
/// <returns></returns>
770773
public void KvCacheDefrag()
771774
{
772-
llama_kv_cache_defrag(this);
775+
llama_kv_self_defrag(this);
773776
}
774777

775778
/// <summary>
@@ -788,7 +791,7 @@ public LLamaKvCacheViewSafeHandle KvCacheGetDebugView(int maxSequences = 4)
788791
/// <returns></returns>
789792
public int KvCacheCountCells()
790793
{
791-
return NativeApi.llama_get_kv_cache_used_cells(this);
794+
return NativeApi.llama_kv_self_used_cells(this);
792795
}
793796

794797
/// <summary>
@@ -798,15 +801,15 @@ public int KvCacheCountCells()
798801
/// <returns></returns>
799802
public int KvCacheCountTokens()
800803
{
801-
return NativeApi.llama_get_kv_cache_token_count(this);
804+
return NativeApi.llama_kv_self_n_tokens(this);
802805
}
803806

804807
/// <summary>
805808
/// Clear the KV cache - both cell info is erased and KV data is zeroed
806809
/// </summary>
807810
public void KvCacheClear()
808811
{
809-
NativeApi.llama_kv_cache_clear(this);
812+
NativeApi.llama_kv_self_clear(this);
810813
}
811814

812815
/// <summary>
@@ -817,7 +820,7 @@ public void KvCacheClear()
817820
/// <param name="p1"></param>
818821
public void KvCacheRemove(LLamaSeqId seq, LLamaPos p0, LLamaPos p1)
819822
{
820-
NativeApi.llama_kv_cache_seq_rm(this, seq, p0, p1);
823+
NativeApi.llama_kv_self_seq_rm(this, seq, p0, p1);
821824
}
822825

823826
/// <summary>
@@ -831,7 +834,7 @@ public void KvCacheRemove(LLamaSeqId seq, LLamaPos p0, LLamaPos p1)
831834
/// <param name="p1"></param>
832835
public void KvCacheSequenceCopy(LLamaSeqId src, LLamaSeqId dest, LLamaPos p0, LLamaPos p1)
833836
{
834-
NativeApi.llama_kv_cache_seq_cp(this, src, dest, p0, p1);
837+
NativeApi.llama_kv_self_seq_cp(this, src, dest, p0, p1);
835838
}
836839

837840
/// <summary>
@@ -840,7 +843,7 @@ public void KvCacheSequenceCopy(LLamaSeqId src, LLamaSeqId dest, LLamaPos p0, LL
840843
/// <param name="seq"></param>
841844
public void KvCacheSequenceKeep(LLamaSeqId seq)
842845
{
843-
NativeApi.llama_kv_cache_seq_keep(this, seq);
846+
NativeApi.llama_kv_self_seq_keep(this, seq);
844847
}
845848

846849
/// <summary>
@@ -854,7 +857,10 @@ public void KvCacheSequenceKeep(LLamaSeqId seq)
854857
/// <param name="delta"></param>
855858
public void KvCacheSequenceAdd(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int delta)
856859
{
857-
NativeApi.llama_kv_cache_seq_add(this, seq, p0, p1, delta);
860+
if (!KvCacheCanShift)
861+
throw new InvalidOperationException("Cannot shift KV cache (KvCacheCanShift=False)");
862+
863+
NativeApi.llama_kv_self_seq_add(this, seq, p0, p1, delta);
858864
}
859865

860866
/// <summary>
@@ -869,7 +875,10 @@ public void KvCacheSequenceAdd(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int del
869875
/// <param name="divisor"></param>
870876
public void KvCacheSequenceDivide(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int divisor)
871877
{
872-
NativeApi.llama_kv_cache_seq_div(this, seq, p0, p1, divisor);
878+
if (!KvCacheCanShift)
879+
throw new InvalidOperationException("Cannot shift KV cache (KvCacheCanShift=False)");
880+
881+
NativeApi.llama_kv_self_seq_div(this, seq, p0, p1, divisor);
873882
}
874883

875884
/// <summary>
@@ -879,7 +888,7 @@ public void KvCacheSequenceDivide(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int
879888
/// <returns></returns>
880889
public LLamaPos KvCacheMaxPosition(LLamaSeqId seq)
881890
{
882-
return NativeApi.llama_kv_cache_seq_pos_max(this, seq);
891+
return NativeApi.llama_kv_self_seq_pos_max(this, seq);
883892
}
884893
#endregion
885894
}

0 commit comments

Comments
 (0)