diff --git a/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs index 413c9ed35..5c675ab13 100644 --- a/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs +++ b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs @@ -1,5 +1,6 @@ using LLama; using LLama.Abstractions; +using Microsoft.Extensions.AI; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; using System.Runtime.CompilerServices; @@ -12,7 +13,7 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion; /// /// LLamaSharp ChatCompletion /// -public sealed class LLamaSharpChatCompletion : IChatCompletionService +public sealed class LLamaSharpChatCompletion : IChatClient, IChatCompletionService { private readonly ILLamaExecutor _model; private readonly LLamaSharpPromptExecutionSettings _defaultRequestSettings; @@ -64,14 +65,7 @@ public ChatHistory CreateNewChat(string? instructions = "") /// public async Task> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) { - var settings = executionSettings != null - ? LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings) - : _defaultRequestSettings; - - var prompt = _getFormattedPrompt(chatHistory); - var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken); - - var output = _outputTransform.TransformAsync(result); + var output = InferChatHistory(chatHistory, executionSettings, cancellationToken); var sb = new StringBuilder(); await foreach (var token in output) @@ -84,20 +78,26 @@ public async Task> GetChatMessageContentsAsync /// public async IAsyncEnumerable GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var output = InferChatHistory(chatHistory, executionSettings, cancellationToken); + + await foreach (var token in output) + { + yield return new StreamingChatMessageContent(AuthorRole.Assistant, token); + } + } + + private IAsyncEnumerable InferChatHistory(ChatHistory chatHistory, PromptExecutionSettings? executionSettings, CancellationToken cancellationToken) { var settings = executionSettings != null - ? LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings) - : _defaultRequestSettings; + ? LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings) + : _defaultRequestSettings; var prompt = _getFormattedPrompt(chatHistory); var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken); var output = _outputTransform.TransformAsync(result); - - await foreach (var token in output) - { - yield return new StreamingChatMessageContent(AuthorRole.Assistant, token); - } + return output; } /// @@ -130,4 +130,81 @@ private string _getFormattedPrompt(ChatHistory chatHistory) return prompt; } + + private string _getFormattedPrompt(IEnumerable messages) + { + string prompt; + if (_isStatefulExecutor) + { + var state = (InteractiveExecutorState)((StatefulExecutorBase)_model).GetStateData(); + if (state.IsPromptRun) + { + prompt = _historyTransform.HistoryToText(messages.ToLLamaSharpChatHistory()); + } + else + { + ChatHistory tempHistory = new(); + tempHistory.AddUserMessage(messages.Last().Text ?? ""); + prompt = _historyTransform.HistoryToText(tempHistory.ToLLamaSharpChatHistory()); + } + } + else + { + prompt = _historyTransform.HistoryToText(messages.ToLLamaSharpChatHistory()); + } + + return prompt; + } + + public async Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + var output = InferChatMessage(messages, options, cancellationToken); + + var sb = new StringBuilder(); + await foreach (var token in output) + { + sb.Append(token); + } + + return new ChatResponse(new ChatMessage(ChatRole.Assistant, sb.ToString())); + } + + /// + public async IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + var output = InferChatMessage(messages, options, cancellationToken); + + await foreach (var token in output) + { + yield return new ChatResponseUpdate(ChatRole.Assistant, token); + } + } + + private IAsyncEnumerable InferChatMessage(IEnumerable messages, ChatOptions? options, CancellationToken cancellationToken) + { + var settings = options != null + ? LLamaSharpPromptExecutionSettings.FromRequestSettings(options) + : _defaultRequestSettings; + + var prompt = _getFormattedPrompt(messages); + var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken); + + var output = _outputTransform.TransformAsync(result); + return output; + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + throw new NotImplementedException(); + } + + /// + public void Dispose() + { + if (_outputTransform is IDisposable disposable) + { + disposable.Dispose(); + } + } } diff --git a/LLama.SemanticKernel/ExtensionMethods.cs b/LLama.SemanticKernel/ExtensionMethods.cs index ba1b74479..3b8d73fb4 100644 --- a/LLama.SemanticKernel/ExtensionMethods.cs +++ b/LLama.SemanticKernel/ExtensionMethods.cs @@ -1,4 +1,5 @@ using LLama.Sampling; +using Microsoft.Extensions.AI; using Microsoft.SemanticKernel.ChatCompletion; using AuthorRole = LLama.Common.AuthorRole; @@ -26,6 +27,26 @@ public static LLama.Common.ChatHistory ToLLamaSharpChatHistory(this ChatHistory return history; } + public static LLama.Common.ChatHistory ToLLamaSharpChatHistory(this IEnumerable messages, bool ignoreCase = true) + { + if (messages is null) + { + throw new ArgumentNullException(nameof(messages)); + } + + var history = new LLama.Common.ChatHistory(); + + foreach (var chat in messages) + { + if (!Enum.TryParse(chat.Role.Value, ignoreCase, out var role)) + role = AuthorRole.Unknown; + + history.AddMessage(role, chat.Text ?? ""); + } + + return history; + } + /// /// Convert LLamaSharpPromptExecutionSettings to LLamaSharp InferenceParams /// diff --git a/LLama.SemanticKernel/LLamaSharpPromptExecutionSettings.cs b/LLama.SemanticKernel/LLamaSharpPromptExecutionSettings.cs index 77fe9a75c..34b166ccd 100644 --- a/LLama.SemanticKernel/LLamaSharpPromptExecutionSettings.cs +++ b/LLama.SemanticKernel/LLamaSharpPromptExecutionSettings.cs @@ -1,3 +1,4 @@ +using Microsoft.Extensions.AI; using Microsoft.SemanticKernel; using System.Text.Json; using System.Text.Json.Serialization; @@ -97,6 +98,47 @@ public static LLamaSharpPromptExecutionSettings FromRequestSettings(PromptExecut throw new ArgumentException($"Invalid request settings, cannot convert to {nameof(LLamaSharpPromptExecutionSettings)}", nameof(requestSettings)); } + internal static LLamaSharpPromptExecutionSettings FromRequestSettings(ChatOptions? options, int? defaultMaxTokens = null) + { + if (options == null) + { + return new LLamaSharpPromptExecutionSettings + { + MaxTokens = defaultMaxTokens + }; + } + + // Handle nullable float? to double conversion and nullability + double GetDoubleOrDefault(float? value, double defaultValue = 0.0) => value.HasValue ? (double)value.Value : defaultValue; + + // Handle StopSequences: ensure always IList + IList stopSequences = options.StopSequences != null + ? new List(options.StopSequences) + : new List(); + + // ResultsPerPrompt, MaxTokens, TokenSelectionBiases, ResponseFormat: check for property existence + // Since these properties do not exist on ChatOptions, use defaults + int resultsPerPrompt = 1; + int? maxTokens = defaultMaxTokens; + IDictionary tokenSelectionBiases = new Dictionary(); + string responseFormat = options.ResponseFormat?.ToString() ?? string.Empty; + + var settings = new LLamaSharpPromptExecutionSettings + { + Temperature = GetDoubleOrDefault(options.Temperature), + TopP = GetDoubleOrDefault(options.TopP), + PresencePenalty = GetDoubleOrDefault(options.PresencePenalty), + FrequencyPenalty = GetDoubleOrDefault(options.FrequencyPenalty), + StopSequences = stopSequences, + ResultsPerPrompt = resultsPerPrompt, + MaxTokens = maxTokens ?? options.MaxOutputTokens, + TokenSelectionBiases = tokenSelectionBiases, + ResponseFormat = responseFormat + }; + + return settings; + } + private static readonly JsonSerializerOptions SerializerOptions = new() { WriteIndented = true, diff --git a/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs b/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs index d50945117..08c015566 100644 --- a/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs +++ b/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs @@ -1,11 +1,12 @@ using LLama; +using Microsoft.Extensions.AI; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Embeddings; namespace LLamaSharp.SemanticKernel.TextEmbedding; public sealed class LLamaSharpEmbeddingGeneration - : ITextEmbeddingGenerationService + : IEmbeddingGenerator, ITextEmbeddingGenerationService, IDisposable { private readonly LLamaEmbedder _embedder; @@ -28,4 +29,19 @@ public async Task>> GenerateEmbeddingsAsync(IList + public object? GetService(Type serviceType, object? serviceKey = null) + { + throw new NotImplementedException(); + } + + /// + public void Dispose() + { + if (_embedder is IDisposable disposable) + { + disposable.Dispose(); + } + } } diff --git a/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs b/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs index d75a8d4b4..b16c16db8 100644 --- a/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs +++ b/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs @@ -1,16 +1,27 @@ -using LLamaSharp.SemanticKernel; +using LLamaSharp.SemanticKernel; +using Microsoft.Extensions.AI; using Microsoft.SemanticKernel; namespace LLama.Unittest.SemanticKernel { public class ChatRequestSettingsTests { - [Fact] - public void ChatRequestSettings_FromRequestSettingsNull() + public static IEnumerable NullRequestSettingsData() + { + yield return new object[] { null, typeof(PromptExecutionSettings) }; + yield return new object[] { null, typeof(ChatOptions) }; + } + + [Theory] + [MemberData(nameof(NullRequestSettingsData))] + public void ChatRequestSettings_FromRequestSettingsNull(object settings, Type botType) { // Arrange // Act - var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(null, null); + LLamaSharpPromptExecutionSettings requestSettings = botType == typeof(PromptExecutionSettings) + ? LLamaSharpPromptExecutionSettings.FromRequestSettings((PromptExecutionSettings?)settings, null) + : LLamaSharpPromptExecutionSettings.FromRequestSettings((ChatOptions?)settings, null); + // Assert Assert.NotNull(requestSettings); @@ -26,12 +37,15 @@ public void ChatRequestSettings_FromRequestSettingsNull() Assert.Equal(0, requestSettings.TopP); } - [Fact] - public void ChatRequestSettings_FromRequestSettingsNullWithMaxTokens() + [Theory] + [MemberData(nameof(NullRequestSettingsData))] + public void ChatRequestSettings_FromRequestSettingsNullWithMaxTokens(object settings, Type botType) { // Arrange // Act - var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(null, 200); + LLamaSharpPromptExecutionSettings requestSettings = botType == typeof(PromptExecutionSettings) + ? LLamaSharpPromptExecutionSettings.FromRequestSettings((PromptExecutionSettings?)settings, 200) + : LLamaSharpPromptExecutionSettings.FromRequestSettings((ChatOptions?)settings, 200); // Assert Assert.NotNull(requestSettings); diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs index 76f5d6c77..4aaee0cd6 100644 --- a/LLama/Native/LLamaContextParams.cs +++ b/LLama/Native/LLamaContextParams.cs @@ -64,7 +64,12 @@ public struct LLamaContextParams /// Attention type to use for embeddings /// public LLamaAttentionType attention_type; - + + /// + /// When to enable Flash Attention + /// + public LLamaAttentionType flash_attn_type; + /// /// RoPE base frequency, 0 = from model /// diff --git a/LLama/Native/LLamaFlashAttnType.cs b/LLama/Native/LLamaFlashAttnType.cs new file mode 100644 index 000000000..116fbb298 --- /dev/null +++ b/LLama/Native/LLamaFlashAttnType.cs @@ -0,0 +1,23 @@ +namespace LLama.Native; + +/// +/// +/// +/// llama_flash_attn_type +public enum LLamaFlashAttnType +{ + /// + /// + /// + Auto = -1, + + /// + /// + /// + Disable = 0, + + /// + /// + /// + Enabled = 1, +} \ No newline at end of file diff --git a/LLama/Native/LLamaFtype.cs b/LLama/Native/LLamaFtype.cs index 705f8032e..70db7986d 100644 --- a/LLama/Native/LLamaFtype.cs +++ b/LLama/Native/LLamaFtype.cs @@ -202,6 +202,11 @@ public enum LLamaFtype /// LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, + /// + /// except 1d tensors + /// + LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38, + /// /// File type was not specified /// diff --git a/LLama/Native/LLamaModelParams.cs b/LLama/Native/LLamaModelParams.cs index acb024852..20a9c99d0 100644 --- a/LLama/Native/LLamaModelParams.cs +++ b/LLama/Native/LLamaModelParams.cs @@ -99,7 +99,17 @@ public bool check_tensors readonly get => Convert.ToBoolean(_check_tensors); set => _check_tensors = Convert.ToSByte(value); } - private sbyte _check_tensors; + private sbyte _check_tensors; + + /// + /// use extra buffer types (used for weight repacking) + /// + public bool use_extra_bufts + { + readonly get => Convert.ToBoolean(_use_extra_bufts); + set => _use_extra_bufts = Convert.ToSByte(value); + } + private sbyte _use_extra_bufts; /// /// Create a LLamaModelParams with default values diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index e26619b26..915b08be8 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -294,7 +294,7 @@ static SafeLLamaContextHandle() /// Get the exact size needed to copy the state of a single sequence /// [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] - private static extern nuint llama_state_seq_get_size(SafeLLamaContextHandle ctx, LLamaSeqId seqId); + private static extern nuint llama_state_seq_get_size(SafeLLamaContextHandle ctx, LLamaSeqId seqId, uint llama_state_seq_flags); /// /// Copy the state of a single sequence into the specified buffer @@ -303,9 +303,10 @@ static SafeLLamaContextHandle() /// /// /// + /// /// [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] - private static extern unsafe nuint llama_state_seq_get_data(SafeLLamaContextHandle ctx, byte* dst, nuint size, LLamaSeqId seqId); + private static extern unsafe nuint llama_state_seq_get_data(SafeLLamaContextHandle ctx, byte* dst, nuint size, LLamaSeqId seqId, uint llama_state_seq_flags); /// /// Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence @@ -314,12 +315,13 @@ static SafeLLamaContextHandle() /// /// /// + /// /// /// - Positive: Ok /// - Zero: Failed to load /// [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] - private static extern unsafe nuint llama_state_seq_set_data(SafeLLamaContextHandle ctx, byte* src, nuint size, LLamaSeqId destSeqId); + private static extern unsafe nuint llama_state_seq_set_data(SafeLLamaContextHandle ctx, byte* src, nuint size, LLamaSeqId destSeqId, uint llama_state_seq_flags); [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern LLamaPerfContextTimings llama_perf_context(SafeLLamaContextHandle ctx); @@ -680,7 +682,7 @@ public nuint GetStateSize() /// public nuint GetStateSize(LLamaSeqId sequence) { - return llama_state_seq_get_size(this, sequence); + return llama_state_seq_get_size(this, sequence, 0u); } /// @@ -712,7 +714,7 @@ public unsafe nuint GetState(byte* dest, nuint size, LLamaSeqId sequence) if (size < required) throw new ArgumentOutOfRangeException(nameof(size), $"Allocated space is too small, {size} < {required}"); - return llama_state_seq_get_data(this, dest, size, sequence); + return llama_state_seq_get_data(this, dest, size, sequence, 0u); } /// @@ -735,7 +737,7 @@ public unsafe nuint SetState(byte* src, nuint size) /// Number of bytes read from the src pointer public unsafe nuint SetState(byte* src, nuint size, LLamaSeqId sequence) { - return llama_state_seq_set_data(this, src, size, sequence); + return llama_state_seq_set_data(this, src, size, sequence, 1u); } #endregion