Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 93 additions & 16 deletions LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using LLama;
using LLama.Abstractions;
using Microsoft.Extensions.AI;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using System.Runtime.CompilerServices;
Expand All @@ -12,7 +13,7 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion;
/// <summary>
/// LLamaSharp ChatCompletion
/// </summary>
public sealed class LLamaSharpChatCompletion : IChatCompletionService
public sealed class LLamaSharpChatCompletion : IChatClient, IChatCompletionService
{
private readonly ILLamaExecutor _model;
private readonly LLamaSharpPromptExecutionSettings _defaultRequestSettings;
Expand Down Expand Up @@ -64,14 +65,7 @@ public ChatHistory CreateNewChat(string? instructions = "")
/// <inheritdoc/>
public async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
{
var settings = executionSettings != null
? LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings)
: _defaultRequestSettings;

var prompt = _getFormattedPrompt(chatHistory);
var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken);

var output = _outputTransform.TransformAsync(result);
var output = InferChatHistory(chatHistory, executionSettings, cancellationToken);

var sb = new StringBuilder();
await foreach (var token in output)
Expand All @@ -84,20 +78,26 @@ public async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync

/// <inheritdoc/>
public async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var output = InferChatHistory(chatHistory, executionSettings, cancellationToken);

await foreach (var token in output)
{
yield return new StreamingChatMessageContent(AuthorRole.Assistant, token);
}
}

private IAsyncEnumerable<string> InferChatHistory(ChatHistory chatHistory, PromptExecutionSettings? executionSettings, CancellationToken cancellationToken)
{
var settings = executionSettings != null
? LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings)
: _defaultRequestSettings;
? LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings)
: _defaultRequestSettings;

var prompt = _getFormattedPrompt(chatHistory);
var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken);

var output = _outputTransform.TransformAsync(result);

await foreach (var token in output)
{
yield return new StreamingChatMessageContent(AuthorRole.Assistant, token);
}
return output;
}

/// <summary>
Expand Down Expand Up @@ -130,4 +130,81 @@ private string _getFormattedPrompt(ChatHistory chatHistory)

return prompt;
}

private string _getFormattedPrompt(IEnumerable<ChatMessage> messages)
{
string prompt;
if (_isStatefulExecutor)
{
var state = (InteractiveExecutorState)((StatefulExecutorBase)_model).GetStateData();
if (state.IsPromptRun)
{
prompt = _historyTransform.HistoryToText(messages.ToLLamaSharpChatHistory());
}
else
{
ChatHistory tempHistory = new();
tempHistory.AddUserMessage(messages.Last().Text ?? "");
prompt = _historyTransform.HistoryToText(tempHistory.ToLLamaSharpChatHistory());
}
}
else
{
prompt = _historyTransform.HistoryToText(messages.ToLLamaSharpChatHistory());
}

return prompt;
}

public async Task<ChatResponse> GetResponseAsync(IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
{
var output = InferChatMessage(messages, options, cancellationToken);

var sb = new StringBuilder();
await foreach (var token in output)
{
sb.Append(token);
}

return new ChatResponse(new ChatMessage(ChatRole.Assistant, sb.ToString()));
}

/// <inheritdoc/>
public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
{
var output = InferChatMessage(messages, options, cancellationToken);

await foreach (var token in output)
{
yield return new ChatResponseUpdate(ChatRole.Assistant, token);
}
}

private IAsyncEnumerable<string> InferChatMessage(IEnumerable<ChatMessage> messages, ChatOptions? options, CancellationToken cancellationToken)
{
var settings = options != null
? LLamaSharpPromptExecutionSettings.FromRequestSettings(options)
: _defaultRequestSettings;

var prompt = _getFormattedPrompt(messages);
var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken);

var output = _outputTransform.TransformAsync(result);
return output;
}

/// <inheritdoc/>
public object? GetService(Type serviceType, object? serviceKey = null)
{
throw new NotImplementedException();
}

