Skip to content

Commit 3f176be

Browse files
authored
Merge pull request #964 from stephentoub/meai
Add Microsoft.Extensions.AI support for IChatClient / IEmbeddingGenerator
2 parents b2c5e3f + 0d7875f commit 3f176be

File tree

6 files changed

+265
-3
lines changed

6 files changed

+265
-3
lines changed

LLama.Unittest/LLamaEmbedderTests.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using LLama.Common;
22
using LLama.Extensions;
33
using LLama.Native;
4+
using Microsoft.Extensions.AI;
45
using Xunit.Abstractions;
56

67
namespace LLama.Unittest;
@@ -41,6 +42,27 @@ private async Task CompareEmbeddings(string modelPath)
4142
var spoon = (await embedder.GetEmbeddings("The spoon is not real")).Single().EuclideanNormalization();
4243
Assert.DoesNotContain(float.NaN, spoon);
4344

45+
var generator = (IEmbeddingGenerator<string, Embedding<float>>)embedder;
46+
Assert.NotNull(generator.Metadata);
47+
Assert.Equal(nameof(LLamaEmbedder), generator.Metadata.ProviderName);
48+
Assert.NotNull(generator.Metadata.ModelId);
49+
Assert.NotEmpty(generator.Metadata.ModelId);
50+
Assert.Same(embedder, generator.GetService<LLamaEmbedder>());
51+
Assert.Same(generator, generator.GetService<IEmbeddingGenerator<string, Embedding<float>>>());
52+
Assert.Null(generator.GetService<string>());
53+
54+
var embeddings = await generator.GenerateAsync(
55+
[
56+
"The cat is cute",
57+
"The kitten is cute",
58+
"The spoon is not real"
59+
]);
60+
Assert.All(cat.Zip(embeddings[0].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001));
61+
Assert.All(kitten.Zip(embeddings[1].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001));
62+
Assert.All(spoon.Zip(embeddings[2].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001));
63+
Assert.True(embeddings.Usage?.InputTokenCount is 19 or 20);
64+
Assert.True(embeddings.Usage?.TotalTokenCount is 19 or 20);
65+
4466
_testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]");
4567
_testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]");
4668
_testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]");
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Runtime.CompilerServices;
5+
using System.Text;
6+
using System.Threading;
7+
using System.Threading.Tasks;
8+
using LLama.Common;
9+
using LLama.Sampling;
10+
using Microsoft.Extensions.AI;
11+
12+
namespace LLama.Abstractions;
13+
14+
/// <summary>
15+
/// Extension methods to the <see cref="LLamaExecutorExtensions" /> interface.
16+
/// </summary>
17+
public static class LLamaExecutorExtensions
18+
{
19+
/// <summary>Gets an <see cref="IChatClient"/> instance for the specified <see cref="ILLamaExecutor"/>.</summary>
20+
/// <param name="executor">The executor.</param>
21+
/// <param name="historyTransform">The <see cref="IHistoryTransform"/> to use to transform an input list messages into a prompt.</param>
22+
/// <param name="outputTransform">The <see cref="ITextStreamTransform"/> to use to transform the output into text.</param>
23+
/// <returns>An <see cref="IChatClient"/> instance for the provided <see cref="ILLamaExecutor" />.</returns>
24+
/// <exception cref="ArgumentNullException"><paramref name="executor"/> is null.</exception>
25+
public static IChatClient AsChatClient(
26+
this ILLamaExecutor executor,
27+
IHistoryTransform? historyTransform = null,
28+
ITextStreamTransform? outputTransform = null) =>
29+
new LLamaExecutorChatClient(
30+
executor ?? throw new ArgumentNullException(nameof(executor)),
31+
historyTransform,
32+
outputTransform);
33+
34+
private sealed class LLamaExecutorChatClient(
35+
ILLamaExecutor executor,
36+
IHistoryTransform? historyTransform = null,
37+
ITextStreamTransform? outputTransform = null) : IChatClient
38+
{
39+
private static readonly InferenceParams s_defaultParams = new();
40+
private static readonly DefaultSamplingPipeline s_defaultPipeline = new();
41+
private static readonly string[] s_antiPrompts = ["User:", "Assistant:", "System:"];
42+
[ThreadStatic]
43+
private static Random? t_random;
44+
45+
private readonly ILLamaExecutor _executor = executor;
46+
private readonly IHistoryTransform _historyTransform = historyTransform ?? new AppendAssistantHistoryTransform();
47+
private readonly ITextStreamTransform _outputTransform = outputTransform ??
48+
new LLamaTransforms.KeywordTextOutputStreamTransform(s_antiPrompts);
49+
50+
/// <inheritdoc/>
51+
public ChatClientMetadata Metadata { get; } = new(nameof(LLamaExecutorChatClient));
52+
53+
/// <inheritdoc/>
54+
public void Dispose() { }
55+
56+
/// <inheritdoc/>
57+
public TService? GetService<TService>(object? key = null) where TService : class =>
58+
typeof(TService) == typeof(ILLamaExecutor) ? (TService)_executor :
59+
this as TService;
60+
61+
/// <inheritdoc/>
62+
public async Task<ChatCompletion> CompleteAsync(
63+
IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default)
64+
{
65+
var result = _executor.InferAsync(CreatePrompt(chatMessages), CreateInferenceParams(options), cancellationToken);
66+
67+
StringBuilder text = new();
68+
await foreach (var token in _outputTransform.TransformAsync(result))
69+
{
70+
text.Append(token);
71+
}
72+
73+
return new(new ChatMessage(ChatRole.Assistant, text.ToString()))
74+
{
75+
CreatedAt = DateTime.UtcNow,
76+
};
77+
}
78+
79+
/// <inheritdoc/>
80+
public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
81+
IList<ChatMessage> chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
82+
{
83+
var result = _executor.InferAsync(CreatePrompt(chatMessages), CreateInferenceParams(options), cancellationToken);
84+
85+
await foreach (var token in _outputTransform.TransformAsync(result))
86+
{
87+
yield return new()
88+
{
89+
CreatedAt = DateTime.UtcNow,
90+
Role = ChatRole.Assistant,
91+
Text = token,
92+
};
93+
}
94+
}
95+
96+
/// <summary>Format the chat messages into a string prompt.</summary>
97+
private string CreatePrompt(IList<ChatMessage> messages)
98+
{
99+
if (messages is null)
100+
{
101+
throw new ArgumentNullException(nameof(messages));
102+
}
103+
104+
ChatHistory history = new();
105+
106+
if (_executor is not StatefulExecutorBase seb ||
107+
seb.GetStateData() is InteractiveExecutor.InteractiveExecutorState { IsPromptRun: true })
108+
{
109+
foreach (var message in messages)
110+
{
111+
history.AddMessage(
112+
message.Role == ChatRole.System ? AuthorRole.System :
113+
message.Role == ChatRole.Assistant ? AuthorRole.Assistant :
114+
AuthorRole.User,
115+
string.Concat(message.Contents.OfType<TextContent>()));
116+
}
117+
}
118+
else
119+
{
120+
// Stateless executor with IsPromptRun = false: use only the last message.
121+
history.AddMessage(AuthorRole.User, string.Concat(messages.LastOrDefault()?.Contents.OfType<TextContent>() ?? []));
122+
}
123+
124+
return _historyTransform.HistoryToText(history);
125+
}
126+
127+
/// <summary>Convert the chat options to inference parameters.</summary>
128+
private static InferenceParams? CreateInferenceParams(ChatOptions? options)
129+
{
130+
List<string> antiPrompts = new(s_antiPrompts);
131+
if (options?.AdditionalProperties?.TryGetValue(nameof(InferenceParams.AntiPrompts), out IReadOnlyList<string>? anti) is true)
132+
{
133+
antiPrompts.AddRange(anti);
134+
}
135+
136+
return new()
137+
{
138+
AntiPrompts = antiPrompts,
139+
TokensKeep = options?.AdditionalProperties?.TryGetValue(nameof(InferenceParams.TokensKeep), out int tk) is true ? tk : s_defaultParams.TokensKeep,
140+
MaxTokens = options?.MaxOutputTokens ?? 256, // arbitrary upper limit
141+
SamplingPipeline = new DefaultSamplingPipeline()
142+
{
143+
AlphaFrequency = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.AlphaFrequency), out float af) is true ? af : s_defaultPipeline.AlphaFrequency,
144+
AlphaPresence = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.AlphaPresence), out float ap) is true ? ap : s_defaultPipeline.AlphaPresence,
145+
PenalizeEOS = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PenalizeEOS), out bool eos) is true ? eos : s_defaultPipeline.PenalizeEOS,
146+
PenalizeNewline = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PenalizeNewline), out bool pnl) is true ? pnl : s_defaultPipeline.PenalizeNewline,
147+
RepeatPenalty = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.RepeatPenalty), out float rp) is true ? rp : s_defaultPipeline.RepeatPenalty,
148+
RepeatPenaltyCount = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.RepeatPenaltyCount), out int rpc) is true ? rpc : s_defaultPipeline.RepeatPenaltyCount,
149+
Grammar = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.Grammar), out Grammar? g) is true ? g : s_defaultPipeline.Grammar,
150+
MinKeep = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.MinKeep), out int mk) is true ? mk : s_defaultPipeline.MinKeep,
151+
MinP = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.MinP), out float mp) is true ? mp : s_defaultPipeline.MinP,
152+
Seed = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.Seed), out uint seed) is true ? seed : (uint)(t_random ??= new()).Next(),
153+
TailFreeZ = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.TailFreeZ), out float tfz) is true ? tfz : s_defaultPipeline.TailFreeZ,
154+
Temperature = options?.Temperature ?? 0,
155+
TopP = options?.TopP ?? 0,
156+
TopK = options?.TopK ?? s_defaultPipeline.TopK,
157+
TypicalP = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.TypicalP), out float tp) is true ? tp : s_defaultPipeline.TypicalP,
158+
},
159+
};
160+
}
161+
162+
/// <summary>A default transform that appends "Assistant: " to the end.</summary>
163+
private sealed class AppendAssistantHistoryTransform : LLamaTransforms.DefaultHistoryTransform
164+
{
165+
public override string HistoryToText(ChatHistory history) =>
166+
$"{base.HistoryToText(history)}{AuthorRole.Assistant}: ";
167+
}
168+
}
169+
}

