Skip to content

Commit d761188

Browse files
committed
applied changes from header, still need to update C#
1 parent 40ea046 commit d761188

File tree

7 files changed

+83
-45
lines changed

7 files changed

+83
-45
lines changed

LLama/Native/LLamaNativeBatch.cs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ public unsafe struct LLamaNativeBatch
2525

2626
/// <summary>
2727
/// the positions of the respective token in the sequence
28+
/// (if set to NULL, the token position will be tracked automatically by llama_decode)
2829
/// </summary>
2930
public LLamaPos* pos;
3031

@@ -35,18 +36,13 @@ public unsafe struct LLamaNativeBatch
3536

3637
/// <summary>
3738
/// the sequence to which the respective token belongs
39+
/// (if set to NULL, the sequence ID will be assumed to be 0)
3840
/// </summary>
3941
public LLamaSeqId** seq_id;
4042

4143
/// <summary>
4244
/// if zero, the logits for the respective token will not be output
45+
/// (if set to NULL, only the logits for last token will be returned)
4346
/// </summary>
4447
public byte* logits;
45-
46-
// Note from llama.cpp:
47-
// > helpers for smooth API transition - can be deprecated in the future
48-
// > for future-proof code, use the above fields instead and ignore everything below
49-
private LLamaPos _all_pos_0;
50-
private LLamaPos _all_pos_1;
51-
private LLamaSeqId _all_seq_id;
5248
}

LLama/Native/LLamaPoolingType.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,9 @@ public enum LLamaPoolingType
2929
CLS = 2,
3030

3131
Last = 3,
32+
33+
/// <summary>
34+
/// Used by reranking models to attach the classification head to the graph
35+
/// </summary>
36+
Rank,
3237
}

LLama/Native/LLamaVocabPreType.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,5 @@ internal enum LLamaVocabPreType
3333
BLOOM = 23,
3434
GPT3_FINNISH = 24,
3535
EXAONE = 25,
36+
CHAMELEON = 26,
3637
}

LLama/Native/NativeApi.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ public static void llama_empty_call()
4949
[return: MarshalAs(UnmanagedType.U1)]
5050
public static extern bool llama_supports_gpu_offload();
5151

52+
/// <summary>
53+
/// Check if RPC offload is supported
54+
/// </summary>
55+
/// <returns></returns>
56+
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
57+
[return: MarshalAs(UnmanagedType.U1)]
58+
public static extern bool llama_supports_rpc();
59+
5260
/// <summary>
5361
/// Initialize the llama + ggml backend. Call once at the start of the program.
5462
///

LLama/Native/SafeLLamaContextHandle.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,10 @@ static SafeLLamaContextHandle()
368368
private static extern LLamaPoolingType llama_pooling_type(SafeLLamaContextHandle ctx);
369369

370370
/// <summary>
371-
/// Get the embeddings for the a specific sequence.
372-
/// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
371+
/// Get the embeddings for a sequence id.
372+
/// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
373+
/// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence
374+
/// otherwise: float[n_embd] (1-dimensional)
373375
/// </summary>
374376
/// <returns>A pointer to the first float in an embedding, length = ctx.EmbeddingSize</returns>
375377
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]

LLama/Native/SafeLLamaSamplerHandle.cs

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -249,19 +249,6 @@ public void AddMirostat2Sampler(uint seed, float tau, float eta)
249249
static extern IntPtr llama_sampler_init_mirostat_v2(uint seed, float tau, float eta);
250250
}
251251