/// <inheritdoc/>
public void Dispose()
{
if (_outputTransform is IDisposable disposable)
{
disposable.Dispose();
}
}
}
21 changes: 21 additions & 0 deletions LLama.SemanticKernel/ExtensionMethods.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LLama.Sampling;
using Microsoft.Extensions.AI;
using Microsoft.SemanticKernel.ChatCompletion;
using AuthorRole = LLama.Common.AuthorRole;

Expand Down Expand Up @@ -26,6 +27,26 @@ public static LLama.Common.ChatHistory ToLLamaSharpChatHistory(this ChatHistory
return history;
}

public static LLama.Common.ChatHistory ToLLamaSharpChatHistory(this IEnumerable<ChatMessage> messages, bool ignoreCase = true)
{
if (messages is null)
{
throw new ArgumentNullException(nameof(messages));
}

var history = new LLama.Common.ChatHistory();

foreach (var chat in messages)
{
if (!Enum.TryParse<AuthorRole>(chat.Role.Value, ignoreCase, out var role))
role = AuthorRole.Unknown;

history.AddMessage(role, chat.Text ?? "");
}

return history;
}

/// <summary>
/// Convert LLamaSharpPromptExecutionSettings to LLamaSharp InferenceParams
/// </summary>
Expand Down
42 changes: 42 additions & 0 deletions LLama.SemanticKernel/LLamaSharpPromptExecutionSettings.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using Microsoft.Extensions.AI;
using Microsoft.SemanticKernel;
using System.Text.Json;
using System.Text.Json.Serialization;
Expand Down Expand Up @@ -97,6 +98,47 @@ public static LLamaSharpPromptExecutionSettings FromRequestSettings(PromptExecut
throw new ArgumentException($"Invalid request settings, cannot convert to {nameof(LLamaSharpPromptExecutionSettings)}", nameof(requestSettings));
}

internal static LLamaSharpPromptExecutionSettings FromRequestSettings(ChatOptions? options, int? defaultMaxTokens = null)
{
if (options == null)
{
return new LLamaSharpPromptExecutionSettings
{
MaxTokens = defaultMaxTokens
};
}

// Handle nullable float? to double conversion and nullability
double GetDoubleOrDefault(float? value, double defaultValue = 0.0) => value.HasValue ? (double)value.Value : defaultValue;

// Handle StopSequences: ensure always IList<string>
IList<string> stopSequences = options.StopSequences != null
? new List<string>(options.StopSequences)
: new List<string>();

// ResultsPerPrompt, MaxTokens, TokenSelectionBiases, ResponseFormat: check for property existence
// Since these properties do not exist on ChatOptions, use defaults
int resultsPerPrompt = 1;
int? maxTokens = defaultMaxTokens;
IDictionary<int, int> tokenSelectionBiases = new Dictionary<int, int>();
string responseFormat = options.ResponseFormat?.ToString() ?? string.Empty;

var settings = new LLamaSharpPromptExecutionSettings
{
Temperature = GetDoubleOrDefault(options.Temperature),
TopP = GetDoubleOrDefault(options.TopP),
PresencePenalty = GetDoubleOrDefault(options.PresencePenalty),
FrequencyPenalty = GetDoubleOrDefault(options.FrequencyPenalty),
StopSequences = stopSequences,
ResultsPerPrompt = resultsPerPrompt,
MaxTokens = maxTokens ?? options.MaxOutputTokens,
TokenSelectionBiases = tokenSelectionBiases,
ResponseFormat = responseFormat
};

return settings;
}

private static readonly JsonSerializerOptions SerializerOptions = new()
{
WriteIndented = true,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
using LLama;
using Microsoft.Extensions.AI;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Embeddings;

namespace LLamaSharp.SemanticKernel.TextEmbedding;

public sealed class LLamaSharpEmbeddingGeneration
: ITextEmbeddingGenerationService
: IEmbeddingGenerator, ITextEmbeddingGenerationService, IDisposable
{
private readonly LLamaEmbedder _embedder;

Expand All @@ -28,4 +29,19 @@ public async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<st

return result;
}

/// <inheritdoc/>
public object? GetService(Type serviceType, object? serviceKey = null)
{
throw new NotImplementedException();
}

/// <inheritdoc/>
public void Dispose()
{
if (_embedder is IDisposable disposable)
{
disposable.Dispose();
}
}
}
28 changes: 21 additions & 7 deletions LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
using LLamaSharp.SemanticKernel;
using LLamaSharp.SemanticKernel;
using Microsoft.Extensions.AI;
using Microsoft.SemanticKernel;