LLama/Extensions/SpanNormalizationExtensions.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,18 @@ public static Span<float> EuclideanNormalization(this Span<float> vector)
8181
return vector;
8282
}
8383

84+
/// <summary>
85+
/// Creates a new array containing an L2 normalization of the input vector.
86+
/// </summary>
87+
/// <param name="vector"></param>
88+
/// <returns>The same span</returns>
89+
public static float[] EuclideanNormalization(this ReadOnlySpan<float> vector)
90+
{
91+
var result = new float[vector.Length];
92+
TensorPrimitives.Divide(vector, TensorPrimitives.Norm(vector), result);
93+
return result;
94+
}
95+
8496
/// <summary>
8597
/// <b>In-place</b> apply p-normalization. https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm
8698
/// <list type="bullet">
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Diagnostics;
4+
using System.Threading;
5+
using System.Threading.Tasks;
6+
using LLama.Native;
7+
using Microsoft.Extensions.AI;
8+
9+
namespace LLama;
10+
11+
public partial class LLamaEmbedder
12+
: IEmbeddingGenerator<string, Embedding<float>>
13+
{
14+
private EmbeddingGeneratorMetadata? _metadata;
15+
16+
/// <inheritdoc />
17+
EmbeddingGeneratorMetadata IEmbeddingGenerator<string, Embedding<float>>.Metadata =>
18+
_metadata ??= new(
19+
nameof(LLamaEmbedder),
20+
modelId: Context.NativeHandle.ModelHandle.ReadMetadata().TryGetValue("general.name", out var name) ? name : null,
21+
dimensions: EmbeddingSize);
22+
23+
/// <inheritdoc />
24+
TService? IEmbeddingGenerator<string, Embedding<float>>.GetService<TService>(object? key) where TService : class =>
25+
typeof(TService) == typeof(LLamaContext) ? (TService)(object)Context :
26+
this as TService;
27+
28+
/// <inheritdoc />
29+
async Task<GeneratedEmbeddings<Embedding<float>>> IEmbeddingGenerator<string, Embedding<float>>.GenerateAsync(IEnumerable<string> values, EmbeddingGenerationOptions? options, CancellationToken cancellationToken)
30+
{
31+
if (Context.NativeHandle.PoolingType == LLamaPoolingType.None)
32+
{
33+
throw new NotSupportedException($"Embedding generation is not supported with {nameof(LLamaPoolingType)}.{nameof(LLamaPoolingType.None)}.");
34+
}
35+
36+
GeneratedEmbeddings<Embedding<float>> results = new()
37+
{
38+
Usage = new() { InputTokenCount = 0 },
39+
};
40+
41+
foreach (var value in values)
42+
{
43+
var (embeddings, tokenCount) = await GetEmbeddingsWithTokenCount(value, cancellationToken).ConfigureAwait(false);
44+
Debug.Assert(embeddings.Count == 1, "Should be one and only one embedding when pooling is enabled.");
45+
46+
results.Usage.InputTokenCount += tokenCount;
47+
results.Add(new Embedding<float>(embeddings[0]) { CreatedAt = DateTime.UtcNow });
48+
}
49+
50+
results.Usage.TotalTokenCount = results.Usage.InputTokenCount;
51+
52+
return results;
53+
}
54+
}

