diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/GoogleAI/GoogleAIEmbeddingRequestTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/GoogleAI/GoogleAIEmbeddingRequestTests.cs index 731e20dda585..26643dd45d97 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/GoogleAI/GoogleAIEmbeddingRequestTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/GoogleAI/GoogleAIEmbeddingRequestTests.cs @@ -83,4 +83,24 @@ public void FromDataJsonIncludesDimensionsWhenProvided() // Assert Assert.Contains($"{DimensionalityJsonPropertyName}:{Dimensions}", json); } + + [Fact] + public void FromDataShouldIncludeTaskTypeWhenProvided() + { + // Arrange + var input = new[] { "This is a retrieval document." }; + var modelId = "embedding-001"; + var dimensions = 1024; + var taskType = "RETRIEVAL_DOCUMENT"; + + // Act + var request = GoogleAIEmbeddingRequest.FromData(input, modelId, dimensions, taskType); + + // Serialize to JSON (this is what would be sent in the HTTP request) + var json = System.Text.Json.JsonSerializer.Serialize(request); + + // Assert + Assert.Contains("\"taskType\":\"RETRIEVAL_DOCUMENT\"", json); + Assert.Contains("\"model\":\"models/embedding-001\"", json); + } } diff --git a/dotnet/src/Connectors/Connectors.Google/Core/GoogleAI/GoogleAIEmbeddingClient.cs b/dotnet/src/Connectors/Connectors.Google/Core/GoogleAI/GoogleAIEmbeddingClient.cs index 6a801acff76e..5ad762b8608a 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/GoogleAI/GoogleAIEmbeddingClient.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/GoogleAI/GoogleAIEmbeddingClient.cs @@ -6,6 +6,7 @@ using System.Net.Http; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; namespace Microsoft.SemanticKernel.Connectors.Google.Core; @@ -54,15 +55,24 @@ public GoogleAIEmbeddingClient( /// Generates embeddings for the given data asynchronously. /// /// The list of strings to generate embeddings for. + /// The embedding generation options. /// The cancellation token to cancel the operation. /// Result contains a list of read-only memories of floats representing the generated embeddings. public async Task>> GenerateEmbeddingsAsync( IList data, + EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) { Verify.NotNullOrEmpty(data); - var geminiRequest = this.GetEmbeddingRequest(data); + string? taskType = null; + if (options?.AdditionalProperties?.TryGetValue("task_type", out var taskTypeValue) == true) + { + taskType = taskTypeValue?.ToString(); + } + + var geminiRequest = this.GetEmbeddingRequest(data, taskType); + using var httpRequestMessage = await this.CreateHttpRequestAsync(geminiRequest, this._embeddingEndpoint).ConfigureAwait(false); string body = await this.SendRequestAndGetStringBodyAsync(httpRequestMessage, cancellationToken) @@ -71,8 +81,8 @@ public async Task>> GenerateEmbeddingsAsync( return DeserializeAndProcessEmbeddingsResponse(body); } - private GoogleAIEmbeddingRequest GetEmbeddingRequest(IEnumerable data) - => GoogleAIEmbeddingRequest.FromData(data, this._embeddingModelId, this._dimensions); + private GoogleAIEmbeddingRequest GetEmbeddingRequest(IEnumerable data, string? taskType = null) + => GoogleAIEmbeddingRequest.FromData(data, this._embeddingModelId, this._dimensions, taskType); private static List> DeserializeAndProcessEmbeddingsResponse(string body) => ProcessEmbeddingsResponse(DeserializeResponse(body)); diff --git a/dotnet/src/Connectors/Connectors.Google/Core/GoogleAI/GoogleAIEmbeddingRequest.cs b/dotnet/src/Connectors/Connectors.Google/Core/GoogleAI/GoogleAIEmbeddingRequest.cs index d69953dc5423..4b019832d034 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/GoogleAI/GoogleAIEmbeddingRequest.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/GoogleAI/GoogleAIEmbeddingRequest.cs @@ -11,7 +11,7 @@ internal sealed class GoogleAIEmbeddingRequest [JsonPropertyName("requests")] public IList Requests { get; set; } = null!; - public static GoogleAIEmbeddingRequest FromData(IEnumerable data, string modelId, int? dimensions = null) => new() + public static GoogleAIEmbeddingRequest FromData(IEnumerable data, string modelId, int? dimensions = null, string? taskType = null) => new() { Requests = data.Select(text => new RequestEmbeddingContent { @@ -26,7 +26,8 @@ internal sealed class GoogleAIEmbeddingRequest } ] }, - Dimensions = dimensions + Dimensions = dimensions, + TaskType = taskType }).ToList() }; diff --git a/dotnet/src/Connectors/Connectors.Google/Services/GoogleAITextEmbeddingGenerationService.cs b/dotnet/src/Connectors/Connectors.Google/Services/GoogleAITextEmbeddingGenerationService.cs index d526801c52c9..6958105f3ee3 100644 --- a/dotnet/src/Connectors/Connectors.Google/Services/GoogleAITextEmbeddingGenerationService.cs +++ b/dotnet/src/Connectors/Connectors.Google/Services/GoogleAITextEmbeddingGenerationService.cs @@ -5,6 +5,7 @@ using System.Net.Http; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel.Connectors.Google.Core; using Microsoft.SemanticKernel.Embeddings; @@ -68,6 +69,24 @@ public Task>> GenerateEmbeddingsAsync( Kernel? kernel = null, CancellationToken cancellationToken = default) { - return this._embeddingClient.GenerateEmbeddingsAsync(data, cancellationToken); + return this._embeddingClient.GenerateEmbeddingsAsync(data, null, cancellationToken); + } + + /// + /// Generates embeddings for the specified input text, allowing additional configuration + /// via (e.g., specifying the Google task type). + /// + /// The input text collection to generate embeddings for. + /// Embedding generation options (e.g., task_type). + /// Optional Kernel instance. + /// Token for cancelling the request. + /// A list of generated embeddings. + public Task>> GenerateEmbeddingsAsync( + IList data, + EmbeddingGenerationOptions? options, + Kernel? kernel = null, + CancellationToken cancellationToken = default) + { + return this._embeddingClient.GenerateEmbeddingsAsync(data, options, cancellationToken); } }