Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LLama/LLamaSharp.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
</ItemGroup>

<PropertyGroup>
<BinaryReleaseId>11dd5a44eb180e</BinaryReleaseId>
<BinaryReleaseId>86587da</BinaryReleaseId>
</PropertyGroup>

<PropertyGroup>
Expand Down
5 changes: 5 additions & 0 deletions LLama/Native/LLamaContextParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ public struct LLamaContextParams
/// Attention type to use for embeddings
/// </summary>
public LLamaAttentionType attention_type;

/// <summary>
/// when to enable Flash Attention
/// </summary>
public LLamaFlashAttentionType llama_flash_attn_type;

/// <summary>
/// RoPE base frequency, 0 = from model
Expand Down
19 changes: 19 additions & 0 deletions LLama/Native/LLamaFlashAttentionType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
namespace LLama.Native;
/// <summary>
/// flash_attn_type
/// </summary>
public enum LLamaFlashAttentionType
{
/// <summary>
/// attention type auto
/// </summary>
LLAMA_FLASH_ATTENTION_TYPE_AUTO = -1,
/// <summary>
/// attention disabled
/// </summary>
LLAMA_FLASH_ATTENTION_TYPE_DISABLED = 0,
/// <summary>
/// attention enabled
/// </summary>
LLAMA_FLASH_ATTENTION_TYPE_ENABLED = 1,
}
7 changes: 6 additions & 1 deletion LLama/Native/LLamaFtype.cs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,12 @@ public enum LLamaFtype
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_TQ2_0 = 37,


/// <summary>
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38,

/// <summary>
/// File type was not specified
/// </summary>
Expand Down
11 changes: 10 additions & 1 deletion LLama/Native/LLamaModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,16 @@ public bool check_tensors
set => _check_tensors = Convert.ToSByte(value);
}
private sbyte _check_tensors;


/// <summary>
/// use extra buffer types (used for weight repacking)
/// </summary>
public bool use_extra_bufts
{
readonly get => Convert.ToBoolean(_use_extra_bufts);
set => _use_extra_bufts = Convert.ToSByte(value);
}
private sbyte _use_extra_bufts;
/// <summary>
/// Create a LLamaModelParams with default values
/// </summary>
Expand Down
95 changes: 75 additions & 20 deletions LLama/Native/NativeApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ public static void llama_empty_call()
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
[return: MarshalAs(UnmanagedType.U1)]
public static extern bool llama_state_load_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens_out, ulong n_token_capacity, out ulong n_token_count_out);
public static extern bool llama_state_load_file(SafeLLamaContextHandle ctx, string path_session,
LLamaToken[] tokens_out, ulong n_token_capacity, out ulong n_token_count_out);

/// <summary>
/// Save session file
Expand All @@ -111,39 +112,45 @@ public static void llama_empty_call()
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
[return: MarshalAs(UnmanagedType.U1)]
public static extern bool llama_state_save_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens, ulong n_token_count);
public static extern bool llama_state_save_file(SafeLLamaContextHandle ctx, string path_session,
LLamaToken[] tokens, ulong n_token_count);

/// <summary>
/// Saves the specified sequence as a file on specified filepath. Can later be loaded via <see cref="llama_state_load_file(SafeLLamaContextHandle, string, LLamaToken[], ulong, out ulong)"/>
/// </summary>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe nuint llama_state_seq_save_file(SafeLLamaContextHandle ctx, string filepath, LLamaSeqId seq_id, LLamaToken* tokens, nuint n_token_count);
public static extern unsafe nuint llama_state_seq_save_file(SafeLLamaContextHandle ctx, string filepath,
LLamaSeqId seq_id, LLamaToken* tokens, nuint n_token_count);

/// <summary>
/// Loads a sequence saved as a file via <see cref="llama_state_save_file(SafeLLamaContextHandle, string, LLamaToken[], ulong)"/> into the specified sequence
/// </summary>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe nuint llama_state_seq_load_file(SafeLLamaContextHandle ctx, string filepath, LLamaSeqId dest_seq_id, LLamaToken* tokens_out, nuint n_token_capacity, out nuint n_token_count_out);
public static extern unsafe nuint llama_state_seq_load_file(SafeLLamaContextHandle ctx, string filepath,
LLamaSeqId dest_seq_id, LLamaToken* tokens_out, nuint n_token_capacity, out nuint n_token_count_out);

/// <summary>
/// Set whether to use causal attention or not. If set to true, the model will only attend to the past tokens
/// </summary>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_set_causal_attn(SafeLLamaContextHandle ctx, [MarshalAs(UnmanagedType.U1)] bool causalAttn);
public static extern void llama_set_causal_attn(SafeLLamaContextHandle ctx,
[MarshalAs(UnmanagedType.U1)] bool causalAttn);

