Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 192 additions & 0 deletions src/xAI.Tests/EmbeddingGeneratorTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
using Microsoft.Extensions.AI;
using Moq;
using Tests.Client.Helpers;
using xAI;
using xAI.Protocol;
using static ConfigurationExtensions;

namespace xAI.Tests;

public class EmbeddingGeneratorTests(ITestOutputHelper output)
{
[SecretsFact("XAI_API_KEY")]
public async Task GrokGeneratesEmbeddings()
{
var client = new GrokClient(Configuration["XAI_API_KEY"]!);
var generator = client.AsIEmbeddingGenerator("v1");

var response = await generator.GenerateAsync(["Hello, world!", "How are you?"]);

Assert.NotNull(response);
Assert.Equal(2, response.Count);

foreach (var embedding in response)
{
Assert.NotNull(embedding.ModelId);
Assert.NotNull(embedding.CreatedAt);
Assert.NotEmpty(embedding.Vector.ToArray());
}
}

[Fact]
public void AsIEmbeddingGenerator_NullClient_Throws()
{
Assert.Throws<ArgumentNullException>("client", () => ((GrokClient)null!).AsIEmbeddingGenerator("model"));
}

[Fact]
public void AsIEmbeddingGenerator_ProducesExpectedMetadata()
{
Uri endpoint = new("https://api.x.ai");
string model = "v1";

var clientOptions = new GrokClientOptions { Endpoint = endpoint };
var mockClient = new Mock<xAI.Protocol.Embedder.EmbedderClient>(MockBehavior.Strict);

var embeddingGenerator = new GrokEmbeddingGenerator(mockClient.Object, model);
var metadata = embeddingGenerator.GetService<EmbeddingGeneratorMetadata>();

Assert.NotNull(metadata);
Assert.Equal("xai", metadata.ProviderName);
Assert.Equal(model, metadata.DefaultModelId);
}

[Fact]
public void GetService_SuccessfullyReturnsUnderlyingClient()
{
var mockClient = new Mock<xAI.Protocol.Embedder.EmbedderClient>(MockBehavior.Strict);
var embeddingGenerator = new GrokEmbeddingGenerator(mockClient.Object, "model");

Assert.Same(embeddingGenerator, embeddingGenerator.GetService<IEmbeddingGenerator<string, Embedding<float>>>());
Assert.Same(mockClient.Object, embeddingGenerator.GetService<xAI.Protocol.Embedder.EmbedderClient>());
Assert.Same(embeddingGenerator, embeddingGenerator.GetService<GrokEmbeddingGenerator>());
}

[Fact]
public async Task GenerateAsync_ExpectedRequestResponse()
{
var mockClient = new Mock<xAI.Protocol.Embedder.EmbedderClient>(MockBehavior.Strict);

var response = new EmbedResponse
{
Id = "test-id",
Model = "v1",
SystemFingerprint = "test-fingerprint",
Usage = new EmbeddingUsage
{
NumTextEmbeddings = 2,
NumImageEmbeddings = 0
}
};

// Add first embedding
var embedding1 = new xAI.Protocol.Embedding { Index = 0 };
var featureVector1 = new FeatureVector();
featureVector1.FloatArray.AddRange([0.1f, 0.2f, 0.3f]);
embedding1.Embeddings.Add(featureVector1);
response.Embeddings.Add(embedding1);

// Add second embedding
var embedding2 = new xAI.Protocol.Embedding { Index = 1 };
var featureVector2 = new FeatureVector();
featureVector2.FloatArray.AddRange([0.4f, 0.5f, 0.6f]);
embedding2.Embeddings.Add(featureVector2);
response.Embeddings.Add(embedding2);

mockClient
.Setup(x => x.EmbedAsync(It.IsAny<EmbedRequest>(), null, null, CancellationToken.None))
.Returns(CallHelpers.CreateAsyncUnaryCall(response));

var embeddingGenerator = new GrokEmbeddingGenerator(mockClient.Object, "v1");

var result = await embeddingGenerator.GenerateAsync(["hello, world!", "how are you?"]);

Assert.NotNull(result);
Assert.Equal(2, result.Count);

Assert.NotNull(result.Usage);
Assert.Equal(2, result.Usage.InputTokenCount);
Assert.Equal(2, result.Usage.TotalTokenCount);

var first = result[0];
Assert.Equal("v1", first.ModelId);
Assert.NotNull(first.CreatedAt);
Assert.Equal(3, first.Vector.Length);
Assert.Equal([0.1f, 0.2f, 0.3f], first.Vector.ToArray());

var second = result[1];
Assert.Equal("v1", second.ModelId);
Assert.NotNull(second.CreatedAt);
Assert.Equal(3, second.Vector.Length);
Assert.Equal([0.4f, 0.5f, 0.6f], second.Vector.ToArray());
}

[Fact]
public async Task GenerateAsync_MissingUsage_ReturnsNullUsage()
{
var mockClient = new Mock<xAI.Protocol.Embedder.EmbedderClient>(MockBehavior.Strict);

var response = new EmbedResponse
{
Id = "test-id",
Model = "v1",
};

// Add embedding without usage
var embedding = new xAI.Protocol.Embedding { Index = 0 };
var featureVector = new FeatureVector();
featureVector.FloatArray.AddRange([0.1f, 0.2f, 0.3f]);
embedding.Embeddings.Add(featureVector);
response.Embeddings.Add(embedding);

mockClient
.Setup(x => x.EmbedAsync(It.IsAny<EmbedRequest>(), null, null, CancellationToken.None))
.Returns(CallHelpers.CreateAsyncUnaryCall(response));

var embeddingGenerator = new GrokEmbeddingGenerator(mockClient.Object, "v1");

var result = await embeddingGenerator.GenerateAsync(["hello, world!"]);

Assert.NotNull(result);
Assert.Single(result);
Assert.Null(result.Usage);

var first = result[0];
Assert.Equal("v1", first.ModelId);
Assert.NotNull(first.CreatedAt);
Assert.Equal(3, first.Vector.Length);
}

[Fact]
public async Task GenerateAsync_UsesCustomModelId()
{
var mockClient = new Mock<xAI.Protocol.Embedder.EmbedderClient>(MockBehavior.Strict);
EmbedRequest? capturedRequest = null;

var response = new EmbedResponse
{
Id = "test-id",
Model = "custom-model",
};

var embedding = new xAI.Protocol.Embedding { Index = 0 };
var featureVector = new FeatureVector();
featureVector.FloatArray.AddRange([0.1f, 0.2f, 0.3f]);
embedding.Embeddings.Add(featureVector);
response.Embeddings.Add(embedding);

mockClient
.Setup(x => x.EmbedAsync(It.IsAny<EmbedRequest>(), null, null, CancellationToken.None))
.Callback<EmbedRequest, Grpc.Core.Metadata?, DateTime?, CancellationToken>((req, _, _, _) => capturedRequest = req)
.Returns(CallHelpers.CreateAsyncUnaryCall(response));

var embeddingGenerator = new GrokEmbeddingGenerator(mockClient.Object, "default-model");

var result = await embeddingGenerator.GenerateAsync(
["hello"],
new EmbeddingGenerationOptions { ModelId = "custom-model" });

Assert.NotNull(capturedRequest);
Assert.Equal("custom-model", capturedRequest.Model);
}
}
12 changes: 10 additions & 2 deletions src/xAI/GrokClientExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,17 @@ public static class GrokClientExtensions
{
/// <summary>Creates a new <see cref="IChatClient"/> from the specified <see cref="GrokClient"/> using the given model as the default.</summary>
public static IChatClient AsIChatClient(this GrokClient client, string defaultModelId)
=> new GrokChatClient(client.Channel, client.Options, defaultModelId);
=> new GrokChatClient(Throw.IfNull(client).Channel, client.Options, defaultModelId);

/// <summary>Creates a new <see cref="IChatClient"/> from the specified <see cref="Chat.ChatClient"/> using the given model as the default.</summary>
public static IChatClient AsIChatClient(this Chat.ChatClient client, string defaultModelId)
=> new GrokChatClient(client, defaultModelId);
=> new GrokChatClient(Throw.IfNull(client), defaultModelId);

/// <summary>Creates a new <see cref="IEmbeddingGenerator{String, Embedding}"/> from the specified <see cref="GrokClient"/> using the given model as the default.</summary>
public static IEmbeddingGenerator<string, Embedding<float>> AsIEmbeddingGenerator(this GrokClient client, string defaultModelId)
=> new GrokEmbeddingGenerator(Throw.IfNull(client).Channel, client.Options, defaultModelId);

/// <summary>Creates a new <see cref="IEmbeddingGenerator{String, Embedding}"/> from the specified <see cref="Embedder.EmbedderClient"/> using the given model as the default.</summary>
public static IEmbeddingGenerator<string, Embedding<float>> AsIEmbeddingGenerator(this Embedder.EmbedderClient client, string defaultModelId)
=> new GrokEmbeddingGenerator(Throw.IfNull(client), defaultModelId);
}
95 changes: 95 additions & 0 deletions src/xAI/GrokEmbeddingGenerator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
using Grpc.Net.Client;
using Microsoft.Extensions.AI;
using xAI.Protocol;
using static xAI.Protocol.Embedder;

