diff --git a/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs
index 413c9ed35..5c675ab13 100644
--- a/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs
+++ b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs
@@ -1,5 +1,6 @@
using LLama;
using LLama.Abstractions;
+using Microsoft.Extensions.AI;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using System.Runtime.CompilerServices;
@@ -12,7 +13,7 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion;
///
/// LLamaSharp ChatCompletion
///
-public sealed class LLamaSharpChatCompletion : IChatCompletionService
+public sealed class LLamaSharpChatCompletion : IChatClient, IChatCompletionService
{
private readonly ILLamaExecutor _model;
private readonly LLamaSharpPromptExecutionSettings _defaultRequestSettings;
@@ -64,14 +65,7 @@ public ChatHistory CreateNewChat(string? instructions = "")
///
public async Task> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
{
- var settings = executionSettings != null
- ? LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings)
- : _defaultRequestSettings;
-
- var prompt = _getFormattedPrompt(chatHistory);
- var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken);
-
- var output = _outputTransform.TransformAsync(result);
+ var output = InferChatHistory(chatHistory, executionSettings, cancellationToken);
var sb = new StringBuilder();
await foreach (var token in output)
@@ -84,20 +78,26 @@ public async Task> GetChatMessageContentsAsync
///
public async IAsyncEnumerable GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ var output = InferChatHistory(chatHistory, executionSettings, cancellationToken);
+
+ await foreach (var token in output)
+ {
+ yield return new StreamingChatMessageContent(AuthorRole.Assistant, token);
+ }
+ }
+
+ private IAsyncEnumerable InferChatHistory(ChatHistory chatHistory, PromptExecutionSettings? executionSettings, CancellationToken cancellationToken)
{
var settings = executionSettings != null
- ? LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings)
- : _defaultRequestSettings;
+ ? LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings)
+ : _defaultRequestSettings;
var prompt = _getFormattedPrompt(chatHistory);
var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken);
var output = _outputTransform.TransformAsync(result);
-
- await foreach (var token in output)
- {
- yield return new StreamingChatMessageContent(AuthorRole.Assistant, token);
- }
+ return output;
}
///
@@ -130,4 +130,81 @@ private string _getFormattedPrompt(ChatHistory chatHistory)
return prompt;
}
+
+ private string _getFormattedPrompt(IEnumerable messages)
+ {
+ string prompt;
+ if (_isStatefulExecutor)
+ {
+ var state = (InteractiveExecutorState)((StatefulExecutorBase)_model).GetStateData();
+ if (state.IsPromptRun)
+ {
+ prompt = _historyTransform.HistoryToText(messages.ToLLamaSharpChatHistory());
+ }
+ else
+ {
+ ChatHistory tempHistory = new();
+ tempHistory.AddUserMessage(messages.Last().Text ?? "");
+ prompt = _historyTransform.HistoryToText(tempHistory.ToLLamaSharpChatHistory());
+ }
+ }
+ else
+ {
+ prompt = _historyTransform.HistoryToText(messages.ToLLamaSharpChatHistory());
+ }
+
+ return prompt;
+ }
+
+ public async Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
+ {
+ var output = InferChatMessage(messages, options, cancellationToken);
+
+ var sb = new StringBuilder();
+ await foreach (var token in output)
+ {
+ sb.Append(token);
+ }
+
+ return new ChatResponse(new ChatMessage(ChatRole.Assistant, sb.ToString()));
+ }
+
+ ///
+ public async IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
+ {
+ var output = InferChatMessage(messages, options, cancellationToken);
+
+ await foreach (var token in output)
+ {
+ yield return new ChatResponseUpdate(ChatRole.Assistant, token);
+ }
+ }
+
+ private IAsyncEnumerable InferChatMessage(IEnumerable messages, ChatOptions? options, CancellationToken cancellationToken)
+ {
+ var settings = options != null
+ ? LLamaSharpPromptExecutionSettings.FromRequestSettings(options)
+ : _defaultRequestSettings;
+
+ var prompt = _getFormattedPrompt(messages);
+ var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken);
+
+ var output = _outputTransform.TransformAsync(result);
+ return output;
+ }
+
+ ///
+ public object? GetService(Type serviceType, object? serviceKey = null)
+ {
+ throw new NotImplementedException();
+ }
+
+ ///
+ public void Dispose()
+ {
+ if (_outputTransform is IDisposable disposable)
+ {
+ disposable.Dispose();
+ }
+ }
}
diff --git a/LLama.SemanticKernel/ExtensionMethods.cs b/LLama.SemanticKernel/ExtensionMethods.cs
index ba1b74479..3b8d73fb4 100644
--- a/LLama.SemanticKernel/ExtensionMethods.cs
+++ b/LLama.SemanticKernel/ExtensionMethods.cs
@@ -1,4 +1,5 @@
using LLama.Sampling;
+using Microsoft.Extensions.AI;
using Microsoft.SemanticKernel.ChatCompletion;
using AuthorRole = LLama.Common.AuthorRole;
@@ -26,6 +27,26 @@ public static LLama.Common.ChatHistory ToLLamaSharpChatHistory(this ChatHistory
return history;
}
+ public static LLama.Common.ChatHistory ToLLamaSharpChatHistory(this IEnumerable messages, bool ignoreCase = true)
+ {
+ if (messages is null)
+ {
+ throw new ArgumentNullException(nameof(messages));
+ }
+
+ var history = new LLama.Common.ChatHistory();
+
+ foreach (var chat in messages)
+ {
+ if (!Enum.TryParse(chat.Role.Value, ignoreCase, out var role))
+ role = AuthorRole.Unknown;
+
+ history.AddMessage(role, chat.Text ?? "");
+ }
+
+ return history;
+ }
+
///
/// Convert LLamaSharpPromptExecutionSettings to LLamaSharp InferenceParams
///
diff --git a/LLama.SemanticKernel/LLamaSharpPromptExecutionSettings.cs b/LLama.SemanticKernel/LLamaSharpPromptExecutionSettings.cs
index 77fe9a75c..34b166ccd 100644
--- a/LLama.SemanticKernel/LLamaSharpPromptExecutionSettings.cs
+++ b/LLama.SemanticKernel/LLamaSharpPromptExecutionSettings.cs
@@ -1,3 +1,4 @@
+using Microsoft.Extensions.AI;
using Microsoft.SemanticKernel;
using System.Text.Json;
using System.Text.Json.Serialization;
@@ -97,6 +98,47 @@ public static LLamaSharpPromptExecutionSettings FromRequestSettings(PromptExecut
throw new ArgumentException($"Invalid request settings, cannot convert to {nameof(LLamaSharpPromptExecutionSettings)}", nameof(requestSettings));
}
+ internal static LLamaSharpPromptExecutionSettings FromRequestSettings(ChatOptions? options, int? defaultMaxTokens = null)
+ {
+ if (options == null)
+ {
+ return new LLamaSharpPromptExecutionSettings
+ {
+ MaxTokens = defaultMaxTokens
+ };
+ }
+
+ // Handle nullable float? to double conversion and nullability
+ double GetDoubleOrDefault(float? value, double defaultValue = 0.0) => value.HasValue ? (double)value.Value : defaultValue;
+
+ // Handle StopSequences: ensure always IList
+ IList stopSequences = options.StopSequences != null
+ ? new List(options.StopSequences)
+ : new List();
+
+ // ResultsPerPrompt, MaxTokens, TokenSelectionBiases, ResponseFormat: check for property existence
+ // Since these properties do not exist on ChatOptions, use defaults
+ int resultsPerPrompt = 1;
+ int? maxTokens = defaultMaxTokens;
+ IDictionary tokenSelectionBiases = new Dictionary();
+ string responseFormat = options.ResponseFormat?.ToString() ?? string.Empty;
+
+ var settings = new LLamaSharpPromptExecutionSettings
+ {
+ Temperature = GetDoubleOrDefault(options.Temperature),
+ TopP = GetDoubleOrDefault(options.TopP),
+ PresencePenalty = GetDoubleOrDefault(options.PresencePenalty),
+ FrequencyPenalty = GetDoubleOrDefault(options.FrequencyPenalty),
+ StopSequences = stopSequences,
+ ResultsPerPrompt = resultsPerPrompt,
+ MaxTokens = maxTokens ?? options.MaxOutputTokens,
+ TokenSelectionBiases = tokenSelectionBiases,
+ ResponseFormat = responseFormat
+ };
+
+ return settings;
+ }
+
private static readonly JsonSerializerOptions SerializerOptions = new()
{
WriteIndented = true,
diff --git a/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs b/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs
index d50945117..08c015566 100644
--- a/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs
+++ b/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs
@@ -1,11 +1,12 @@
using LLama;
+using Microsoft.Extensions.AI;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Embeddings;
namespace LLamaSharp.SemanticKernel.TextEmbedding;
public sealed class LLamaSharpEmbeddingGeneration
- : ITextEmbeddingGenerationService
+ : IEmbeddingGenerator, ITextEmbeddingGenerationService, IDisposable
{
private readonly LLamaEmbedder _embedder;
@@ -28,4 +29,19 @@ public async Task>> GenerateEmbeddingsAsync(IList
+ public object? GetService(Type serviceType, object? serviceKey = null)
+ {
+ throw new NotImplementedException();
+ }
+
+ ///
+ public void Dispose()
+ {
+ if (_embedder is IDisposable disposable)
+ {
+ disposable.Dispose();
+ }
+ }
}
diff --git a/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs b/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs
index d75a8d4b4..b16c16db8 100644
--- a/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs
+++ b/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs
@@ -1,16 +1,27 @@
-using LLamaSharp.SemanticKernel;
+using LLamaSharp.SemanticKernel;
+using Microsoft.Extensions.AI;
using Microsoft.SemanticKernel;
namespace LLama.Unittest.SemanticKernel
{
public class ChatRequestSettingsTests
{
- [Fact]
- public void ChatRequestSettings_FromRequestSettingsNull()
+ public static IEnumerable
public LLamaAttentionType attention_type;
-
+
+ ///
+ /// When to enable Flash Attention
+ ///
+ public LLamaAttentionType flash_attn_type;
+
///
/// RoPE base frequency, 0 = from model
///
diff --git a/LLama/Native/LLamaFlashAttnType.cs b/LLama/Native/LLamaFlashAttnType.cs
new file mode 100644
index 000000000..116fbb298
--- /dev/null
+++ b/LLama/Native/LLamaFlashAttnType.cs
@@ -0,0 +1,23 @@
+namespace LLama.Native;
+
+///
+///
+///
+/// llama_flash_attn_type
+public enum LLamaFlashAttnType
+{
+ ///
+ ///
+ ///
+ Auto = -1,
+
+ ///
+ ///
+ ///
+ Disable = 0,
+
+ ///
+ ///
+ ///
+ Enabled = 1,
+}
\ No newline at end of file
diff --git a/LLama/Native/LLamaFtype.cs b/LLama/Native/LLamaFtype.cs
index 705f8032e..70db7986d 100644
--- a/LLama/Native/LLamaFtype.cs
+++ b/LLama/Native/LLamaFtype.cs
@@ -202,6 +202,11 @@ public enum LLamaFtype
///
LLAMA_FTYPE_MOSTLY_TQ2_0 = 37,
+ ///
+ /// except 1d tensors
+ ///
+ LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38,
+
///
/// File type was not specified
///
diff --git a/LLama/Native/LLamaModelParams.cs b/LLama/Native/LLamaModelParams.cs
index acb024852..20a9c99d0 100644
--- a/LLama/Native/LLamaModelParams.cs
+++ b/LLama/Native/LLamaModelParams.cs
@@ -99,7 +99,17 @@ public bool check_tensors
readonly get => Convert.ToBoolean(_check_tensors);
set => _check_tensors = Convert.ToSByte(value);
}
- private sbyte _check_tensors;
+ private sbyte _check_tensors;
+
+ ///
+ /// use extra buffer types (used for weight repacking)
+ ///
+ public bool use_extra_bufts
+ {
+ readonly get => Convert.ToBoolean(_use_extra_bufts);
+ set => _use_extra_bufts = Convert.ToSByte(value);
+ }
+ private sbyte _use_extra_bufts;
///
/// Create a LLamaModelParams with default values
diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs
index e26619b26..915b08be8 100644
--- a/LLama/Native/SafeLLamaContextHandle.cs
+++ b/LLama/Native/SafeLLamaContextHandle.cs
@@ -294,7 +294,7 @@ static SafeLLamaContextHandle()
/// Get the exact size needed to copy the state of a single sequence
///
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
- private static extern nuint llama_state_seq_get_size(SafeLLamaContextHandle ctx, LLamaSeqId seqId);
+ private static extern nuint llama_state_seq_get_size(SafeLLamaContextHandle ctx, LLamaSeqId seqId, uint llama_state_seq_flags);
///
/// Copy the state of a single sequence into the specified buffer
@@ -303,9 +303,10 @@ static SafeLLamaContextHandle()
///
///
///
+ ///
///
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
- private static extern unsafe nuint llama_state_seq_get_data(SafeLLamaContextHandle ctx, byte* dst, nuint size, LLamaSeqId seqId);
+ private static extern unsafe nuint llama_state_seq_get_data(SafeLLamaContextHandle ctx, byte* dst, nuint size, LLamaSeqId seqId, uint llama_state_seq_flags);
///
/// Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
@@ -314,12 +315,13 @@ static SafeLLamaContextHandle()
///
///
///
+ ///
///
/// - Positive: Ok
/// - Zero: Failed to load
///
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
- private static extern unsafe nuint llama_state_seq_set_data(SafeLLamaContextHandle ctx, byte* src, nuint size, LLamaSeqId destSeqId);
+ private static extern unsafe nuint llama_state_seq_set_data(SafeLLamaContextHandle ctx, byte* src, nuint size, LLamaSeqId destSeqId, uint llama_state_seq_flags);
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern LLamaPerfContextTimings llama_perf_context(SafeLLamaContextHandle ctx);
@@ -680,7 +682,7 @@ public nuint GetStateSize()
///
public nuint GetStateSize(LLamaSeqId sequence)
{
- return llama_state_seq_get_size(this, sequence);
+ return llama_state_seq_get_size(this, sequence, 0u);
}
///
@@ -712,7 +714,7 @@ public unsafe nuint GetState(byte* dest, nuint size, LLamaSeqId sequence)
if (size < required)
throw new ArgumentOutOfRangeException(nameof(size), $"Allocated space is too small, {size} < {required}");
- return llama_state_seq_get_data(this, dest, size, sequence);
+ return llama_state_seq_get_data(this, dest, size, sequence, 0u);
}
///
@@ -735,7 +737,7 @@ public unsafe nuint SetState(byte* src, nuint size)
/// Number of bytes read from the src pointer
public unsafe nuint SetState(byte* src, nuint size, LLamaSeqId sequence)
{
- return llama_state_seq_set_data(this, src, size, sequence);
+ return llama_state_seq_set_data(this, src, size, sequence, 1u);
}
#endregion