/// <summary>
/// Set whether the context outputs embeddings or not
/// </summary>
/// <param name="ctx"></param>
/// <param name="embeddings">If true, embeddings will be returned but logits will not</param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_set_embeddings(SafeLLamaContextHandle ctx, [MarshalAs(UnmanagedType.U1)] bool embeddings);
public static extern void llama_set_embeddings(SafeLLamaContextHandle ctx,
[MarshalAs(UnmanagedType.U1)] bool embeddings);

/// <summary>
/// Set abort callback
/// </summary>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_set_abort_callback(SafeLlamaModelHandle ctx, IntPtr /* ggml_abort_callback */ abortCallback, IntPtr abortCallbackData);
public static extern void llama_set_abort_callback(SafeLlamaModelHandle ctx,
IntPtr /* ggml_abort_callback */ abortCallback, IntPtr abortCallbackData);

/// <summary>
/// Get the n_seq_max for this context
Expand Down Expand Up @@ -175,12 +182,15 @@ public static void llama_empty_call()
/// <param name="buf">A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages)</param>
/// <param name="length">The size of the allocated buffer</param>
/// <returns>The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.</returns>
public static unsafe int llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg, [MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length)
public static unsafe int llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg,
[MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length)
{
return internal_llama_chat_apply_template(tmpl, chat, n_msg, add_ass, buf, length);

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_chat_apply_template")]
static extern int internal_llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg, [MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length);
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl,
EntryPoint = "llama_chat_apply_template")]
static extern int internal_llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg,
[MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length);
}

/// <summary>
Expand Down Expand Up @@ -215,7 +225,8 @@ public static unsafe int llama_chat_apply_template(byte* tmpl, LLamaChatMessage*
/// <param name="lstrip">User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix')</param>
/// <param name="special">If true, special tokens are rendered in the output</param>
/// <returns>The length written, or if the buffer is too small a negative that indicates the length required</returns>
public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LLamaToken llamaToken, Span<byte> buffer, int lstrip, bool special)
public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LLamaToken llamaToken,
Span<byte> buffer, int lstrip, bool special)
{
// Handle invalid tokens
if ((int)llamaToken < 0)
Expand All @@ -225,12 +236,14 @@ public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LL
{
fixed (byte* bufferPtr = buffer)
{
return llama_token_to_piece_native(vocab.VocabNative, llamaToken, bufferPtr, buffer.Length, lstrip, special);
return llama_token_to_piece_native(vocab.VocabNative, llamaToken, bufferPtr, buffer.Length, lstrip,
special);
}
}

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_token_to_piece")]
static extern unsafe int llama_token_to_piece_native(LLamaVocabNative* model, LLamaToken llamaToken, byte* buffer, int length, int lstrip, [MarshalAs(UnmanagedType.U1)] bool special);
static extern unsafe int llama_token_to_piece_native(LLamaVocabNative* model, LLamaToken llamaToken,
byte* buffer, int length, int lstrip, [MarshalAs(UnmanagedType.U1)] bool special);
}

/// <summary>
Expand All @@ -247,7 +260,9 @@ public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LL
/// Returns a negative number on failure - the number of tokens that would have been returned. Returns INT32_MIN on overflow (e.g., tokenization result size exceeds int32_t limit)
/// </returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
internal static extern unsafe int llama_tokenize(LLamaVocabNative* model, byte* text, int text_len, LLamaToken* tokens, int n_max_tokens, [MarshalAs(UnmanagedType.U1)] bool add_special, [MarshalAs(UnmanagedType.U1)] bool parse_special);
internal static extern unsafe int llama_tokenize(LLamaVocabNative* model, byte* text, int text_len,
LLamaToken* tokens, int n_max_tokens, [MarshalAs(UnmanagedType.U1)] bool add_special,
[MarshalAs(UnmanagedType.U1)] bool parse_special);

/// <summary>
/// Convert the provided tokens into text (inverse of llama_tokenize()).
Expand All @@ -261,7 +276,8 @@ public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LL
/// <param name="unparseSpecial">unparse_special If true, special tokens are rendered in the output.</param>
/// <returns>Returns the number of chars/bytes on success, no more than textLengthMax. Returns a negative number on failure - the number of chars/bytes that would have been returned.</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
internal static extern unsafe int llama_detokenize(LLamaVocabNative* model, LLamaToken* tokens, int nTokens, byte* textOut, int textLengthMax, bool removeSpecial, bool unparseSpecial);
internal static extern unsafe int llama_detokenize(LLamaVocabNative* model, LLamaToken* tokens, int nTokens,
byte* textOut, int textLengthMax, bool removeSpecial, bool unparseSpecial);

/// <summary>
/// Register a callback to receive llama log messages
Expand All @@ -272,7 +288,7 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
{
NativeLogConfig.llama_log_set(logCallback);
}

