Skip to content

Commit dedaae2

Browse files
committed
applied changes from header, still need to update C#
1 parent 6df26d3 commit dedaae2

File tree

7 files changed

+83
-45
lines changed

7 files changed

+83
-45
lines changed

LLama/Native/LLamaNativeBatch.cs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ public unsafe struct LLamaNativeBatch
2525

2626
/// <summary>
2727
/// the positions of the respective token in the sequence
28+
/// (if set to NULL, the token position will be tracked automatically by llama_decode)
2829
/// </summary>
2930
public LLamaPos* pos;
3031

@@ -35,18 +36,13 @@ public unsafe struct LLamaNativeBatch
3536

3637
/// <summary>
3738
/// the sequence to which the respective token belongs
39+
/// (if set to NULL, the sequence ID will be assumed to be 0)
3840
/// </summary>
3941
public LLamaSeqId** seq_id;
4042

4143
/// <summary>
4244
/// if zero, the logits for the respective token will not be output
45+
/// (if set to NULL, only the logits for last token will be returned)
4346
/// </summary>
4447
public byte* logits;
45-
46-
// Note from llama.cpp:
47-
// > helpers for smooth API transition - can be deprecated in the future
48-
// > for future-proof code, use the above fields instead and ignore everything below
49-
private LLamaPos _all_pos_0;
50-
private LLamaPos _all_pos_1;
51-
private LLamaSeqId _all_seq_id;
5248
}

LLama/Native/LLamaPoolingType.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,9 @@ public enum LLamaPoolingType
2929
CLS = 2,
3030

3131
Last = 3,
32+
33+
/// <summary>
34+
/// Used by reranking models to attach the classification head to the graph
35+
/// </summary>
36+
Rank,
3237
}

LLama/Native/LLamaVocabPreType.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,5 @@ internal enum LLamaVocabPreType
3333
BLOOM = 23,
3434
GPT3_FINNISH = 24,
3535
EXAONE = 25,
36+
CHAMELEON = 26,
3637
}

LLama/Native/NativeApi.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ public static void llama_empty_call()
4949
[return: MarshalAs(UnmanagedType.U1)]
5050
public static extern bool llama_supports_gpu_offload();
5151

52+
/// <summary>
53+
/// Check if RPC offload is supported
54+
/// </summary>
55+
/// <returns></returns>
56+
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
57+
[return: MarshalAs(UnmanagedType.U1)]
58+
public static extern bool llama_supports_rpc();
59+
5260
/// <summary>
5361
/// Initialize the llama + ggml backend. Call once at the start of the program.
5462
///

LLama/Native/SafeLLamaContextHandle.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,10 @@ static SafeLLamaContextHandle()
368368
private static extern LLamaPoolingType llama_pooling_type(SafeLLamaContextHandle ctx);
369369

370370
/// <summary>
371-
/// Get the embeddings for the a specific sequence.
372-
/// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
371+
/// Get the embeddings for a sequence id.
372+
/// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
373+
/// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence
374+
/// otherwise: float[n_embd] (1-dimensional)
373375
/// </summary>
374376
/// <returns>A pointer to the first float in an embedding, length = ctx.EmbeddingSize</returns>
375377
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]

LLama/Native/SafeLLamaSamplerHandle.cs

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -267,19 +267,6 @@ public void AddMirostat2Sampler(uint seed, float tau, float eta)
267267
static extern IntPtr llama_sampler_init_mirostat_v2(uint seed, float tau, float eta);
268268
}
269269