252-
253-
/// <summary>
254-
/// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
255-
/// </summary>
256-
/// <returns></returns>
257-
public void AddSoftmax()
258-
{
259-
llama_sampler_chain_add(this, llama_sampler_init_softmax());
260-
261-
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
262-
static extern IntPtr llama_sampler_init_softmax();
263-
}
264-
265252
/// <summary>
266253
/// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
267254
/// </summary>
@@ -291,7 +278,6 @@ public void AddTopP(float p, nint minKeep)
291278
/// <summary>
292279
/// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
293280
/// </summary>
294-
/// <returns></returns>
295281
public void AddMinP(float p, nint minKeep)
296282
{
297283
llama_sampler_chain_add(this, llama_sampler_init_min_p(p, minKeep));
@@ -305,7 +291,6 @@ public void AddMinP(float p, nint minKeep)
305291
/// <summary>
306292
/// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
307293
/// </summary>
308-
/// <returns></returns>
309294
public void AddTailFree(float z, nint minKeep)
310295
{
311296
llama_sampler_chain_add(this, llama_sampler_init_tail_free(z, minKeep));
@@ -319,7 +304,6 @@ public void AddTailFree(float z, nint minKeep)
319304
/// <summary>
320305
/// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
321306
/// </summary>
322-
/// <returns></returns>
323307
public void AddTypical(float p, nint minKeep)
324308
{
325309
llama_sampler_chain_add(this, llama_sampler_init_typical(p, minKeep));
@@ -331,14 +315,15 @@ public void AddTypical(float p, nint minKeep)
331315
}
332316

333317
/// <summary>
334-
/// Apply temperature to the logits
318+
/// Apply temperature to the logits.
319+
/// If temperature is less than zero the maximum logit is left unchanged and the rest are set to -infinity
335320
/// </summary>
336321
/// <param name="t"></param>
337-
/// <returns></returns>
338322
public void AddTemperature(float t)
339323
{
340324
llama_sampler_chain_add(this, llama_sampler_init_temp(t));
341325

326+
// #details Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf
342327
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
343328
static extern IntPtr llama_sampler_init_temp(float t);
344329
}
@@ -349,7 +334,6 @@ public void AddTemperature(float t)
349334
/// <param name="t"></param>
350335
/// <param name="delta"></param>
351336
/// <param name="exponent"></param>
352-
/// <returns></returns>
353337
public void AddDynamicTemperature(float t, float delta, float exponent)
354338
{
355339
llama_sampler_chain_add(this, llama_sampler_init_temp_ext(t, delta, exponent));
@@ -358,6 +342,51 @@ public void AddDynamicTemperature(float t, float delta, float exponent)
358342
static extern IntPtr llama_sampler_init_temp_ext(float t, float delta, float exponent);
359343
}
360344

345+
/// <summary>
346+
/// XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
347+
/// </summary>
348+
/// <param name="p"></param>
349+
/// <param name="t"></param>
350+
/// <param name="minKeep"></param>
351+
/// <param name="seed"></param>
352+
public void AddXTC(float p, float t, int minKeep, uint seed)
353+
{
354+
llama_sampler_chain_add(this, llama_sampler_init_xtc(p, t, minKeep, seed));
355+
356+
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
357+
static extern IntPtr llama_sampler_init_xtc(float p, float t, nint minKeep, uint seed);
358+
}
359+
360+
/// <summary>
361+
/// This sampler is meant to be used for fill-in-the-middle infilling, after top_k + top_p sampling
362+
///<br />
363+
/// 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG<br />
364+
/// 2. combine probs of tokens that have the same prefix<br />
365+
/// <br />
366+
/// example:<br />
367+
/// <br />
368+
/// - before:<br />
369+
/// "hel": 0.5<br />
370+
/// "hell": 0.2<br />
371+
/// "hello": 0.1<br />
372+
/// "dummy": 0.1<br />
373+
///<br />
374+
/// - after:<br />
375+
/// "hel": 0.8<br />
376+
/// "dummy": 0.1<br />
377+
///<br />
378+
/// 3. discard non-EOG tokens with low prob<br />
379+
/// 4. if no tokens are left -> pick EOT
380+
/// </summary>
381+
/// <param name="model"></param>
382+
public void AddFillInMiddleInfill(SafeLlamaModelHandle model)
383+
{
384+
llama_sampler_chain_add(this, llama_sampler_init_infill(model));
385+
386+
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
387+
static extern IntPtr llama_sampler_init_infill(SafeLlamaModelHandle model);
388+
}
389+
361390
/// <summary>
362391
/// Create a sampler which makes tokens impossible unless they match the grammar
363392
/// </summary>

LLama/Native/SafeLlamaModelHandle.cs

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -386,32 +386,29 @@ private static int llama_model_meta_val_str(SafeLlamaModelHandle model, string k
386386
private static extern LLamaToken llama_token_pad(SafeLlamaModelHandle model);
387387

388388
/// <summary>
389-
/// codellama infill tokens, Beginning of infill prefix
389+
/// codellama infill tokens, End of infill middle
390390
/// </summary>
391391
/// <returns></returns>
392392
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
393-
private static extern int llama_token_prefix(SafeLlamaModelHandle model);
393+
private static extern int llama_token_eot(SafeLlamaModelHandle model);
394394

395-
/// <summary>
396-
/// codellama infill tokens, Beginning of infill middle
397-
/// </summary>
398-
/// <returns></returns>
399395
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
400-
private static extern int llama_token_middle(SafeLlamaModelHandle model);
396+
private static extern int llama_token_fim_pre(SafeLlamaModelHandle model);
401397

402-
/// <summary>
403-
/// codellama infill tokens, Beginning of infill suffix
404-
/// </summary>
405-
/// <returns></returns>
406398
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
407-
private static extern int llama_token_suffix(SafeLlamaModelHandle model);
399+
private static extern int llama_token_fim_suf(SafeLlamaModelHandle model);
408400

409-
/// <summary>
410-
/// codellama infill tokens, End of infill middle
411-
/// </summary>
412-
/// <returns></returns>
413401
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
414-
private static extern int llama_token_eot(SafeLlamaModelHandle model);
402+
private static extern int llama_token_fim_mid(SafeLlamaModelHandle model);
403+
404+
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
405+
private static extern int llama_token_fim_pad(SafeLlamaModelHandle model);
406+
407+
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
408+
private static extern int llama_token_fim_rep(SafeLlamaModelHandle model);
409+
410+
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
411+
private static extern int llama_token_fim_sep(SafeLlamaModelHandle model);
415412

416413
/// <summary>
417414
/// For encoder-decoder models, this function returns id of the token that must be provided

0 commit comments

Comments
 (0)