/// <summary>
/// Allocates a batch of tokens on the heap
/// Each token can be assigned up to n_seq_max sequence ids
Expand Down Expand Up @@ -311,7 +327,8 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
/// <param name="il_end"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe int llama_apply_adapter_cvec(SafeLLamaContextHandle ctx, float* data, nuint len, int n_embd, int il_start, int il_end);
public static extern unsafe int llama_apply_adapter_cvec(SafeLLamaContextHandle ctx, float* data, nuint len,
int n_embd, int il_start, int il_end);

/// <summary>
/// Build a split GGUF final path for this chunk.
Expand All @@ -324,7 +341,8 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
/// <param name="split_count"></param>
/// <returns>Returns the split_path length.</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_split_path(string split_path, nuint maxlen, string path_prefix, int split_no, int split_count);
public static extern int llama_split_path(string split_path, nuint maxlen, string path_prefix, int split_no,
int split_count);

/// <summary>
/// Extract the path prefix from the split_path if and only if the split_no and split_count match.
Expand All @@ -337,7 +355,8 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
/// <param name="split_count"></param>
/// <returns>Returns the split_prefix length.</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_split_prefix(string split_prefix, nuint maxlen, string split_path, int split_no, int split_count);
public static extern int llama_split_prefix(string split_prefix, nuint maxlen, string split_path, int split_no,
int split_count);

//[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
//todo: public static void llama_attach_threadpool(SafeLLamaContextHandle ctx, ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch);
Expand Down Expand Up @@ -380,5 +399,41 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
/// <returns>Name of the buffer type</returns>
[DllImport(ggmlBaseLibraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr ggml_backend_buft_name(IntPtr buft);

/// <summary>
///
/// </summary>
/// <param name="ctx"></param>
/// <param name="seq_id"></param>
/// <param name="flags"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern UIntPtr llama_state_seq_get_size_ext(IntPtr ctx, int seq_id, uint flags);

/// <summary>
///
/// </summary>
/// <param name="ctx"></param>
/// <param name="dst"></param>
/// <param name="size"></param>
/// <param name="seq_id"></param>
/// <param name="flags"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern UIntPtr llama_state_seq_get_data_ext(IntPtr ctx, [Out] byte[] dst, UIntPtr size,
int seq_id, uint flags);

/// <summary>
///
/// </summary>
/// <param name="ctx"></param>
/// <param name="src"></param>
/// <param name="size"></param>
/// <param name="dest_seq_id"></param>
/// <param name="flags"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern UIntPtr llama_state_seq_set_data_ext(IntPtr ctx, byte[] src, UIntPtr size, int dest_seq_id,
uint flags);
}
}
}
41 changes: 41 additions & 0 deletions LLama/Native/SafeLLamaContextHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,47 @@ static SafeLLamaContextHandle()
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int llama_set_adapter_lora(SafeLLamaContextHandle context, IntPtr adapter, float scale);

/// <summary>
/// Get metadata value as a string by key name
/// </summary>
/// <param name="adapter"></param>
/// <param name="key"></param>
/// <param name="buf"></param>
/// <param name="buf_size"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int llama_adapter_meta_val_str(IntPtr adapter, string key, StringBuilder buf, UIntPtr buf_size);

/// <summary>
/// Get the number of metadata key value pairs
/// </summary>
/// <param name="adapter"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int llama_adapter_meta_count(IntPtr adapter);

/// <summary>
/// Get metadata key name by index
/// </summary>
/// <param name="adapter"></param>
/// <param name="i"></param>
/// <param name="buf"></param>
/// <param name="buf_size"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int llama_adapter_meta_key_by_index(IntPtr adapter, int i, StringBuilder buf, UIntPtr buf_size);

/// <summary>
/// Get metadata key value by index
/// </summary>
/// <param name="adapter"></param>
/// <param name="i"></param>
/// <param name="buf"></param>
/// <param name="buf_size"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int llama_adapter_meta_val_by_index(IntPtr adapter, int i, StringBuilder buf, UIntPtr buf_size);

[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int llama_rm_adapter_lora(SafeLLamaContextHandle context, IntPtr adapter);

Expand Down
2 changes: 1 addition & 1 deletion LLama/Native/SafeLLamaSamplerHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ static extern unsafe IntPtr llama_sampler_init_logit_bias(

// This is a tricky method to work with!
// It can't return a handle, because that would create a second handle to these resources.
// Instead It returns the raw pointer, and that can be looked up in the _samplers dictionary.
// Instead , It returns the raw pointer, and that can be looked up in the _samplers dictionary.
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern IntPtr llama_sampler_chain_get(SafeLLamaSamplerChainHandle chain, int i);
// ReSharper restore InconsistentNaming
Expand Down
Loading
Loading