270-
271-
/// <summary>
272-
/// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
273-
/// </summary>
274-
/// <returns></returns>
275-
public void AddSoftmax()
276-
{
277-
llama_sampler_chain_add(this, llama_sampler_init_softmax());
278-
279-
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
280-
static extern IntPtr llama_sampler_init_softmax();
281-
}
282-
283270
/// <summary>
284271
/// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
285272
/// </summary>
@@ -309,7 +296,6 @@ public void AddTopP(float p, nint minKeep)
309296
/// <summary>
310297
/// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
311298
/// </summary>
312-
/// <returns></returns>
313299
public void AddMinP(float p, nint minKeep)
314300
{
315301
llama_sampler_chain_add(this, llama_sampler_init_min_p(p, minKeep));
@@ -323,7 +309,6 @@ public void AddMinP(float p, nint minKeep)
323309
/// <summary>
324310
/// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
325311
/// </summary>
326-
/// <returns></returns>
327312
public void AddTailFree(float z, nint minKeep)
328313
{
329314
llama_sampler_chain_add(this, llama_sampler_init_tail_free(z, minKeep));
@@ -337,7 +322,6 @@ public void AddTailFree(float z, nint minKeep)
337322
/// <summary>
338323
/// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
339324
/// </summary>
340-
/// <returns></returns>
341325
public void AddTypical(float p, nint minKeep)
342326
{
343327
llama_sampler_chain_add(this, llama_sampler_init_typical(p, minKeep));
@@ -349,14 +333,15 @@ public void AddTypical(float p, nint minKeep)
349333
}
350334

351335
/// <summary>
352-
/// Apply temperature to the logits
336+
/// Apply temperature to the logits.
337+
/// If temperature is less than zero the maximum logit is left unchanged and the rest are set to -infinity
353338
/// </summary>
354339
/// <param name="t"></param>
355-
/// <returns></returns>
356340
public void AddTemperature(float t)
357341
{
358342
llama_sampler_chain_add(this, llama_sampler_init_temp(t));
359343

344+
// #details Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf
360345
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
361346
static extern IntPtr llama_sampler_init_temp(float t);
362347
}
@@ -367,7 +352,6 @@ public void AddTemperature(float t)
367352
/// <param name="t"></param>
368353
/// <param name="delta"></param>
369354
/// <param name="exponent"></param>
370-
/// <returns></returns>
371355
public void AddDynamicTemperature(float t, float delta, float exponent)
372356
{
373357
llama_sampler_chain_add(this, llama_sampler_init_temp_ext(t, delta, exponent));
@@ -376,6 +360,51 @@ public void AddDynamicTemperature(float t, float delta, float exponent)
376360
static extern IntPtr llama_sampler_init_temp_ext(float t, float delta, float exponent);
377361
}
378362

363+
/// <summary>
364+
/// XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
365+
/// </summary>
366+
/// <param name="p"></param>
367+
/// <param name="t"></param>
368+
/// <param name="minKeep"></param>
369+
/// <param name="seed"></param>
370+
public void AddXTC(float p, float t, int minKeep, uint seed)
371+
{
372+
llama_sampler_chain_add(this, llama_sampler_init_xtc(p, t, minKeep, seed));
373+
374+
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
375+
static extern IntPtr llama_sampler_init_xtc(float p, float t, nint minKeep, uint seed);
376+
}
377+
378+
/// <summary>
379+
/// This sampler is meant to be used for fill-in-the-middle infilling, after top_k + top_p sampling
380+
///<br />
381+
/// 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG<br />
382+
/// 2. combine probs of tokens that have the same prefix<br />
383+
/// <br />
384+
/// example:<br />
385+
/// <br />
386+
/// - before:<br />
387+
/// "hel": 0.5<br />
388+
/// "hell": 0.2<br />
389+
/// "hello": 0.1<br />
390+
/// "dummy": 0.1<br />
391+
///<br />
392+
/// - after:<br />
393+
/// "hel": 0.8<br />
394+
/// "dummy": 0.1<br />
395+
///<br />
396+
/// 3. discard non-EOG tokens with low prob<br />
397+
/// 4. if no tokens are left -> pick EOT
398+
/// </summary>
399+
/// <param name="model"></param>
400+
public void AddFillInMiddleInfill(SafeLlamaModelHandle model)
401+
{
402+
llama_sampler_chain_add(this, llama_sampler_init_infill(model));
403+
404+
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
405+
static extern IntPtr llama_sampler_init_infill(SafeLlamaModelHandle model);
406+
}
407+
379408
/// <summary>
380409
/// Create a sampler which makes tokens impossible unless they match the grammar
381410
/// </summary>

LLama/Native/SafeLlamaModelHandle.cs

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -386,32 +386,29 @@ private static int llama_model_meta_val_str(SafeLlamaModelHandle model, string k
386386
private static extern LLamaToken llama_token_pad(SafeLlamaModelHandle model);
387387

388388
/// <summary>
389-
/// codellama infill tokens, Beginning of infill prefix
389+
/// codellama infill tokens, End of infill middle
390390
/// </summary>
391391
/// <returns></returns>
392392
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
393-
private static extern int llama_token_prefix(SafeLlamaModelHandle model);
393+
private static extern int llama_token_eot(SafeLlamaModelHandle model);
394394

395-
/// <summary>
396-
/// codellama infill tokens, Beginning of infill middle
397-
/// </summary>
398-
/// <returns></returns>
399395
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
400-
private static extern int llama_token_middle(SafeLlamaModelHandle model);
396+
private static extern int llama_token_fim_pre(SafeLlamaModelHandle model);
401397

402-
/// <summary>
403-
/// codellama infill tokens, Beginning of infill suffix
404-
/// </summary>
405-
/// <returns></returns>
406398
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
407-
private static extern int llama_token_suffix(SafeLlamaModelHandle model);
399+
private static extern int llama_token_fim_suf(SafeLlamaModelHandle model);
408400

409-
/// <summary>
410-
/// codellama infill tokens, End of infill middle
411-
/// </summary>
412-
/// <returns></returns>
413401
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
414-
private static extern int llama_token_eot(SafeLlamaModelHandle model);
402+
private static extern int llama_token_fim_mid(SafeLlamaModelHandle model);
403+
404+
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
405+
private static extern int llama_token_fim_pad(SafeLlamaModelHandle model);
406+
407+
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
408+
private static extern int llama_token_fim_rep(SafeLlamaModelHandle model);
409+
410+
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
411+
private static extern int llama_token_fim_sep(SafeLlamaModelHandle model);
415412

416413
/// <summary>
417414
/// For encoder-decoder models, this function returns id of the token that must be provided

0 commit comments

Comments
 (0)