LLama/LLamaEmbedder.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace LLama;
1212
/// <summary>
1313
/// Generate high dimensional embedding vectors from text
1414
/// </summary>
15-
public sealed class LLamaEmbedder
15+
public sealed partial class LLamaEmbedder
1616
: IDisposable
1717
{
1818
/// <summary>
@@ -58,7 +58,10 @@ public void Dispose()
5858
/// <returns></returns>
5959
/// <exception cref="RuntimeError"></exception>
6060
/// <exception cref="NotSupportedException"></exception>
61-
public async Task<IReadOnlyList<float[]>> GetEmbeddings(string input, CancellationToken cancellationToken = default)
61+
public async Task<IReadOnlyList<float[]>> GetEmbeddings(string input, CancellationToken cancellationToken = default) =>
62+
(await GetEmbeddingsWithTokenCount(input, cancellationToken).ConfigureAwait(false)).Embeddings;
63+
64+
private async Task<(IReadOnlyList<float[]> Embeddings, int Tokens)> GetEmbeddingsWithTokenCount(string input, CancellationToken cancellationToken = default)
6265
{
6366
// Add all of the tokens to the batch
6467
var tokens = Context.Tokenize(input);
@@ -113,6 +116,6 @@ public async Task<IReadOnlyList<float[]>> GetEmbeddings(string input, Cancellati
113116

114117
Context.NativeHandle.KvCacheClear();
115118

116-
return results;
119+
return (results, tokens.Length);
117120
}
118121
}

LLama/LLamaSharp.csproj

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
</ItemGroup>
5050

5151
<ItemGroup>
52+
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" Version="8.0.0" />
53+
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.0.0-preview.9.24525.1" />
5254
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="8.0.1" />
5355
<PackageReference Include="System.Numerics.Tensors" Version="8.0.0" />
5456
</ItemGroup>

0 commit comments

Comments
 (0)