diff --git a/src/xAI.Tests/EmbeddingGeneratorTests.cs b/src/xAI.Tests/EmbeddingGeneratorTests.cs new file mode 100644 index 0000000..4914c1d --- /dev/null +++ b/src/xAI.Tests/EmbeddingGeneratorTests.cs @@ -0,0 +1,192 @@ +using Microsoft.Extensions.AI; +using Moq; +using Tests.Client.Helpers; +using xAI; +using xAI.Protocol; +using static ConfigurationExtensions; + +namespace xAI.Tests; + +public class EmbeddingGeneratorTests(ITestOutputHelper output) +{ + [SecretsFact("XAI_API_KEY")] + public async Task GrokGeneratesEmbeddings() + { + var client = new GrokClient(Configuration["XAI_API_KEY"]!); + var generator = client.AsIEmbeddingGenerator("v1"); + + var response = await generator.GenerateAsync(["Hello, world!", "How are you?"]); + + Assert.NotNull(response); + Assert.Equal(2, response.Count); + + foreach (var embedding in response) + { + Assert.NotNull(embedding.ModelId); + Assert.NotNull(embedding.CreatedAt); + Assert.NotEmpty(embedding.Vector.ToArray()); + } + } + + [Fact] + public void AsIEmbeddingGenerator_NullClient_Throws() + { + Assert.Throws("client", () => ((GrokClient)null!).AsIEmbeddingGenerator("model")); + } + + [Fact] + public void AsIEmbeddingGenerator_ProducesExpectedMetadata() + { + Uri endpoint = new("https://api.x.ai"); + string model = "v1"; + + var clientOptions = new GrokClientOptions { Endpoint = endpoint }; + var mockClient = new Mock(MockBehavior.Strict); + + var embeddingGenerator = new GrokEmbeddingGenerator(mockClient.Object, model); + var metadata = embeddingGenerator.GetService(); + + Assert.NotNull(metadata); + Assert.Equal("xai", metadata.ProviderName); + Assert.Equal(model, metadata.DefaultModelId); + } + + [Fact] + public void GetService_SuccessfullyReturnsUnderlyingClient() + { + var mockClient = new Mock(MockBehavior.Strict); + var embeddingGenerator = new GrokEmbeddingGenerator(mockClient.Object, "model"); + + Assert.Same(embeddingGenerator, embeddingGenerator.GetService>>()); + Assert.Same(mockClient.Object, embeddingGenerator.GetService()); + Assert.Same(embeddingGenerator, embeddingGenerator.GetService()); + } + + [Fact] + public async Task GenerateAsync_ExpectedRequestResponse() + { + var mockClient = new Mock(MockBehavior.Strict); + + var response = new EmbedResponse + { + Id = "test-id", + Model = "v1", + SystemFingerprint = "test-fingerprint", + Usage = new EmbeddingUsage + { + NumTextEmbeddings = 2, + NumImageEmbeddings = 0 + } + }; + + // Add first embedding + var embedding1 = new xAI.Protocol.Embedding { Index = 0 }; + var featureVector1 = new FeatureVector(); + featureVector1.FloatArray.AddRange([0.1f, 0.2f, 0.3f]); + embedding1.Embeddings.Add(featureVector1); + response.Embeddings.Add(embedding1); + + // Add second embedding + var embedding2 = new xAI.Protocol.Embedding { Index = 1 }; + var featureVector2 = new FeatureVector(); + featureVector2.FloatArray.AddRange([0.4f, 0.5f, 0.6f]); + embedding2.Embeddings.Add(featureVector2); + response.Embeddings.Add(embedding2); + + mockClient + .Setup(x => x.EmbedAsync(It.IsAny(), null, null, CancellationToken.None)) + .Returns(CallHelpers.CreateAsyncUnaryCall(response)); + + var embeddingGenerator = new GrokEmbeddingGenerator(mockClient.Object, "v1"); + + var result = await embeddingGenerator.GenerateAsync(["hello, world!", "how are you?"]); + + Assert.NotNull(result); + Assert.Equal(2, result.Count); + + Assert.NotNull(result.Usage); + Assert.Equal(2, result.Usage.InputTokenCount); + Assert.Equal(2, result.Usage.TotalTokenCount); + + var first = result[0]; + Assert.Equal("v1", first.ModelId); + Assert.NotNull(first.CreatedAt); + Assert.Equal(3, first.Vector.Length); + Assert.Equal([0.1f, 0.2f, 0.3f], first.Vector.ToArray()); + + var second = result[1]; + Assert.Equal("v1", second.ModelId); + Assert.NotNull(second.CreatedAt); + Assert.Equal(3, second.Vector.Length); + Assert.Equal([0.4f, 0.5f, 0.6f], second.Vector.ToArray()); + } + + [Fact] + public async Task GenerateAsync_MissingUsage_ReturnsNullUsage() + { + var mockClient = new Mock(MockBehavior.Strict); + + var response = new EmbedResponse + { + Id = "test-id", + Model = "v1", + }; + + // Add embedding without usage + var embedding = new xAI.Protocol.Embedding { Index = 0 }; + var featureVector = new FeatureVector(); + featureVector.FloatArray.AddRange([0.1f, 0.2f, 0.3f]); + embedding.Embeddings.Add(featureVector); + response.Embeddings.Add(embedding); + + mockClient + .Setup(x => x.EmbedAsync(It.IsAny(), null, null, CancellationToken.None)) + .Returns(CallHelpers.CreateAsyncUnaryCall(response)); + + var embeddingGenerator = new GrokEmbeddingGenerator(mockClient.Object, "v1"); + + var result = await embeddingGenerator.GenerateAsync(["hello, world!"]); + + Assert.NotNull(result); + Assert.Single(result); + Assert.Null(result.Usage); + + var first = result[0]; + Assert.Equal("v1", first.ModelId); + Assert.NotNull(first.CreatedAt); + Assert.Equal(3, first.Vector.Length); + } + + [Fact] + public async Task GenerateAsync_UsesCustomModelId() + { + var mockClient = new Mock(MockBehavior.Strict); + EmbedRequest? capturedRequest = null; + + var response = new EmbedResponse + { + Id = "test-id", + Model = "custom-model", + }; + + var embedding = new xAI.Protocol.Embedding { Index = 0 }; + var featureVector = new FeatureVector(); + featureVector.FloatArray.AddRange([0.1f, 0.2f, 0.3f]); + embedding.Embeddings.Add(featureVector); + response.Embeddings.Add(embedding); + + mockClient + .Setup(x => x.EmbedAsync(It.IsAny(), null, null, CancellationToken.None)) + .Callback((req, _, _, _) => capturedRequest = req) + .Returns(CallHelpers.CreateAsyncUnaryCall(response)); + + var embeddingGenerator = new GrokEmbeddingGenerator(mockClient.Object, "default-model"); + + var result = await embeddingGenerator.GenerateAsync( + ["hello"], + new EmbeddingGenerationOptions { ModelId = "custom-model" }); + + Assert.NotNull(capturedRequest); + Assert.Equal("custom-model", capturedRequest.Model); + } +} diff --git a/src/xAI/GrokClientExtensions.cs b/src/xAI/GrokClientExtensions.cs index 49910d7..16790c6 100644 --- a/src/xAI/GrokClientExtensions.cs +++ b/src/xAI/GrokClientExtensions.cs @@ -10,9 +10,17 @@ public static class GrokClientExtensions { /// Creates a new from the specified using the given model as the default. public static IChatClient AsIChatClient(this GrokClient client, string defaultModelId) - => new GrokChatClient(client.Channel, client.Options, defaultModelId); + => new GrokChatClient(Throw.IfNull(client).Channel, client.Options, defaultModelId); /// Creates a new from the specified using the given model as the default. public static IChatClient AsIChatClient(this Chat.ChatClient client, string defaultModelId) - => new GrokChatClient(client, defaultModelId); + => new GrokChatClient(Throw.IfNull(client), defaultModelId); + + /// Creates a new from the specified using the given model as the default. + public static IEmbeddingGenerator> AsIEmbeddingGenerator(this GrokClient client, string defaultModelId) + => new GrokEmbeddingGenerator(Throw.IfNull(client).Channel, client.Options, defaultModelId); + + /// Creates a new from the specified using the given model as the default. + public static IEmbeddingGenerator> AsIEmbeddingGenerator(this Embedder.EmbedderClient client, string defaultModelId) + => new GrokEmbeddingGenerator(Throw.IfNull(client), defaultModelId); } \ No newline at end of file diff --git a/src/xAI/GrokEmbeddingGenerator.cs b/src/xAI/GrokEmbeddingGenerator.cs new file mode 100644 index 0000000..b6fd36d --- /dev/null +++ b/src/xAI/GrokEmbeddingGenerator.cs @@ -0,0 +1,95 @@ +using Grpc.Net.Client; +using Microsoft.Extensions.AI; +using xAI.Protocol; +using static xAI.Protocol.Embedder; + +namespace xAI; + +class GrokEmbeddingGenerator : IEmbeddingGenerator> +{ + readonly EmbeddingGeneratorMetadata metadata; + readonly EmbedderClient client; + readonly string defaultModelId; + readonly GrokClientOptions clientOptions; + + internal GrokEmbeddingGenerator(GrpcChannel channel, GrokClientOptions clientOptions, string defaultModelId) + : this(new EmbedderClient(channel), clientOptions, defaultModelId) + { } + + /// + /// Test constructor. + /// + internal GrokEmbeddingGenerator(EmbedderClient client, string defaultModelId) + : this(client, new(), defaultModelId) + { } + + GrokEmbeddingGenerator(EmbedderClient client, GrokClientOptions clientOptions, string defaultModelId) + { + this.client = client; + this.clientOptions = clientOptions; + this.defaultModelId = defaultModelId; + metadata = new EmbeddingGeneratorMetadata("xai", clientOptions.Endpoint, defaultModelId); + } + + /// + public async Task>> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + { + var request = new EmbedRequest + { + Model = options?.ModelId ?? defaultModelId, + EncodingFormat = EmbedEncodingFormat.FormatFloat + }; + + foreach (var value in values) + { + request.Input.Add(new EmbedInput { String = value }); + } + + if ((clientOptions.EndUserId) is { } user) + request.User = user; + + var response = await client.EmbedAsync(request, cancellationToken: cancellationToken); + + var result = new GeneratedEmbeddings>(); + + foreach (var embedding in response.Embeddings.OrderBy(e => e.Index)) + { + // Each input can produce multiple feature vectors, we take the first one for text inputs + if (embedding.Embeddings.FirstOrDefault() is { } featureVector) + { + result.Add(new Embedding(featureVector.FloatArray.ToArray()) + { + CreatedAt = DateTimeOffset.UtcNow, + ModelId = response.Model, + }); + } + } + + if (response.Usage != null) + { + result.Usage = new UsageDetails + { + InputTokenCount = response.Usage.NumTextEmbeddings, + TotalTokenCount = response.Usage.NumTextEmbeddings + response.Usage.NumImageEmbeddings + }; + } + + return result; + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) => serviceType switch + { + Type t when t == typeof(EmbeddingGeneratorMetadata) => metadata, + Type t when t == typeof(GrokEmbeddingGenerator) => this, + Type t when t == typeof(EmbedderClient) => client, + Type t when t.IsInstanceOfType(this) => this, + _ => null + }; + + /// + public void Dispose() + { + // Nothing to dispose. Implementation required for the IEmbeddingGenerator interface. + } +}