namespace xAI;

class GrokEmbeddingGenerator : IEmbeddingGenerator<string, Embedding<float>>
{
readonly EmbeddingGeneratorMetadata metadata;
readonly EmbedderClient client;
readonly string defaultModelId;
readonly GrokClientOptions clientOptions;

internal GrokEmbeddingGenerator(GrpcChannel channel, GrokClientOptions clientOptions, string defaultModelId)
: this(new EmbedderClient(channel), clientOptions, defaultModelId)
{ }

/// <summary>
/// Test constructor.
/// </summary>
internal GrokEmbeddingGenerator(EmbedderClient client, string defaultModelId)
: this(client, new(), defaultModelId)
{ }

GrokEmbeddingGenerator(EmbedderClient client, GrokClientOptions clientOptions, string defaultModelId)
{
this.client = client;
this.clientOptions = clientOptions;
this.defaultModelId = defaultModelId;
metadata = new EmbeddingGeneratorMetadata("xai", clientOptions.Endpoint, defaultModelId);
}

/// <inheritdoc />
public async Task<GeneratedEmbeddings<Embedding<float>>> GenerateAsync(IEnumerable<string> values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default)
{
var request = new EmbedRequest
{
Model = options?.ModelId ?? defaultModelId,
EncodingFormat = EmbedEncodingFormat.FormatFloat
};

foreach (var value in values)
{
request.Input.Add(new EmbedInput { String = value });
}

if ((clientOptions.EndUserId) is { } user)
request.User = user;

var response = await client.EmbedAsync(request, cancellationToken: cancellationToken);

var result = new GeneratedEmbeddings<Embedding<float>>();

foreach (var embedding in response.Embeddings.OrderBy(e => e.Index))
{
// Each input can produce multiple feature vectors, we take the first one for text inputs
if (embedding.Embeddings.FirstOrDefault() is { } featureVector)
{
result.Add(new Embedding<float>(featureVector.FloatArray.ToArray())
{
CreatedAt = DateTimeOffset.UtcNow,
ModelId = response.Model,
});
}
}

if (response.Usage != null)
{
result.Usage = new UsageDetails
{
InputTokenCount = response.Usage.NumTextEmbeddings,
TotalTokenCount = response.Usage.NumTextEmbeddings + response.Usage.NumImageEmbeddings
};
}

return result;
}

/// <inheritdoc />
public object? GetService(Type serviceType, object? serviceKey = null) => serviceType switch
{
Type t when t == typeof(EmbeddingGeneratorMetadata) => metadata,
Type t when t == typeof(GrokEmbeddingGenerator) => this,
Type t when t == typeof(EmbedderClient) => client,
Type t when t.IsInstanceOfType(this) => this,
_ => null
};

/// <inheritdoc />
public void Dispose()
{
// Nothing to dispose. Implementation required for the IEmbeddingGenerator interface.
}
}