namespace LLama.Unittest.SemanticKernel
{
public class ChatRequestSettingsTests
{
[Fact]
public void ChatRequestSettings_FromRequestSettingsNull()
public static IEnumerable<object[]> NullRequestSettingsData()
{
yield return new object[] { null, typeof(PromptExecutionSettings) };
yield return new object[] { null, typeof(ChatOptions) };
}

[Theory]
[MemberData(nameof(NullRequestSettingsData))]
public void ChatRequestSettings_FromRequestSettingsNull(object settings, Type botType)
{
// Arrange
// Act
var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(null, null);
LLamaSharpPromptExecutionSettings requestSettings = botType == typeof(PromptExecutionSettings)
? LLamaSharpPromptExecutionSettings.FromRequestSettings((PromptExecutionSettings?)settings, null)
: LLamaSharpPromptExecutionSettings.FromRequestSettings((ChatOptions?)settings, null);


// Assert
Assert.NotNull(requestSettings);
Expand All @@ -26,12 +37,15 @@ public void ChatRequestSettings_FromRequestSettingsNull()
Assert.Equal(0, requestSettings.TopP);
}

[Fact]
public void ChatRequestSettings_FromRequestSettingsNullWithMaxTokens()
[Theory]
[MemberData(nameof(NullRequestSettingsData))]
public void ChatRequestSettings_FromRequestSettingsNullWithMaxTokens(object settings, Type botType)
{
// Arrange
// Act
var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(null, 200);
LLamaSharpPromptExecutionSettings requestSettings = botType == typeof(PromptExecutionSettings)
? LLamaSharpPromptExecutionSettings.FromRequestSettings((PromptExecutionSettings?)settings, 200)
: LLamaSharpPromptExecutionSettings.FromRequestSettings((ChatOptions?)settings, 200);

// Assert
Assert.NotNull(requestSettings);
Expand Down
7 changes: 6 additions & 1 deletion LLama/Native/LLamaContextParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@ public struct LLamaContextParams
/// Attention type to use for embeddings
/// </summary>
public LLamaAttentionType attention_type;


/// <summary>
/// When to enable Flash Attention
/// </summary>
public LLamaAttentionType flash_attn_type;

/// <summary>
/// RoPE base frequency, 0 = from model
/// </summary>
Expand Down
23 changes: 23 additions & 0 deletions LLama/Native/LLamaFlashAttnType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
namespace LLama.Native;

/// <summary>
///
/// </summary>
/// <remarks>llama_flash_attn_type</remarks>
public enum LLamaFlashAttnType
{
/// <summary>
///
/// </summary>
Auto = -1,

/// <summary>
///
/// </summary>
Disable = 0,

/// <summary>
///
/// </summary>
Enabled = 1,
}
5 changes: 5 additions & 0 deletions LLama/Native/LLamaFtype.cs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ public enum LLamaFtype
/// </summary>
LLAMA_FTYPE_MOSTLY_TQ2_0 = 37,

/// <summary>
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38,

/// <summary>
/// File type was not specified
/// </summary>
Expand Down
12 changes: 11 additions & 1 deletion LLama/Native/LLamaModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,17 @@ public bool check_tensors
readonly get => Convert.ToBoolean(_check_tensors);
set => _check_tensors = Convert.ToSByte(value);
}
private sbyte _check_tensors;
private sbyte _check_tensors;

/// <summary>
/// use extra buffer types (used for weight repacking)
/// </summary>
public bool use_extra_bufts
{
readonly get => Convert.ToBoolean(_use_extra_bufts);
set => _use_extra_bufts = Convert.ToSByte(value);
}
private sbyte _use_extra_bufts;

/// <summary>
/// Create a LLamaModelParams with default values
Expand Down
Loading