From 5db874b7fdd231f87d7b4e4195f417158553bd28 Mon Sep 17 00:00:00 2001
From: donalevans
Date: Thu, 9 Oct 2025 11:07:00 -0700
Subject: [PATCH 1/3] Replace "text embedding" with "dense embedding"
The name "text embedding" is used in many places where dense vector
embeddings are handled, despite the type of the embedding vector not
being exclusive to text embeddings. For example, image or multimodal
embeddings may also produce a dense vector. To allow future reuse of
classes related to dense vectors with multimodal embeddings, the naming
is being changed to the more general "dense embedding". Classes which
explicitly relate to text embeddings are not being renamed.
This rename is internal to the code only and does not change the name of
any JSON objects which currently use "text_embedding", as doing so would
be a breaking change.
- For everything not exclusively related to text embedding, rename
classes, methods and variables to use "dense embedding" instead of
"text embedding"
- Use correct class name in
ElasticTextEmbeddingPayload.TextEmbeddingFloat.PARSER
- Correct the javadoc in DenseEmbeddingBitResults
---
.../inference/InferenceServiceResults.java | 2 +-
.../inference/action/InferenceAction.java | 2 +-
...lts.java => DenseEmbeddingBitResults.java} | 23 +++--
...ts.java => DenseEmbeddingByteResults.java} | 17 ++--
...s.java => DenseEmbeddingFloatResults.java} | 37 ++++---
...esults.java => DenseEmbeddingResults.java} | 6 +-
.../results/LegacyTextEmbeddingResults.java | 9 +-
.../MlInferenceNamedXContentProvider.java | 4 +-
...ults.java => MlDenseEmbeddingResults.java} | 9 +-
.../TextEmbeddingQueryVectorBuilder.java | 6 +-
...ava => DenseEmbeddingBitResultsTests.java} | 56 +++++------
...va => DenseEmbeddingByteResultsTests.java} | 68 ++++++-------
...a => DenseEmbeddingFloatResultsTests.java} | 75 ++++++++-------
...a => LegacyTextEmbeddingResultsTests.java} | 2 +-
.../action/InferModelActionResponseTests.java | 8 +-
...erTrainedModelDeploymentResponseTests.java | 10 +-
...java => MlDenseEmbeddingResultsTests.java} | 18 ++--
.../TextEmbeddingOperatorOutputBuilder.java | 28 +++---
...xtEmbeddingOperatorOutputBuilderTests.java | 22 ++---
.../TextEmbeddingOperatorTests.java | 10 +-
.../TestDenseInferenceServiceExtension.java | 14 +--
...stStreamingCompletionServiceExtension.java | 16 ++--
.../InferenceNamedWriteablesProvider.java | 20 ++--
...viceDenseTextEmbeddingsResponseEntity.java | 12 +--
.../mapper/SemanticTextFieldMapper.java | 8 +-
...rceptedInferenceKnnVectorQueryBuilder.java | 14 +--
.../queries/SemanticQueryBuilder.java | 6 +-
...baCloudSearchEmbeddingsResponseEntity.java | 12 +--
.../AmazonBedrockEmbeddingsResponse.java | 20 ++--
.../CohereEmbeddingsResponseEntity.java | 20 ++--
.../custom/CustomServiceSettings.java | 4 +-
...java => DenseEmbeddingResponseParser.java} | 44 ++++-----
.../ElasticsearchInternalService.java | 16 ++--
...oogleAiStudioEmbeddingsResponseEntity.java | 12 +--
...oogleVertexAiEmbeddingsResponseEntity.java | 12 +--
.../elser/HuggingFaceElserService.java | 10 +-
.../HuggingFaceEmbeddingsResponseEntity.java | 20 ++--
.../IbmWatsonxEmbeddingsResponseEntity.java | 12 +--
.../JinaAIEmbeddingsResponseEntity.java | 22 ++---
.../OpenAiEmbeddingsResponseEntity.java | 12 +--
.../schema/SageMakerStoredServiceSchema.java | 3 +-
.../elastic/ElasticTextEmbeddingPayload.java | 48 +++++-----
.../openai/OpenAiTextEmbeddingPayload.java | 6 +-
...java => DenseEmbeddingModelValidator.java} | 14 +--
...icsearchInternalServiceModelValidator.java | 10 +-
.../validation/ModelValidatorBuilder.java | 2 +-
.../VoyageAIEmbeddingsResponseEntity.java | 26 ++---
.../action/InferenceActionResponseTests.java | 8 +-
.../EmbeddingRequestChunkerTests.java | 96 ++++++++++---------
.../http/sender/HttpRequestSenderTests.java | 2 +-
.../mapper/SemanticTextFieldTests.java | 12 +--
...edInferenceKnnVectorQueryBuilderTests.java | 6 +-
.../queries/MockInferenceClient.java | 10 +-
.../queries/SemanticQueryBuilderTests.java | 8 +-
.../rest/BaseInferenceActionTests.java | 4 +-
.../AlibabaCloudSearchServiceTests.java | 10 +-
.../AlibabaCloudSearchActionCreatorTests.java | 6 +-
...udSearchEmbeddingsResponseEntityTests.java | 6 +-
.../AmazonBedrockServiceTests.java | 26 ++---
.../AmazonBedrockActionCreatorTests.java | 12 +--
.../client/AmazonBedrockExecutorTests.java | 2 +-
.../AmazonBedrockRequestSenderTests.java | 2 +-
.../AzureAiStudioServiceTests.java | 10 +-
.../AzureAiStudioActionAndCreatorTests.java | 2 +-
...AiStudioEmbeddingsResponseEntityTests.java | 6 +-
.../azureopenai/AzureOpenAiServiceTests.java | 12 +--
.../action/AzureOpenAiActionCreatorTests.java | 2 +-
.../AzureOpenAiEmbeddingsActionTests.java | 2 +-
.../services/cohere/CohereServiceTests.java | 18 ++--
.../action/CohereActionCreatorTests.java | 2 +-
.../action/CohereEmbeddingsActionTests.java | 4 +-
.../CohereEmbeddingsResponseEntityTests.java | 50 +++++-----
.../services/custom/CustomModelTests.java | 4 +-
.../custom/CustomServiceSettingsTests.java | 56 +++++------
.../services/custom/CustomServiceTests.java | 42 ++++----
.../custom/request/CustomRequestTests.java | 8 +-
.../response/CustomResponseEntityTests.java | 10 +-
...=> DenseEmbeddingResponseParserTests.java} | 77 ++++++++-------
.../elastic/ElasticInferenceServiceTests.java | 10 +-
...ticInferenceServiceActionCreatorTests.java | 14 +--
...enseTextEmbeddingsResponseEntityTests.java | 10 +-
.../ElasticsearchInternalServiceTests.java | 28 +++---
.../GoogleAiStudioServiceTests.java | 12 +--
.../GoogleAiStudioEmbeddingsActionTests.java | 2 +-
...AiStudioEmbeddingsResponseEntityTests.java | 12 +--
...VertexAiEmbeddingsResponseEntityTests.java | 12 +--
.../huggingface/HuggingFaceServiceTests.java | 12 +--
.../action/HuggingFaceActionCreatorTests.java | 17 +++-
...gingFaceEmbeddingsResponseEntityTests.java | 38 ++++----
.../ibmwatsonx/IbmWatsonxServiceTests.java | 12 +--
.../IbmWatsonxEmbeddingsActionTests.java | 2 +-
...mWatsonxEmbeddingsResponseEntityTests.java | 12 +--
.../services/jinaai/JinaAIServiceTests.java | 12 +--
.../JinaAIEmbeddingsResponseEntityTests.java | 36 +++----
.../services/llama/LlamaServiceTests.java | 10 +-
.../llama/action/LlamaActionCreatorTests.java | 7 +-
.../services/mistral/MistralServiceTests.java | 10 +-
.../services/openai/OpenAiServiceTests.java | 12 +--
.../action/OpenAiActionCreatorTests.java | 2 +-
.../action/OpenAiEmbeddingsActionTests.java | 2 +-
.../OpenAiEmbeddingsResponseEntityTests.java | 28 +++---
.../sagemaker/SageMakerServiceTests.java | 4 +-
.../ElasticTextEmbeddingPayloadTests.java | 16 ++--
.../OpenAiTextEmbeddingPayloadTests.java | 8 +-
...=> DenseEmbeddingModelValidatorTests.java} | 18 ++--
...rchInternalServiceModelValidatorTests.java | 10 +-
.../ModelValidatorBuilderTests.java | 2 +-
.../voyageai/VoyageAIServiceTests.java | 15 +--
.../action/VoyageAIActionCreatorTests.java | 2 +-
.../action/VoyageAIEmbeddingsActionTests.java | 6 +-
...VoyageAIEmbeddingsResponseEntityTests.java | 30 +++---
.../inference/nlp/TextEmbeddingProcessor.java | 4 +-
.../nlp/TextEmbeddingProcessorTests.java | 6 +-
.../TextEmbeddingQueryVectorBuilderTests.java | 4 +-
114 files changed, 918 insertions(+), 861 deletions(-)
rename x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/{TextEmbeddingBitResults.java => DenseEmbeddingBitResults.java} (75%)
rename x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/{TextEmbeddingByteResults.java => DenseEmbeddingByteResults.java} (90%)
rename x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/{TextEmbeddingFloatResults.java => DenseEmbeddingFloatResults.java} (83%)
rename x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/{TextEmbeddingResults.java => DenseEmbeddingResults.java} (66%)
rename x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/{MlTextEmbeddingResults.java => MlDenseEmbeddingResults.java} (87%)
rename x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/{TextEmbeddingBitResultsTests.java => DenseEmbeddingBitResultsTests.java} (55%)
rename x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/{TextEmbeddingByteResultsTests.java => DenseEmbeddingByteResultsTests.java} (50%)
rename x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/{TextEmbeddingFloatResultsTests.java => DenseEmbeddingFloatResultsTests.java} (50%)
rename x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/{LegacyMlTextEmbeddingResultsTests.java => LegacyTextEmbeddingResultsTests.java} (97%)
rename x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/{MlTextEmbeddingResultsTests.java => MlDenseEmbeddingResultsTests.java} (68%)
rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/{TextEmbeddingResponseParser.java => DenseEmbeddingResponseParser.java} (80%)
rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/{TextEmbeddingModelValidator.java => DenseEmbeddingModelValidator.java} (85%)
rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/{TextEmbeddingResponseParserTests.java => DenseEmbeddingResponseParserTests.java} (71%)
rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/{TextEmbeddingModelValidatorTests.java => DenseEmbeddingModelValidatorTests.java} (90%)
diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java
index 3746960ad8f78..2d1b932271f25 100644
--- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java
+++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java
@@ -24,7 +24,7 @@ public interface InferenceServiceResults extends NamedWriteable, ChunkedToXConte
/**
* Transform the result to match the format required for the TransportCoordinatedInferenceAction.
- * TransportCoordinatedInferenceAction expects an ml plugin TextEmbeddingResults or SparseEmbeddingResults.
+ * TransportCoordinatedInferenceAction expects an ml plugin DenseEmbeddingResults or SparseEmbeddingResults.
*/
default List extends InferenceResults> transformToCoordinationFormat() {
throw new UnsupportedOperationException("transformToCoordinationFormat() is not implemented");
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java
index c3066a56447c9..3f170f2272d7f 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java
@@ -539,7 +539,7 @@ public static InferenceServiceResults transformToServiceResults(List extends I
);
}
- return openaiResults.transformToTextEmbeddingResults();
+ return openaiResults.transformToDenseEmbeddingResults();
} else if (parsedResults.get(0) instanceof TextExpansionResults) {
return transformToSparseEmbeddingResult(parsedResults);
} else {
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingBitResults.java
similarity index 75%
rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResults.java
rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingBitResults.java
index 37fca12f1697a..0792bf90dbe12 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResults.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingBitResults.java
@@ -14,7 +14,7 @@
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.ToXContent;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import java.io.IOException;
import java.util.Iterator;
@@ -24,9 +24,10 @@
import java.util.Objects;
/**
- * Writes a text embedding result in the follow json format
+ * Writes a dense embedding result in the follow json format.
+ *
* {
- * "text_embedding_bytes": [
+ * "text_embedding_bits": [
* {
* "embedding": [
* 23
@@ -39,17 +40,19 @@
* }
* ]
* }
+ *
*/
-// Note: inheriting from TextEmbeddingByteResults gives a bad implementation of the
+// Note: inheriting from DenseEmbeddingByteResults gives a bad implementation of the
// Embedding.merge method for bits. TODO: implement a proper merge method
-public record TextEmbeddingBitResults(List embeddings)
+public record DenseEmbeddingBitResults(List embeddings)
implements
- TextEmbeddingResults {
+ DenseEmbeddingResults {
+ // This name is a holdover from before this class was renamed
public static final String NAME = "text_embedding_service_bit_results";
public static final String TEXT_EMBEDDING_BITS = "text_embedding_bits";
- public TextEmbeddingBitResults(StreamInput in) throws IOException {
- this(in.readCollectionAsList(TextEmbeddingByteResults.Embedding::new));
+ public DenseEmbeddingBitResults(StreamInput in) throws IOException {
+ this(in.readCollectionAsList(DenseEmbeddingByteResults.Embedding::new));
}
@Override
@@ -79,7 +82,7 @@ public String getWriteableName() {
@Override
public List extends InferenceResults> transformToCoordinationFormat() {
return embeddings.stream()
- .map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING_BITS, embedding.toDoubleArray(), false))
+ .map(embedding -> new MlDenseEmbeddingResults(TEXT_EMBEDDING_BITS, embedding.toDoubleArray(), false))
.toList();
}
@@ -94,7 +97,7 @@ public Map asMap() {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
- TextEmbeddingBitResults that = (TextEmbeddingBitResults) o;
+ DenseEmbeddingBitResults that = (DenseEmbeddingBitResults) o;
return Objects.equals(embeddings, that.embeddings);
}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingByteResults.java
similarity index 90%
rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java
rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingByteResults.java
index 54f858cb20ae0..9e72dc9a7b2b4 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingByteResults.java
@@ -20,7 +20,7 @@
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentBuilder;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import java.io.IOException;
import java.util.Arrays;
@@ -31,7 +31,8 @@
import java.util.Objects;
/**
- * Writes a text embedding result in the follow json format
+ * Writes a dense embedding result in the follow json format
+ *
* {
* "text_embedding_bytes": [
* {
@@ -46,13 +47,15 @@
* }
* ]
* }
+ *
*/
-public record TextEmbeddingByteResults(List embeddings) implements TextEmbeddingResults {
+public record DenseEmbeddingByteResults(List embeddings) implements DenseEmbeddingResults {
+ // This name is a holdover from before this class was renamed
public static final String NAME = "text_embedding_service_byte_results";
public static final String TEXT_EMBEDDING_BYTES = "text_embedding_bytes";
- public TextEmbeddingByteResults(StreamInput in) throws IOException {
- this(in.readCollectionAsList(TextEmbeddingByteResults.Embedding::new));
+ public DenseEmbeddingByteResults(StreamInput in) throws IOException {
+ this(in.readCollectionAsList(DenseEmbeddingByteResults.Embedding::new));
}
@Override
@@ -81,7 +84,7 @@ public String getWriteableName() {
@Override
public List extends InferenceResults> transformToCoordinationFormat() {
return embeddings.stream()
- .map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING_BYTES, embedding.toDoubleArray(), false))
+ .map(embedding -> new MlDenseEmbeddingResults(TEXT_EMBEDDING_BYTES, embedding.toDoubleArray(), false))
.toList();
}
@@ -96,7 +99,7 @@ public Map asMap() {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
- TextEmbeddingByteResults that = (TextEmbeddingByteResults) o;
+ DenseEmbeddingByteResults that = (DenseEmbeddingByteResults) o;
return Objects.equals(embeddings, that.embeddings);
}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingFloatResults.java
similarity index 83%
rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResults.java
rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingFloatResults.java
index e68a5e4bd13b0..797870198f5f5 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResults.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingFloatResults.java
@@ -23,7 +23,7 @@
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentBuilder;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import java.io.IOException;
import java.util.ArrayList;
@@ -36,7 +36,8 @@
import java.util.stream.Collectors;
/**
- * Writes a text embedding result in the follow json format
+ * Writes a dense embedding result in the follow json format
+ *
* {
* "text_embedding": [
* {
@@ -51,17 +52,21 @@
* }
* ]
* }
+ *
*/
-public record TextEmbeddingFloatResults(List embeddings) implements TextEmbeddingResults {
+public record DenseEmbeddingFloatResults(List embeddings)
+ implements
+ DenseEmbeddingResults {
+ // This name is a holdover from before this class was renamed
public static final String NAME = "text_embedding_service_results";
public static final String TEXT_EMBEDDING = TaskType.TEXT_EMBEDDING.toString();
- public TextEmbeddingFloatResults(StreamInput in) throws IOException {
- this(in.readCollectionAsList(TextEmbeddingFloatResults.Embedding::new));
+ public DenseEmbeddingFloatResults(StreamInput in) throws IOException {
+ this(in.readCollectionAsList(DenseEmbeddingFloatResults.Embedding::new));
}
@SuppressWarnings("deprecation")
- TextEmbeddingFloatResults(LegacyTextEmbeddingResults legacyTextEmbeddingResults) {
+ DenseEmbeddingFloatResults(LegacyTextEmbeddingResults legacyTextEmbeddingResults) {
this(
legacyTextEmbeddingResults.embeddings()
.stream()
@@ -70,11 +75,11 @@ public TextEmbeddingFloatResults(StreamInput in) throws IOException {
);
}
- public static TextEmbeddingFloatResults of(List extends InferenceResults> results) {
+ public static DenseEmbeddingFloatResults of(List extends InferenceResults> results) {
List embeddings = new ArrayList<>(results.size());
for (InferenceResults result : results) {
- if (result instanceof MlTextEmbeddingResults embeddingResult) {
- embeddings.add(TextEmbeddingFloatResults.Embedding.of(embeddingResult));
+ if (result instanceof MlDenseEmbeddingResults embeddingResult) {
+ embeddings.add(DenseEmbeddingFloatResults.Embedding.of(embeddingResult));
} else if (result instanceof org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults errorResult) {
if (errorResult.getException() instanceof ElasticsearchStatusException statusException) {
throw statusException;
@@ -87,11 +92,15 @@ public static TextEmbeddingFloatResults of(List extends InferenceResults> resu
}
} else {
throw new IllegalArgumentException(
- "Received invalid inference result, of type " + result.getClass().getName() + " but expected TextEmbeddingResults."
+ "Received invalid inference result, of type "
+ + result.getClass().getName()
+ + " but expected "
+ + MlDenseEmbeddingResults.class.getName()
+ + "."
);
}
}
- return new TextEmbeddingFloatResults(embeddings);
+ return new DenseEmbeddingFloatResults(embeddings);
}
@Override
@@ -119,7 +128,7 @@ public String getWriteableName() {
@Override
public List extends InferenceResults> transformToCoordinationFormat() {
- return embeddings.stream().map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING, embedding.asDoubleArray(), false)).toList();
+ return embeddings.stream().map(embedding -> new MlDenseEmbeddingResults(TEXT_EMBEDDING, embedding.asDoubleArray(), false)).toList();
}
public Map asMap() {
@@ -133,7 +142,7 @@ public Map asMap() {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
- TextEmbeddingFloatResults that = (TextEmbeddingFloatResults) o;
+ DenseEmbeddingFloatResults that = (DenseEmbeddingFloatResults) o;
return Objects.equals(embeddings, that.embeddings);
}
@@ -159,7 +168,7 @@ public Embedding(StreamInput in) throws IOException {
this(in.readFloatArray());
}
- public static Embedding of(MlTextEmbeddingResults embeddingResult) {
+ public static Embedding of(MlDenseEmbeddingResults embeddingResult) {
float[] embeddingAsArray = embeddingResult.getInferenceAsFloat();
return new Embedding(embeddingAsArray);
}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingResults.java
similarity index 66%
rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java
rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingResults.java
index ea4e45ec67407..af6abb357bd98 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingResults.java
@@ -7,11 +7,11 @@
package org.elasticsearch.xpack.core.inference.results;
-public interface TextEmbeddingResults> extends EmbeddingResults {
+public interface DenseEmbeddingResults> extends EmbeddingResults {
/**
- * Returns the first text embedding entry in the result list's array size.
- * @return the size of the text embedding
+ * Returns the first embedding entry in the result list's array size.
+ * @return the size of the embedding
* @throws IllegalStateException if the list of embeddings is empty
*/
int getFirstEmbeddingSize() throws IllegalStateException;
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java
index 60bbeb624b532..ff3bcac2a6a6b 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java
@@ -27,6 +27,7 @@
/**
* Writes a text embedding result in the following json format
+ *
* {
* "text_embedding": [
* {
@@ -41,10 +42,10 @@
* }
* ]
* }
- *
+ *
* Legacy text embedding results represents what was returned prior to the
* {@link org.elasticsearch.TransportVersions#V_8_12_0} version.
- * @deprecated use {@link TextEmbeddingFloatResults} instead
+ * @deprecated use {@link DenseEmbeddingFloatResults} instead
*/
@Deprecated
public record LegacyTextEmbeddingResults(List embeddings) implements InferenceResults {
@@ -114,8 +115,8 @@ public int hashCode() {
return Objects.hash(embeddings);
}
- public TextEmbeddingFloatResults transformToTextEmbeddingResults() {
- return new TextEmbeddingFloatResults(this);
+ public DenseEmbeddingFloatResults transformToDenseEmbeddingResults() {
+ return new DenseEmbeddingFloatResults(this);
}
public record Embedding(float[] values) implements Writeable, ToXContentObject {
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java
index 667d7bf63efc9..f2952fee36cda 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java
@@ -25,7 +25,7 @@
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults;
@@ -669,7 +669,7 @@ public List getNamedWriteables() {
);
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, TextExpansionResults.NAME, TextExpansionResults::new));
namedWriteables.add(
- new NamedWriteableRegistry.Entry(InferenceResults.class, MlTextEmbeddingResults.NAME, MlTextEmbeddingResults::new)
+ new NamedWriteableRegistry.Entry(InferenceResults.class, MlDenseEmbeddingResults.NAME, MlDenseEmbeddingResults::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/MlTextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/MlDenseEmbeddingResults.java
similarity index 87%
rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/MlTextEmbeddingResults.java
rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/MlDenseEmbeddingResults.java
index 0c0fa6f3f690e..b839ade030e1e 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/MlTextEmbeddingResults.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/MlDenseEmbeddingResults.java
@@ -16,20 +16,21 @@
import java.util.Map;
import java.util.Objects;
-public class MlTextEmbeddingResults extends NlpInferenceResults {
+public class MlDenseEmbeddingResults extends NlpInferenceResults {
+ // This name is a holdover from before this class was renamed
public static final String NAME = "text_embedding_result";
private final String resultsField;
private final double[] inference;
- public MlTextEmbeddingResults(String resultsField, double[] inference, boolean isTruncated) {
+ public MlDenseEmbeddingResults(String resultsField, double[] inference, boolean isTruncated) {
super(isTruncated);
this.inference = inference;
this.resultsField = resultsField;
}
- public MlTextEmbeddingResults(StreamInput in) throws IOException {
+ public MlDenseEmbeddingResults(StreamInput in) throws IOException {
super(in);
inference = in.readDoubleArray();
resultsField = in.readString();
@@ -89,7 +90,7 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (super.equals(o) == false) return false;
- MlTextEmbeddingResults that = (MlTextEmbeddingResults) o;
+ MlDenseEmbeddingResults that = (MlDenseEmbeddingResults) o;
return Objects.equals(resultsField, that.resultsField) && Arrays.equals(inference, that.inference);
}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/vectors/TextEmbeddingQueryVectorBuilder.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/vectors/TextEmbeddingQueryVectorBuilder.java
index f68fe805c404e..49af56b904c46 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/vectors/TextEmbeddingQueryVectorBuilder.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/vectors/TextEmbeddingQueryVectorBuilder.java
@@ -21,7 +21,7 @@
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate;
@@ -127,14 +127,14 @@ public void buildVector(Client client, ActionListener listener) {
return;
}
- if (response.getInferenceResults().get(0) instanceof MlTextEmbeddingResults textEmbeddingResults) {
+ if (response.getInferenceResults().get(0) instanceof MlDenseEmbeddingResults textEmbeddingResults) {
listener.onResponse(textEmbeddingResults.getInferenceAsFloat());
} else if (response.getInferenceResults().get(0) instanceof WarningInferenceResults warning) {
listener.onFailure(new IllegalStateException(warning.getWarning()));
} else {
throw new IllegalArgumentException(
"expected a result of type ["
- + MlTextEmbeddingResults.NAME
+ + MlDenseEmbeddingResults.NAME
+ "] received ["
+ response.getInferenceResults().get(0).getWriteableName()
+ "]. Is ["
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingBitResultsTests.java
similarity index 55%
rename from x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResultsTests.java
rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingBitResultsTests.java
index 61b49075702a2..9e5814825d659 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResultsTests.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingBitResultsTests.java
@@ -10,7 +10,7 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import java.io.IOException;
import java.util.ArrayList;
@@ -19,19 +19,19 @@
import static org.hamcrest.Matchers.is;
-public class TextEmbeddingBitResultsTests extends AbstractWireSerializingTestCase {
- public static TextEmbeddingBitResults createRandomResults() {
+public class DenseEmbeddingBitResultsTests extends AbstractWireSerializingTestCase {
+ public static DenseEmbeddingBitResults createRandomResults() {
int embeddings = randomIntBetween(1, 10);
- List embeddingResults = new ArrayList<>(embeddings);
+ List embeddingResults = new ArrayList<>(embeddings);
for (int i = 0; i < embeddings; i++) {
embeddingResults.add(createRandomEmbedding());
}
- return new TextEmbeddingBitResults(embeddingResults);
+ return new DenseEmbeddingBitResults(embeddingResults);
}
- private static TextEmbeddingByteResults.Embedding createRandomEmbedding() {
+ private static DenseEmbeddingByteResults.Embedding createRandomEmbedding() {
int columns = randomIntBetween(1, 10);
byte[] bytes = new byte[columns];
@@ -39,11 +39,11 @@ private static TextEmbeddingByteResults.Embedding createRandomEmbedding() {
bytes[i] = randomByte();
}
- return new TextEmbeddingByteResults.Embedding(bytes);
+ return new DenseEmbeddingByteResults.Embedding(bytes);
}
public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException {
- var entity = new TextEmbeddingBitResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23 })));
+ var entity = new DenseEmbeddingBitResults(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23 })));
String xContentResult = Strings.toString(entity, true, true);
assertThat(xContentResult, is("""
@@ -59,10 +59,10 @@ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOE
}
public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException {
- var entity = new TextEmbeddingBitResults(
+ var entity = new DenseEmbeddingBitResults(
List.of(
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }),
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 24 })
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }),
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 24 })
)
);
@@ -85,10 +85,10 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I
}
public void testTransformToCoordinationFormat() {
- var results = new TextEmbeddingBitResults(
+ var results = new DenseEmbeddingBitResults(
List.of(
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }),
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 })
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }),
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 })
)
).transformToCoordinationFormat();
@@ -96,18 +96,18 @@ public void testTransformToCoordinationFormat() {
results,
is(
List.of(
- new MlTextEmbeddingResults(TextEmbeddingBitResults.TEXT_EMBEDDING_BITS, new double[] { 23F, 24F }, false),
- new MlTextEmbeddingResults(TextEmbeddingBitResults.TEXT_EMBEDDING_BITS, new double[] { 25F, 26F }, false)
+ new MlDenseEmbeddingResults(DenseEmbeddingBitResults.TEXT_EMBEDDING_BITS, new double[] { 23F, 24F }, false),
+ new MlDenseEmbeddingResults(DenseEmbeddingBitResults.TEXT_EMBEDDING_BITS, new double[] { 25F, 26F }, false)
)
)
);
}
public void testGetFirstEmbeddingSize() {
- var firstEmbeddingSize = new TextEmbeddingBitResults(
+ var firstEmbeddingSize = new DenseEmbeddingBitResults(
List.of(
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }),
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 })
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }),
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 })
)
).getFirstEmbeddingSize();
@@ -115,33 +115,33 @@ public void testGetFirstEmbeddingSize() {
}
@Override
- protected Writeable.Reader instanceReader() {
- return TextEmbeddingBitResults::new;
+ protected Writeable.Reader instanceReader() {
+ return DenseEmbeddingBitResults::new;
}
@Override
- protected TextEmbeddingBitResults createTestInstance() {
+ protected DenseEmbeddingBitResults createTestInstance() {
return createRandomResults();
}
@Override
- protected TextEmbeddingBitResults mutateInstance(TextEmbeddingBitResults instance) throws IOException {
+ protected DenseEmbeddingBitResults mutateInstance(DenseEmbeddingBitResults instance) throws IOException {
// if true we reduce the embeddings list by a random amount, if false we add an embedding to the list
if (randomBoolean()) {
// -1 to remove at least one item from the list
int end = randomInt(instance.embeddings().size() - 1);
- return new TextEmbeddingBitResults(instance.embeddings().subList(0, end));
+ return new DenseEmbeddingBitResults(instance.embeddings().subList(0, end));
} else {
- List embeddings = new ArrayList<>(instance.embeddings());
+ List embeddings = new ArrayList<>(instance.embeddings());
embeddings.add(createRandomEmbedding());
- return new TextEmbeddingBitResults(embeddings);
+ return new DenseEmbeddingBitResults(embeddings);
}
}
public static Map buildExpectationByte(List> embeddings) {
return Map.of(
- TextEmbeddingBitResults.TEXT_EMBEDDING_BITS,
- embeddings.stream().map(embedding -> Map.of(TextEmbeddingByteResults.Embedding.EMBEDDING, embedding)).toList()
+ DenseEmbeddingBitResults.TEXT_EMBEDDING_BITS,
+ embeddings.stream().map(embedding -> Map.of(DenseEmbeddingByteResults.Embedding.EMBEDDING, embedding)).toList()
);
}
}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingByteResultsTests.java
similarity index 50%
rename from x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResultsTests.java
rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingByteResultsTests.java
index 60f45399cfb32..53cd8690b7dc0 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResultsTests.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingByteResultsTests.java
@@ -10,7 +10,7 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import java.io.IOException;
import java.util.ArrayList;
@@ -20,19 +20,19 @@
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
-public class TextEmbeddingByteResultsTests extends AbstractWireSerializingTestCase {
- public static TextEmbeddingByteResults createRandomResults() {
+public class DenseEmbeddingByteResultsTests extends AbstractWireSerializingTestCase {
+ public static DenseEmbeddingByteResults createRandomResults() {
int embeddings = randomIntBetween(1, 10);
- List embeddingResults = new ArrayList<>(embeddings);
+ List embeddingResults = new ArrayList<>(embeddings);
for (int i = 0; i < embeddings; i++) {
embeddingResults.add(createRandomEmbedding());
}
- return new TextEmbeddingByteResults(embeddingResults);
+ return new DenseEmbeddingByteResults(embeddingResults);
}
- private static TextEmbeddingByteResults.Embedding createRandomEmbedding() {
+ private static DenseEmbeddingByteResults.Embedding createRandomEmbedding() {
int columns = randomIntBetween(1, 10);
byte[] bytes = new byte[columns];
@@ -40,11 +40,11 @@ private static TextEmbeddingByteResults.Embedding createRandomEmbedding() {
bytes[i] = randomByte();
}
- return new TextEmbeddingByteResults.Embedding(bytes);
+ return new DenseEmbeddingByteResults.Embedding(bytes);
}
public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException {
- var entity = new TextEmbeddingByteResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23 })));
+ var entity = new DenseEmbeddingByteResults(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23 })));
String xContentResult = Strings.toString(entity, true, true);
assertThat(xContentResult, is("""
@@ -60,10 +60,10 @@ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOE
}
public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException {
- var entity = new TextEmbeddingByteResults(
+ var entity = new DenseEmbeddingByteResults(
List.of(
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }),
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 24 })
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }),
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 24 })
)
);
@@ -86,10 +86,10 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I
}
public void testTransformToCoordinationFormat() {
- var results = new TextEmbeddingByteResults(
+ var results = new DenseEmbeddingByteResults(
List.of(
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }),
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 })
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }),
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 })
)
).transformToCoordinationFormat();
@@ -97,18 +97,18 @@ public void testTransformToCoordinationFormat() {
results,
is(
List.of(
- new MlTextEmbeddingResults(TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, new double[] { 23F, 24F }, false),
- new MlTextEmbeddingResults(TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, new double[] { 25F, 26F }, false)
+ new MlDenseEmbeddingResults(DenseEmbeddingByteResults.TEXT_EMBEDDING_BYTES, new double[] { 23F, 24F }, false),
+ new MlDenseEmbeddingResults(DenseEmbeddingByteResults.TEXT_EMBEDDING_BYTES, new double[] { 25F, 26F }, false)
)
)
);
}
public void testGetFirstEmbeddingSize() {
- var firstEmbeddingSize = new TextEmbeddingByteResults(
+ var firstEmbeddingSize = new DenseEmbeddingByteResults(
List.of(
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }),
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 })
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }),
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 })
)
).getFirstEmbeddingSize();
@@ -116,43 +116,43 @@ public void testGetFirstEmbeddingSize() {
}
public void testEmbeddingMerge() {
- TextEmbeddingByteResults.Embedding embedding1 = new TextEmbeddingByteResults.Embedding(new byte[] { 1, 1, -128 });
- TextEmbeddingByteResults.Embedding embedding2 = new TextEmbeddingByteResults.Embedding(new byte[] { 1, 0, 127 });
- TextEmbeddingByteResults.Embedding embedding3 = new TextEmbeddingByteResults.Embedding(new byte[] { 0, 0, 100 });
- TextEmbeddingByteResults.Embedding mergedEmbedding = embedding1.merge(embedding2);
- assertThat(mergedEmbedding, equalTo(new TextEmbeddingByteResults.Embedding(new byte[] { 1, 1, 0 })));
+ DenseEmbeddingByteResults.Embedding embedding1 = new DenseEmbeddingByteResults.Embedding(new byte[] { 1, 1, -128 });
+ DenseEmbeddingByteResults.Embedding embedding2 = new DenseEmbeddingByteResults.Embedding(new byte[] { 1, 0, 127 });
+ DenseEmbeddingByteResults.Embedding embedding3 = new DenseEmbeddingByteResults.Embedding(new byte[] { 0, 0, 100 });
+ DenseEmbeddingByteResults.Embedding mergedEmbedding = embedding1.merge(embedding2);
+ assertThat(mergedEmbedding, equalTo(new DenseEmbeddingByteResults.Embedding(new byte[] { 1, 1, 0 })));
mergedEmbedding = mergedEmbedding.merge(embedding3);
- assertThat(mergedEmbedding, equalTo(new TextEmbeddingByteResults.Embedding(new byte[] { 1, 0, 33 })));
+ assertThat(mergedEmbedding, equalTo(new DenseEmbeddingByteResults.Embedding(new byte[] { 1, 0, 33 })));
}
@Override
- protected Writeable.Reader instanceReader() {
- return TextEmbeddingByteResults::new;
+ protected Writeable.Reader instanceReader() {
+ return DenseEmbeddingByteResults::new;
}
@Override
- protected TextEmbeddingByteResults createTestInstance() {
+ protected DenseEmbeddingByteResults createTestInstance() {
return createRandomResults();
}
@Override
- protected TextEmbeddingByteResults mutateInstance(TextEmbeddingByteResults instance) throws IOException {
+ protected DenseEmbeddingByteResults mutateInstance(DenseEmbeddingByteResults instance) throws IOException {
// if true we reduce the embeddings list by a random amount, if false we add an embedding to the list
if (randomBoolean()) {
// -1 to remove at least one item from the list
int end = randomInt(instance.embeddings().size() - 1);
- return new TextEmbeddingByteResults(instance.embeddings().subList(0, end));
+ return new DenseEmbeddingByteResults(instance.embeddings().subList(0, end));
} else {
- List embeddings = new ArrayList<>(instance.embeddings());
+ List embeddings = new ArrayList<>(instance.embeddings());
embeddings.add(createRandomEmbedding());
- return new TextEmbeddingByteResults(embeddings);
+ return new DenseEmbeddingByteResults(embeddings);
}
}
public static Map buildExpectationByte(List> embeddings) {
return Map.of(
- TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES,
- embeddings.stream().map(embedding -> Map.of(TextEmbeddingByteResults.Embedding.EMBEDDING, embedding)).toList()
+ DenseEmbeddingByteResults.TEXT_EMBEDDING_BYTES,
+ embeddings.stream().map(embedding -> Map.of(DenseEmbeddingByteResults.Embedding.EMBEDDING, embedding)).toList()
);
}
}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingFloatResultsTests.java
similarity index 50%
rename from x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResultsTests.java
rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingFloatResultsTests.java
index 8cdd98bcdebc6..4481692127979 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResultsTests.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingFloatResultsTests.java
@@ -10,7 +10,7 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import java.io.IOException;
import java.util.ArrayList;
@@ -20,30 +20,30 @@
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
-public class TextEmbeddingFloatResultsTests extends AbstractWireSerializingTestCase {
- public static TextEmbeddingFloatResults createRandomResults() {
+public class DenseEmbeddingFloatResultsTests extends AbstractWireSerializingTestCase {
+ public static DenseEmbeddingFloatResults createRandomResults() {
int embeddings = randomIntBetween(1, 10);
- List embeddingResults = new ArrayList<>(embeddings);
+ List embeddingResults = new ArrayList<>(embeddings);
for (int i = 0; i < embeddings; i++) {
embeddingResults.add(createRandomEmbedding());
}
- return new TextEmbeddingFloatResults(embeddingResults);
+ return new DenseEmbeddingFloatResults(embeddingResults);
}
- private static TextEmbeddingFloatResults.Embedding createRandomEmbedding() {
+ private static DenseEmbeddingFloatResults.Embedding createRandomEmbedding() {
int columns = randomIntBetween(1, 10);
float[] floats = new float[columns];
for (int i = 0; i < columns; i++) {
floats[i] = randomFloat();
}
- return new TextEmbeddingFloatResults.Embedding(floats);
+ return new DenseEmbeddingFloatResults.Embedding(floats);
}
public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException {
- var entity = new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F })));
+ var entity = new DenseEmbeddingFloatResults(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F })));
String xContentResult = Strings.toString(entity, true, true);
assertThat(xContentResult, is("""
@@ -59,10 +59,10 @@ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOE
}
public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException {
- var entity = new TextEmbeddingFloatResults(
+ var entity = new DenseEmbeddingFloatResults(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.2F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.2F })
)
);
@@ -86,10 +86,10 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I
}
public void testTransformToCoordinationFormat() {
- var results = new TextEmbeddingFloatResults(
+ var results = new DenseEmbeddingFloatResults(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.2F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.3F, 0.4F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.2F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.3F, 0.4F })
)
).transformToCoordinationFormat();
@@ -97,18 +97,18 @@ public void testTransformToCoordinationFormat() {
results,
is(
List.of(
- new MlTextEmbeddingResults(TextEmbeddingFloatResults.TEXT_EMBEDDING, new double[] { 0.1F, 0.2F }, false),
- new MlTextEmbeddingResults(TextEmbeddingFloatResults.TEXT_EMBEDDING, new double[] { 0.3F, 0.4F }, false)
+ new MlDenseEmbeddingResults(DenseEmbeddingFloatResults.TEXT_EMBEDDING, new double[] { 0.1F, 0.2F }, false),
+ new MlDenseEmbeddingResults(DenseEmbeddingFloatResults.TEXT_EMBEDDING, new double[] { 0.3F, 0.4F }, false)
)
)
);
}
public void testGetFirstEmbeddingSize() {
- var firstEmbeddingSize = new TextEmbeddingFloatResults(
+ var firstEmbeddingSize = new DenseEmbeddingFloatResults(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.2F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.3F, 0.4F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.2F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.3F, 0.4F })
)
).getFirstEmbeddingSize();
@@ -116,51 +116,54 @@ public void testGetFirstEmbeddingSize() {
}
public void testEmbeddingMerge() {
- TextEmbeddingFloatResults.Embedding embedding1 = new TextEmbeddingFloatResults.Embedding(new float[] { 0.1f, 0.2f, 0.3f, 0.4f });
- TextEmbeddingFloatResults.Embedding embedding2 = new TextEmbeddingFloatResults.Embedding(new float[] { 0.0f, 0.4f, 0.1f, 1.0f });
- TextEmbeddingFloatResults.Embedding embedding3 = new TextEmbeddingFloatResults.Embedding(new float[] { 0.2f, 0.9f, 0.8f, 0.1f });
- TextEmbeddingFloatResults.Embedding mergedEmbedding = embedding1.merge(embedding2);
- assertThat(mergedEmbedding, equalTo(new TextEmbeddingFloatResults.Embedding(new float[] { 0.05f, 0.3f, 0.2f, 0.7f })));
+ DenseEmbeddingFloatResults.Embedding embedding1 = new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1f, 0.2f, 0.3f, 0.4f });
+ DenseEmbeddingFloatResults.Embedding embedding2 = new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0f, 0.4f, 0.1f, 1.0f });
+ DenseEmbeddingFloatResults.Embedding embedding3 = new DenseEmbeddingFloatResults.Embedding(new float[] { 0.2f, 0.9f, 0.8f, 0.1f });
+ DenseEmbeddingFloatResults.Embedding mergedEmbedding = embedding1.merge(embedding2);
+ assertThat(mergedEmbedding, equalTo(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.05f, 0.3f, 0.2f, 0.7f })));
mergedEmbedding = mergedEmbedding.merge(embedding3);
- assertThat(mergedEmbedding, equalTo(new TextEmbeddingFloatResults.Embedding(new float[] { 0.1f, 0.5f, 0.4f, 0.5f })));
+ assertThat(mergedEmbedding, equalTo(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1f, 0.5f, 0.4f, 0.5f })));
}
@Override
- protected Writeable.Reader instanceReader() {
- return TextEmbeddingFloatResults::new;
+ protected Writeable.Reader instanceReader() {
+ return DenseEmbeddingFloatResults::new;
}
@Override
- protected TextEmbeddingFloatResults createTestInstance() {
+ protected DenseEmbeddingFloatResults createTestInstance() {
return createRandomResults();
}
@Override
- protected TextEmbeddingFloatResults mutateInstance(TextEmbeddingFloatResults instance) throws IOException {
+ protected DenseEmbeddingFloatResults mutateInstance(DenseEmbeddingFloatResults instance) throws IOException {
// if true we reduce the embeddings list by a random amount, if false we add an embedding to the list
if (randomBoolean()) {
// -1 to remove at least one item from the list
int end = randomInt(instance.embeddings().size() - 1);
- return new TextEmbeddingFloatResults(instance.embeddings().subList(0, end));
+ return new DenseEmbeddingFloatResults(instance.embeddings().subList(0, end));
} else {
- List embeddings = new ArrayList<>(instance.embeddings());
+ List embeddings = new ArrayList<>(instance.embeddings());
embeddings.add(createRandomEmbedding());
- return new TextEmbeddingFloatResults(embeddings);
+ return new DenseEmbeddingFloatResults(embeddings);
}
}
public static Map buildExpectationFloat(List embeddings) {
- return Map.of(TextEmbeddingFloatResults.TEXT_EMBEDDING, embeddings.stream().map(TextEmbeddingFloatResults.Embedding::new).toList());
+ return Map.of(
+ DenseEmbeddingFloatResults.TEXT_EMBEDDING,
+ embeddings.stream().map(DenseEmbeddingFloatResults.Embedding::new).toList()
+ );
}
public static Map buildExpectationByte(List embeddings) {
return Map.of(
- TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES,
- embeddings.stream().map(TextEmbeddingByteResults.Embedding::new).toList()
+ DenseEmbeddingByteResults.TEXT_EMBEDDING_BYTES,
+ embeddings.stream().map(DenseEmbeddingByteResults.Embedding::new).toList()
);
}
public static Map buildExpectationBinary(List embeddings) {
- return Map.of("text_embedding_bits", embeddings.stream().map(TextEmbeddingByteResults.Embedding::new).toList());
+ return Map.of("text_embedding_bits", embeddings.stream().map(DenseEmbeddingByteResults.Embedding::new).toList());
}
}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/LegacyMlTextEmbeddingResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResultsTests.java
similarity index 97%
rename from x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/LegacyMlTextEmbeddingResultsTests.java
rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResultsTests.java
index 6251881e41b8e..ad416b4d87f91 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/LegacyMlTextEmbeddingResultsTests.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResultsTests.java
@@ -22,7 +22,7 @@
import static org.hamcrest.Matchers.is;
@SuppressWarnings("deprecation")
-public class LegacyMlTextEmbeddingResultsTests extends AbstractWireSerializingTestCase {
+public class LegacyTextEmbeddingResultsTests extends AbstractWireSerializingTestCase {
public static LegacyTextEmbeddingResults createRandomResults() {
int embeddings = randomIntBetween(1, 10);
List embeddingResults = new ArrayList<>(embeddings);
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java
index 87049d6bde90c..b897dc8aef6dd 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java
@@ -17,8 +17,8 @@
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResultsTests;
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResultsTests;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResultsTests;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResultsTests;
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
import org.elasticsearch.xpack.core.ml.inference.results.NerResultsTests;
import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults;
@@ -50,7 +50,7 @@ public class InferModelActionResponseTests extends AbstractWireSerializingTestCa
PyTorchPassThroughResults.NAME,
QuestionAnsweringInferenceResults.NAME,
RegressionInferenceResults.NAME,
- MlTextEmbeddingResults.NAME,
+ MlDenseEmbeddingResults.NAME,
TextExpansionResults.NAME,
TextSimilarityInferenceResults.NAME,
WarningInferenceResults.NAME
@@ -87,7 +87,7 @@ private static InferenceResults randomInferenceResult(String resultType) {
case PyTorchPassThroughResults.NAME -> PyTorchPassThroughResultsTests.createRandomResults();
case QuestionAnsweringInferenceResults.NAME -> QuestionAnsweringInferenceResultsTests.createRandomResults();
case RegressionInferenceResults.NAME -> RegressionInferenceResultsTests.createRandomResults();
- case MlTextEmbeddingResults.NAME -> MlTextEmbeddingResultsTests.createRandomResults();
+ case MlDenseEmbeddingResults.NAME -> MlDenseEmbeddingResultsTests.createRandomResults();
case TextExpansionResults.NAME -> TextExpansionResultsTests.createRandomResults();
case TextSimilarityInferenceResults.NAME -> TextSimilarityInferenceResultsTests.createRandomResults();
case WarningInferenceResults.NAME -> WarningInferenceResultsTests.createRandomResults();
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentResponseTests.java
index eb373080eee4a..8c17f9e62b494 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentResponseTests.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentResponseTests.java
@@ -14,7 +14,7 @@
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResultsTests;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResultsTests;
import org.junit.Before;
import java.util.List;
@@ -50,10 +50,10 @@ protected Writeable.Reader instanceR
protected InferTrainedModelDeploymentAction.Response createTestInstance() {
return new InferTrainedModelDeploymentAction.Response(
List.of(
- MlTextEmbeddingResultsTests.createRandomResults(),
- MlTextEmbeddingResultsTests.createRandomResults(),
- MlTextEmbeddingResultsTests.createRandomResults(),
- MlTextEmbeddingResultsTests.createRandomResults()
+ MlDenseEmbeddingResultsTests.createRandomResults(),
+ MlDenseEmbeddingResultsTests.createRandomResults(),
+ MlDenseEmbeddingResultsTests.createRandomResults(),
+ MlDenseEmbeddingResultsTests.createRandomResults()
)
);
}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/MlTextEmbeddingResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/MlDenseEmbeddingResultsTests.java
similarity index 68%
rename from x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/MlTextEmbeddingResultsTests.java
rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/MlDenseEmbeddingResultsTests.java
index 3338609eebdc3..b5aee244771ea 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/MlTextEmbeddingResultsTests.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/MlDenseEmbeddingResultsTests.java
@@ -16,35 +16,35 @@
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
-public class MlTextEmbeddingResultsTests extends InferenceResultsTestCase {
+public class MlDenseEmbeddingResultsTests extends InferenceResultsTestCase {
- public static MlTextEmbeddingResults createRandomResults() {
+ public static MlDenseEmbeddingResults createRandomResults() {
int columns = randomIntBetween(1, 10);
double[] arr = new double[columns];
for (int i = 0; i < columns; i++) {
arr[i] = randomDouble();
}
- return new MlTextEmbeddingResults(DEFAULT_RESULTS_FIELD, arr, randomBoolean());
+ return new MlDenseEmbeddingResults(DEFAULT_RESULTS_FIELD, arr, randomBoolean());
}
@Override
- protected Writeable.Reader instanceReader() {
- return MlTextEmbeddingResults::new;
+ protected Writeable.Reader instanceReader() {
+ return MlDenseEmbeddingResults::new;
}
@Override
- protected MlTextEmbeddingResults createTestInstance() {
+ protected MlDenseEmbeddingResults createTestInstance() {
return createRandomResults();
}
@Override
- protected MlTextEmbeddingResults mutateInstance(MlTextEmbeddingResults instance) {
+ protected MlDenseEmbeddingResults mutateInstance(MlDenseEmbeddingResults instance) {
return null;// TODO implement https://github.com/elastic/elasticsearch/issues/25929
}
public void testAsMap() {
- MlTextEmbeddingResults testInstance = createTestInstance();
+ MlDenseEmbeddingResults testInstance = createTestInstance();
Map asMap = testInstance.asMap();
int size = testInstance.isTruncated ? 2 : 1;
assertThat(asMap.keySet(), hasSize(size));
@@ -55,7 +55,7 @@ public void testAsMap() {
}
@Override
- void assertFieldValues(MlTextEmbeddingResults createdInstance, IngestDocument document, String parentField, String resultsField) {
+ void assertFieldValues(MlDenseEmbeddingResults createdInstance, IngestDocument document, String parentField, String resultsField) {
assertArrayEquals(document.getFieldValue(parentField + resultsField, double[].class), createdInstance.getInference(), 1e-10);
}
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java
index a2b0a32e77b05..6cc27f931cfc8 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java
@@ -12,14 +12,14 @@
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingResults;
import org.elasticsearch.xpack.esql.inference.InferenceOperator;
/**
* {@link TextEmbeddingOperatorOutputBuilder} builds the output page for text embedding by converting
- * {@link TextEmbeddingResults} into a {@link FloatBlock} containing dense vector embeddings.
+ * {@link DenseEmbeddingResults} into a {@link FloatBlock} containing dense vector embeddings.
*/
class TextEmbeddingOperatorOutputBuilder implements InferenceOperator.OutputBuilder {
private final Page inputPage;
@@ -39,7 +39,7 @@ public void close() {
* Adds an inference response to the output builder.
*
*
- * If the response is null or not of type {@link TextEmbeddingResults} an {@link IllegalStateException} is thrown.
+ * If the response is null or not of type {@link DenseEmbeddingResults} an {@link IllegalStateException} is thrown.
* Else, the embedding vector is added to the output block as a multi-value position.
*
*
@@ -55,7 +55,7 @@ public void addInferenceResponse(InferenceAction.Response inferenceResponse) {
return;
}
- TextEmbeddingResults> embeddingResults = inferenceResults(inferenceResponse);
+ DenseEmbeddingResults> embeddingResults = inferenceResults(inferenceResponse);
var embeddings = embeddingResults.embeddings();
if (embeddings.isEmpty()) {
@@ -82,21 +82,25 @@ public Page buildOutput() {
return inputPage.appendBlock(outputBlock);
}
- private TextEmbeddingResults> inferenceResults(InferenceAction.Response inferenceResponse) {
- return InferenceOperator.OutputBuilder.inferenceResults(inferenceResponse, TextEmbeddingResults.class);
+ private DenseEmbeddingResults> inferenceResults(InferenceAction.Response inferenceResponse) {
+ return InferenceOperator.OutputBuilder.inferenceResults(inferenceResponse, DenseEmbeddingResults.class);
}
/**
* Extracts the embedding as a float array from the embedding result.
*/
- private static float[] getEmbeddingAsFloatArray(TextEmbeddingResults> embedding) {
+ private static float[] getEmbeddingAsFloatArray(DenseEmbeddingResults> embedding) {
return switch (embedding.embeddings().get(0)) {
- case TextEmbeddingFloatResults.Embedding floatEmbedding -> floatEmbedding.values();
- case TextEmbeddingByteResults.Embedding byteEmbedding -> toFloatArray(byteEmbedding.values());
+ case DenseEmbeddingFloatResults.Embedding floatEmbedding -> floatEmbedding.values();
+ case DenseEmbeddingByteResults.Embedding byteEmbedding -> toFloatArray(byteEmbedding.values());
default -> throw new IllegalArgumentException(
"Unsupported embedding type: "
+ embedding.embeddings().get(0).getClass().getName()
- + ". Expected TextEmbeddingFloatResults.Embedding or TextEmbeddingByteResults.Embedding."
+ + ". Expected "
+ + DenseEmbeddingFloatResults.Embedding.class.getName()
+ + " or "
+ + DenseEmbeddingByteResults.Embedding.class.getName()
+ + "."
);
};
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilderTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilderTests.java
index ea77c6bed3c38..ac1fb5daa36a5 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilderTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilderTests.java
@@ -15,8 +15,8 @@
import org.elasticsearch.compute.test.RandomBlock;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import java.util.List;
@@ -182,7 +182,7 @@ private void assertByteEmbeddingContent(FloatBlock block, byte[][] expectedByteE
int firstValueIndex = block.getFirstValueIndex(currentPos);
for (int i = 0; i < expectedByteEmbeddings[currentPos].length; i++) {
float actualValue = block.getFloat(firstValueIndex + i);
- // Convert byte to float the same way as TextEmbeddingByteResults.Embedding.toFloatArray()
+ // Convert byte to float the same way as DenseEmbeddingByteResults.Embedding.toFloatArray()
float expectedValue = expectedByteEmbeddings[currentPos][i];
assertThat(actualValue, equalTo(expectedValue));
}
@@ -206,20 +206,20 @@ private byte[] randomByteEmbedding(int dimension) {
}
private static InferenceAction.Response createFloatEmbeddingResponse(float[] embedding) {
- var embeddingResult = new TextEmbeddingFloatResults.Embedding(embedding);
- var textEmbeddingResults = new TextEmbeddingFloatResults(List.of(embeddingResult));
- return new InferenceAction.Response(textEmbeddingResults);
+ var embeddingResult = new DenseEmbeddingFloatResults.Embedding(embedding);
+ var denseEmbeddingResults = new DenseEmbeddingFloatResults(List.of(embeddingResult));
+ return new InferenceAction.Response(denseEmbeddingResults);
}
private static InferenceAction.Response createByteEmbeddingResponse(byte[] embedding) {
- var embeddingResult = new TextEmbeddingByteResults.Embedding(embedding);
- var textEmbeddingResults = new TextEmbeddingByteResults(List.of(embeddingResult));
- return new InferenceAction.Response(textEmbeddingResults);
+ var embeddingResult = new DenseEmbeddingByteResults.Embedding(embedding);
+ var denseEmbeddingResults = new DenseEmbeddingByteResults(List.of(embeddingResult));
+ return new InferenceAction.Response(denseEmbeddingResults);
}
private static InferenceAction.Response createEmptyFloatEmbeddingResponse() {
- var textEmbeddingResults = new TextEmbeddingFloatResults(List.of());
- return new InferenceAction.Response(textEmbeddingResults);
+ var denseEmbeddingResults = new DenseEmbeddingFloatResults(List.of());
+ return new InferenceAction.Response(denseEmbeddingResults);
}
private Page randomInputPage(int positionCount, int columnCount) {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorTests.java
index 6ff9a90b70b16..06441be9e7148 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorTests.java
@@ -13,7 +13,7 @@
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.esql.inference.InferenceOperatorTestCase;
import org.hamcrest.Matcher;
import org.junit.Before;
@@ -23,7 +23,7 @@
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
-public class TextEmbeddingOperatorTests extends InferenceOperatorTestCase {
+public class TextEmbeddingOperatorTests extends InferenceOperatorTestCase {
private static final String SIMPLE_INFERENCE_ID = "test_text_embedding";
private static final int EMBEDDING_DIMENSION = 384; // Common embedding dimension
@@ -89,15 +89,15 @@ private void assertTextEmbeddingResults(Page inputPage, Page resultPage) {
}
@Override
- protected TextEmbeddingFloatResults mockInferenceResult(InferenceAction.Request request) {
+ protected DenseEmbeddingFloatResults mockInferenceResult(InferenceAction.Request request) {
// For text embedding, we expect one input text per request
String inputText = request.getInput().get(0);
// Generate a deterministic mock embedding based on the input text
float[] mockEmbedding = generateMockEmbedding(inputText, EMBEDDING_DIMENSION);
- var embeddingResult = new TextEmbeddingFloatResults.Embedding(mockEmbedding);
- return new TextEmbeddingFloatResults(List.of(embeddingResult));
+ var embeddingResult = new DenseEmbeddingFloatResults.Embedding(mockEmbedding);
+ return new DenseEmbeddingFloatResults(List.of(embeddingResult));
}
@Override
diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java
index 051b6dbf3e8fa..45ecb3dedf3f1 100644
--- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java
+++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java
@@ -36,7 +36,7 @@
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
@@ -167,22 +167,22 @@ public void chunkedInfer(
}
}
- private TextEmbeddingFloatResults makeResults(List input, ServiceSettings serviceSettings) {
- List embeddings = new ArrayList<>();
+ private DenseEmbeddingFloatResults makeResults(List input, ServiceSettings serviceSettings) {
+ List embeddings = new ArrayList<>();
for (String inputString : input) {
List floatEmbeddings = generateEmbedding(inputString, serviceSettings.dimensions(), serviceSettings.elementType());
- embeddings.add(TextEmbeddingFloatResults.Embedding.of(floatEmbeddings));
+ embeddings.add(DenseEmbeddingFloatResults.Embedding.of(floatEmbeddings));
}
- return new TextEmbeddingFloatResults(embeddings);
+ return new DenseEmbeddingFloatResults(embeddings);
}
private List makeChunkedResults(List inputs, ServiceSettings serviceSettings) {
var results = new ArrayList();
for (ChunkInferenceInput input : inputs) {
List chunkedInput = chunkInputs(input);
- List chunks = chunkedInput.stream()
+ List chunks = chunkedInput.stream()
.map(
- c -> new TextEmbeddingFloatResults.Chunk(
+ c -> new DenseEmbeddingFloatResults.Chunk(
makeResults(List.of(c.input()), serviceSettings).embeddings().get(0),
new ChunkedInference.TextOffset(c.startOffset(), c.endOffset())
)
diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java
index 28a191a1bbfac..9ea2301abfa0c 100644
--- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java
+++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java
@@ -35,9 +35,9 @@
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.inference.DequeUtils;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import java.io.IOException;
import java.util.ArrayList;
@@ -138,9 +138,9 @@ public void infer(
)
);
} else {
- // Return text embedding results when creating a sparse_embedding inference endpoint to allow creation validation to
- // pass. This is required to test that streaming fails for a sparse_embedding endpoint.
- listener.onResponse(makeTextEmbeddingResults(input));
+ // Return dense embedding results when creating a sparse_embedding inference endpoint to allow creation validation
+ // to pass. This is required to test that streaming fails for a sparse_embedding endpoint.
+ listener.onResponse(makeDenseEmbeddingResults(input));
}
}
default -> listener.onFailure(
@@ -189,16 +189,16 @@ public void cancel() {}
});
}
- private TextEmbeddingFloatResults makeTextEmbeddingResults(List input) {
- var embeddings = new ArrayList();
+ private DenseEmbeddingFloatResults makeDenseEmbeddingResults(List input) {
+ var embeddings = new ArrayList();
for (int i = 0; i < input.size(); i++) {
var values = new float[5];
for (int j = 0; j < 5; j++) {
values[j] = random.nextFloat();
}
- embeddings.add(new TextEmbeddingFloatResults.Embedding(values));
+ embeddings.add(new DenseEmbeddingFloatResults.Embedding(values));
}
- return new TextEmbeddingFloatResults(embeddings);
+ return new DenseEmbeddingFloatResults(embeddings);
}
private InferenceServiceResults.Result completionChunk(String delta) {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
index e7008c2292def..a8cb6767de3ed 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
@@ -21,13 +21,13 @@
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
import org.elasticsearch.xpack.inference.chunking.NoneChunkingSettings;
import org.elasticsearch.xpack.inference.chunking.RecursiveChunkingSettings;
@@ -73,10 +73,10 @@
import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings;
import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
+import org.elasticsearch.xpack.inference.services.custom.response.DenseEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser;
-import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings;
@@ -206,7 +206,11 @@ private static void addCustomNamedWriteables(List
namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, CustomSecretSettings.NAME, CustomSecretSettings::new));
namedWriteables.add(
- new NamedWriteableRegistry.Entry(CustomResponseParser.class, TextEmbeddingResponseParser.NAME, TextEmbeddingResponseParser::new)
+ new NamedWriteableRegistry.Entry(
+ CustomResponseParser.class,
+ DenseEmbeddingResponseParser.NAME,
+ DenseEmbeddingResponseParser::new
+ )
);
namedWriteables.add(
@@ -657,10 +661,14 @@ private static void addInferenceResultsNamedWriteables(List
*
*/
- public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
+ public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) {
- return EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults();
+ return EmbeddingFloatResult.PARSER.apply(p, null).toDenseEmbeddingFloatResults();
}
}
@@ -81,9 +81,9 @@ public record EmbeddingFloatResult(List embeddingResu
}, new ParseField("data"), org.elasticsearch.xcontent.ObjectParser.ValueType.OBJECT_ARRAY);
}
- public TextEmbeddingFloatResults toTextEmbeddingFloatResults() {
- return new TextEmbeddingFloatResults(
- embeddingResults.stream().map(entry -> TextEmbeddingFloatResults.Embedding.of(entry.embedding)).toList()
+ public DenseEmbeddingFloatResults toDenseEmbeddingFloatResults() {
+ return new DenseEmbeddingFloatResults(
+ embeddingResults.stream().map(entry -> DenseEmbeddingFloatResults.Embedding.of(entry.embedding)).toList()
);
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java
index 1a8b162eb1b46..1e9a66c776b53 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java
@@ -80,7 +80,7 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter;
@@ -967,13 +967,13 @@ yield new SparseVectorQueryBuilder(
);
}
case TEXT_EMBEDDING -> {
- if (inferenceResults instanceof MlTextEmbeddingResults == false) {
+ if (inferenceResults instanceof MlDenseEmbeddingResults == false) {
throw new IllegalArgumentException(
- generateQueryInferenceResultsTypeMismatchMessage(inferenceResults, MlTextEmbeddingResults.NAME)
+ generateQueryInferenceResultsTypeMismatchMessage(inferenceResults, MlDenseEmbeddingResults.NAME)
);
}
- MlTextEmbeddingResults textEmbeddingResults = (MlTextEmbeddingResults) inferenceResults;
+ MlDenseEmbeddingResults textEmbeddingResults = (MlDenseEmbeddingResults) inferenceResults;
float[] inference = textEmbeddingResults.getInferenceAsFloat();
int dimensions = modelSettings.elementType() == DenseVectorFieldMapper.ElementType.BIT
? inference.length * Byte.SIZE // Bit vectors encode 8 dimensions into each byte value
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java
index 210ab2b67f9c9..cd590b8410234 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java
@@ -25,7 +25,7 @@
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.search.vectors.QueryVectorBuilder;
import org.elasticsearch.search.vectors.VectorData;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
@@ -195,7 +195,7 @@ private QueryBuilder querySemanticTextField(String clusterAlias, SemanticTextFie
fullyQualifiedInferenceId = new FullyQualifiedInferenceId(clusterAlias, semanticTextFieldType.getSearchInferenceId());
}
- MlTextEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(fullyQualifiedInferenceId);
+ MlDenseEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(fullyQualifiedInferenceId);
queryVector = new VectorData(textEmbeddingResults.getInferenceAsFloat());
}
@@ -226,7 +226,7 @@ private QueryBuilder queryNonSemanticTextField() {
throw new IllegalStateException("No query vector or query vector builder model ID specified");
}
- MlTextEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(fullyQualifiedInferenceId);
+ MlDenseEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(fullyQualifiedInferenceId);
queryVector = new VectorData(textEmbeddingResults.getInferenceAsFloat());
}
@@ -244,20 +244,20 @@ private QueryBuilder queryNonSemanticTextField() {
return knnQuery;
}
- private MlTextEmbeddingResults getTextEmbeddingResults(FullyQualifiedInferenceId fullyQualifiedInferenceId) {
+ private MlDenseEmbeddingResults getTextEmbeddingResults(FullyQualifiedInferenceId fullyQualifiedInferenceId) {
InferenceResults inferenceResults = inferenceResultsMap.get(fullyQualifiedInferenceId);
if (inferenceResults == null) {
throw new IllegalStateException("Could not find inference results from inference endpoint [" + fullyQualifiedInferenceId + "]");
- } else if (inferenceResults instanceof MlTextEmbeddingResults == false) {
+ } else if (inferenceResults instanceof MlDenseEmbeddingResults == false) {
throw new IllegalArgumentException(
"Expected query inference results to be of type ["
- + MlTextEmbeddingResults.NAME
+ + MlDenseEmbeddingResults.NAME
+ "], got ["
+ inferenceResults.getWriteableName()
+ "]. Are you specifying a compatible inference endpoint? Has the inference endpoint configuration changed?"
);
}
- return (MlTextEmbeddingResults) inferenceResults;
+ return (MlDenseEmbeddingResults) inferenceResults;
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java
index 97d0caef98d0e..1c43358a4fbfd 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java
@@ -33,7 +33,7 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.inference.InferenceException;
@@ -505,7 +505,7 @@ private static InferenceResults validateAndConvertInferenceResults(
InferenceResults inferenceResults = inferenceResultsList.getFirst();
if (inferenceResults instanceof TextExpansionResults == false
- && inferenceResults instanceof MlTextEmbeddingResults == false
+ && inferenceResults instanceof MlDenseEmbeddingResults == false
&& inferenceResults instanceof ErrorInferenceResults == false
&& inferenceResults instanceof WarningInferenceResults == false) {
return new ErrorInferenceResults(
@@ -513,7 +513,7 @@ private static InferenceResults validateAndConvertInferenceResults(
"Expected query inference results to be of type ["
+ TextExpansionResults.NAME
+ "] or ["
- + MlTextEmbeddingResults.NAME
+ + MlDenseEmbeddingResults.NAME
+ "], got ["
+ inferenceResults.getWriteableName()
+ "]. Has the inference endpoint ["
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntity.java
index 4e73f03e2898b..0a314202c922c 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntity.java
@@ -9,7 +9,7 @@
import org.elasticsearch.common.xcontent.XContentParserUtils;
import org.elasticsearch.xcontent.XContentParser;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
@@ -70,20 +70,20 @@ public class AlibabaCloudSearchEmbeddingsResponseEntity extends AlibabaCloudSear
*
*
*/
- public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
+ public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
return fromResponse(request, response, parser -> {
positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE);
- List embeddingList = XContentParserUtils.parseList(
+ List embeddingList = XContentParserUtils.parseList(
parser,
AlibabaCloudSearchEmbeddingsResponseEntity::parseEmbeddingObject
);
- return new TextEmbeddingFloatResults(embeddingList);
+ return new DenseEmbeddingFloatResults(embeddingList);
});
}
- private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException {
+ private static DenseEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException {
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);
@@ -95,7 +95,7 @@ private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContent
// if there are additional fields within this object, lets skip them, so we can begin parsing the next embedding array
parser.skipChildren();
- return TextEmbeddingFloatResults.Embedding.of(embeddingValues);
+ return DenseEmbeddingFloatResults.Embedding.of(embeddingValues);
}
private static float parseEmbeddingList(XContentParser parser) throws IOException {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java
index 831bf9938c211..61ec5f0c39790 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java
@@ -16,7 +16,7 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.response.XContentUtils;
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider;
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.AmazonBedrockRequest;
@@ -48,7 +48,7 @@ public InferenceServiceResults accept(AmazonBedrockRequest request) {
throw new ElasticsearchException("unexpected request type [" + request.getClass() + "]");
}
- public static TextEmbeddingFloatResults fromResponse(InvokeModelResponse response, AmazonBedrockProvider provider) {
+ public static DenseEmbeddingFloatResults fromResponse(InvokeModelResponse response, AmazonBedrockProvider provider) {
var charset = StandardCharsets.UTF_8;
var bodyText = String.valueOf(charset.decode(response.body().asByteBuffer()));
@@ -63,13 +63,13 @@ public static TextEmbeddingFloatResults fromResponse(InvokeModelResponse respons
var embeddingList = parseEmbeddings(jsonParser, provider);
- return new TextEmbeddingFloatResults(embeddingList);
+ return new DenseEmbeddingFloatResults(embeddingList);
} catch (IOException e) {
throw new ElasticsearchException(e);
}
}
- private static List parseEmbeddings(XContentParser jsonParser, AmazonBedrockProvider provider)
+ private static List parseEmbeddings(XContentParser jsonParser, AmazonBedrockProvider provider)
throws IOException {
switch (provider) {
case AMAZONTITAN -> {
@@ -82,7 +82,7 @@ private static List parseEmbeddings(XConten
}
}
- private static List parseTitanEmbeddings(XContentParser parser) throws IOException {
+ private static List parseTitanEmbeddings(XContentParser parser) throws IOException {
/*
Titan response:
{
@@ -92,11 +92,11 @@ private static List parseTitanEmbeddings(XC
*/
positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);
List embeddingValuesList = parseList(parser, XContentUtils::parseFloat);
- var embeddingValues = TextEmbeddingFloatResults.Embedding.of(embeddingValuesList);
+ var embeddingValues = DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList);
return List.of(embeddingValues);
}
- private static List parseCohereEmbeddings(XContentParser parser) throws IOException {
+ private static List parseCohereEmbeddings(XContentParser parser) throws IOException {
/*
Cohere response:
{
@@ -111,7 +111,7 @@ private static List parseCohereEmbeddings(X
*/
positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE);
- List embeddingList = parseList(
+ List embeddingList = parseList(
parser,
AmazonBedrockEmbeddingsResponse::parseCohereEmbeddingsListItem
);
@@ -119,9 +119,9 @@ private static List parseCohereEmbeddings(X
return embeddingList;
}
- private static TextEmbeddingFloatResults.Embedding parseCohereEmbeddingsListItem(XContentParser parser) throws IOException {
+ private static DenseEmbeddingFloatResults.Embedding parseCohereEmbeddingsListItem(XContentParser parser) throws IOException {
List embeddingValuesList = parseList(parser, XContentUtils::parseFloat);
- return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList);
+ return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList);
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntity.java
index b4a2e142b3792..a5a174e80b289 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntity.java
@@ -15,9 +15,9 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.XContentUtils;
@@ -189,20 +189,20 @@ private static InferenceServiceResults parseBitEmbeddingsArray(XContentParser pa
// Cohere returns array of binary embeddings encoded as bytes with int8 precision so we can reuse the byte parser
var embeddingList = parseList(parser, CohereEmbeddingsResponseEntity::parseByteArrayEntry);
- return new TextEmbeddingBitResults(embeddingList);
+ return new DenseEmbeddingBitResults(embeddingList);
}
private static InferenceServiceResults parseByteEmbeddingsArray(XContentParser parser) throws IOException {
var embeddingList = parseList(parser, CohereEmbeddingsResponseEntity::parseByteArrayEntry);
- return new TextEmbeddingByteResults(embeddingList);
+ return new DenseEmbeddingByteResults(embeddingList);
}
- private static TextEmbeddingByteResults.Embedding parseByteArrayEntry(XContentParser parser) throws IOException {
+ private static DenseEmbeddingByteResults.Embedding parseByteArrayEntry(XContentParser parser) throws IOException {
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
List embeddingValuesList = parseList(parser, CohereEmbeddingsResponseEntity::parseEmbeddingInt8Entry);
- return TextEmbeddingByteResults.Embedding.of(embeddingValuesList);
+ return DenseEmbeddingByteResults.Embedding.of(embeddingValuesList);
}
private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException {
@@ -223,13 +223,13 @@ private static void checkByteBounds(short value) {
private static InferenceServiceResults parseFloatEmbeddingsArray(XContentParser parser) throws IOException {
var embeddingList = parseList(parser, CohereEmbeddingsResponseEntity::parseFloatArrayEntry);
- return new TextEmbeddingFloatResults(embeddingList);
+ return new DenseEmbeddingFloatResults(embeddingList);
}
- private static TextEmbeddingFloatResults.Embedding parseFloatArrayEntry(XContentParser parser) throws IOException {
+ private static DenseEmbeddingFloatResults.Embedding parseFloatArrayEntry(XContentParser parser) throws IOException {
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
List embeddingValuesList = parseList(parser, XContentUtils::parseFloat);
- return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList);
+ return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList);
}
private CohereEmbeddingsResponseEntity() {}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java
index 5986b2104ffb1..44901b49a7e47 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java
@@ -25,10 +25,10 @@
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
+import org.elasticsearch.xpack.inference.services.custom.response.DenseEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser;
-import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@@ -501,7 +501,7 @@ private static CustomResponseParser extractResponseParser(
}
return switch (taskType) {
- case TEXT_EMBEDDING -> TextEmbeddingResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException);
+ case TEXT_EMBEDDING -> DenseEmbeddingResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException);
case SPARSE_EMBEDDING -> SparseEmbeddingResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException);
case RERANK -> RerankResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException);
case COMPLETION -> CompletionResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException);
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/DenseEmbeddingResponseParser.java
similarity index 80%
rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java
rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/DenseEmbeddingResponseParser.java
index f665c0be81511..18aaf186b4a3f 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/DenseEmbeddingResponseParser.java
@@ -14,9 +14,9 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.XContentBuilder;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.common.MapPathExtractor;
import org.elasticsearch.xpack.inference.services.custom.CustomServiceEmbeddingType;
@@ -31,8 +31,8 @@
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER;
-public class TextEmbeddingResponseParser extends BaseCustomResponseParser {
-
+public class DenseEmbeddingResponseParser extends BaseCustomResponseParser {
+ // This name is a holdover from before this class was renamed
public static final String NAME = "text_embedding_response_parser";
public static final String TEXT_EMBEDDING_PARSER_EMBEDDINGS = "text_embeddings";
public static final String EMBEDDING_TYPE = "embedding_type";
@@ -41,7 +41,7 @@ public class TextEmbeddingResponseParser extends BaseCustomResponseParser {
"ml_inference_custom_service_embedding_type"
);
- public static TextEmbeddingResponseParser fromMap(
+ public static DenseEmbeddingResponseParser fromMap(
Map responseParserMap,
String scope,
ValidationException validationException
@@ -70,18 +70,18 @@ public static TextEmbeddingResponseParser fromMap(
throw validationException;
}
- return new TextEmbeddingResponseParser(path, embeddingType);
+ return new DenseEmbeddingResponseParser(path, embeddingType);
}
private final String textEmbeddingsPath;
private final CustomServiceEmbeddingType embeddingType;
- public TextEmbeddingResponseParser(String textEmbeddingsPath, CustomServiceEmbeddingType embeddingType) {
+ public DenseEmbeddingResponseParser(String textEmbeddingsPath, CustomServiceEmbeddingType embeddingType) {
this.textEmbeddingsPath = Objects.requireNonNull(textEmbeddingsPath);
this.embeddingType = Objects.requireNonNull(embeddingType);
}
- public TextEmbeddingResponseParser(StreamInput in) throws IOException {
+ public DenseEmbeddingResponseParser(StreamInput in) throws IOException {
this.textEmbeddingsPath = in.readString();
if (in.getTransportVersion().supports(ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE)) {
this.embeddingType = in.readEnum(CustomServiceEmbeddingType.class);
@@ -122,7 +122,7 @@ public CustomServiceEmbeddingType getEmbeddingType() {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
- TextEmbeddingResponseParser that = (TextEmbeddingResponseParser) o;
+ DenseEmbeddingResponseParser that = (DenseEmbeddingResponseParser) o;
return Objects.equals(textEmbeddingsPath, that.textEmbeddingsPath) && Objects.equals(embeddingType, that.embeddingType);
}
@@ -174,7 +174,7 @@ private interface EmbeddingConverter {
private static class FloatEmbeddings implements EmbeddingConverter {
- private final List embeddings;
+ private final List embeddings;
FloatEmbeddings() {
this.embeddings = new ArrayList<>();
@@ -182,17 +182,17 @@ private static class FloatEmbeddings implements EmbeddingConverter {
public void toEmbedding(Object entry, String fieldName) {
var embeddingsAsListFloats = convertToListOfFloats(entry, fieldName);
- embeddings.add(TextEmbeddingFloatResults.Embedding.of(embeddingsAsListFloats));
+ embeddings.add(DenseEmbeddingFloatResults.Embedding.of(embeddingsAsListFloats));
}
- public TextEmbeddingFloatResults getResults() {
- return new TextEmbeddingFloatResults(embeddings);
+ public DenseEmbeddingFloatResults getResults() {
+ return new DenseEmbeddingFloatResults(embeddings);
}
}
private static class ByteEmbeddings implements EmbeddingConverter {
- private final List embeddings;
+ private final List embeddings;
ByteEmbeddings() {
this.embeddings = new ArrayList<>();
@@ -200,17 +200,17 @@ private static class ByteEmbeddings implements EmbeddingConverter {
public void toEmbedding(Object entry, String fieldName) {
var convertedEmbeddings = convertToListOfBytes(entry, fieldName);
- this.embeddings.add(TextEmbeddingByteResults.Embedding.of(convertedEmbeddings));
+ this.embeddings.add(DenseEmbeddingByteResults.Embedding.of(convertedEmbeddings));
}
- public TextEmbeddingByteResults getResults() {
- return new TextEmbeddingByteResults(embeddings);
+ public DenseEmbeddingByteResults getResults() {
+ return new DenseEmbeddingByteResults(embeddings);
}
}
private static class BitEmbeddings implements EmbeddingConverter {
- private final List embeddings;
+ private final List embeddings;
BitEmbeddings() {
this.embeddings = new ArrayList<>();
@@ -218,11 +218,11 @@ private static class BitEmbeddings implements EmbeddingConverter {
public void toEmbedding(Object entry, String fieldName) {
var convertedEmbeddings = convertToListOfBits(entry, fieldName);
- this.embeddings.add(TextEmbeddingByteResults.Embedding.of(convertedEmbeddings));
+ this.embeddings.add(DenseEmbeddingByteResults.Embedding.of(convertedEmbeddings));
}
- public TextEmbeddingBitResults getResults() {
- return new TextEmbeddingBitResults(embeddings);
+ public DenseEmbeddingBitResults getResults() {
+ return new DenseEmbeddingBitResults(embeddings);
}
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java
index 2c1ee96b519a3..646ee520de83f 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java
@@ -38,15 +38,15 @@
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.XPackSettings;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
@@ -642,7 +642,7 @@ public void inferTextEmbedding(
);
ActionListener mlResultsListener = listener.delegateFailureAndWrap(
- (l, inferenceResult) -> l.onResponse(TextEmbeddingFloatResults.of(inferenceResult.getInferenceResults()))
+ (l, inferenceResult) -> l.onResponse(DenseEmbeddingFloatResults.of(inferenceResult.getInferenceResults()))
);
var maybeDeployListener = mlResultsListener.delegateResponse(
@@ -772,22 +772,22 @@ private static void translateToChunkedResult(
ActionListener chunkPartListener
) {
if (taskType == TaskType.TEXT_EMBEDDING) {
- var translated = new ArrayList();
+ var translated = new ArrayList();
for (var inferenceResult : inferenceResults) {
- if (inferenceResult instanceof MlTextEmbeddingResults mlTextEmbeddingResult) {
- translated.add(new TextEmbeddingFloatResults.Embedding(mlTextEmbeddingResult.getInferenceAsFloat()));
+ if (inferenceResult instanceof MlDenseEmbeddingResults mlTextEmbeddingResult) {
+ translated.add(new DenseEmbeddingFloatResults.Embedding(mlTextEmbeddingResult.getInferenceAsFloat()));
} else if (inferenceResult instanceof ErrorInferenceResults error) {
chunkPartListener.onFailure(error.getException());
return;
} else {
chunkPartListener.onFailure(
- createInvalidChunkedResultException(MlTextEmbeddingResults.NAME, inferenceResult.getWriteableName())
+ createInvalidChunkedResultException(MlDenseEmbeddingResults.NAME, inferenceResult.getWriteableName())
);
return;
}
}
- chunkPartListener.onResponse(new TextEmbeddingFloatResults(translated));
+ chunkPartListener.onResponse(new DenseEmbeddingFloatResults(translated));
} else { // sparse
var translated = new ArrayList();
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntity.java
index 499fe9ae0c6c7..67527698bf02c 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntity.java
@@ -12,7 +12,7 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.XContentUtils;
@@ -70,7 +70,7 @@ public class GoogleAiStudioEmbeddingsResponseEntity {
*
*/
- public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
+ public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
@@ -81,16 +81,16 @@ public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult
positionParserAtTokenAfterField(jsonParser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE);
- List embeddingList = parseList(
+ List embeddingList = parseList(
jsonParser,
GoogleAiStudioEmbeddingsResponseEntity::parseEmbeddingObject
);
- return new TextEmbeddingFloatResults(embeddingList);
+ return new DenseEmbeddingFloatResults(embeddingList);
}
}
- private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException {
+ private static DenseEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
positionParserAtTokenAfterField(parser, "values", FAILED_TO_FIND_FIELD_TEMPLATE);
@@ -99,7 +99,7 @@ private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContent
// parse and discard the rest of the object
consumeUntilObjectEnd(parser);
- return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList);
+ return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList);
}
private GoogleAiStudioEmbeddingsResponseEntity() {}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntity.java
index b4038e42c62cb..94272815e8db2 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntity.java
@@ -13,7 +13,7 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
@@ -64,7 +64,7 @@ public class GoogleVertexAiEmbeddingsResponseEntity {
*
*/
- public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
+ public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
@@ -75,16 +75,16 @@ public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult
positionParserAtTokenAfterField(jsonParser, "predictions", FAILED_TO_FIND_FIELD_TEMPLATE);
- List embeddingList = parseList(
+ List embeddingList = parseList(
jsonParser,
GoogleVertexAiEmbeddingsResponseEntity::parseEmbeddingObject
);
- return new TextEmbeddingFloatResults(embeddingList);
+ return new DenseEmbeddingFloatResults(embeddingList);
}
}
- private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException {
+ private static DenseEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE);
@@ -99,7 +99,7 @@ private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContent
consumeUntilObjectEnd(parser);
consumeUntilObjectEnd(parser);
- return TextEmbeddingFloatResults.Embedding.of(embeddingValueList);
+ return DenseEmbeddingFloatResults.Embedding.of(embeddingValueList);
}
private static float parseEmbeddingList(XContentParser parser) throws IOException {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java
index 081d5c63b84ff..681058d8c21a5 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java
@@ -26,9 +26,9 @@
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -127,8 +127,8 @@ private static List translateToChunkedResults(
List inputs,
InferenceServiceResults inferenceResults
) {
- if (inferenceResults instanceof TextEmbeddingFloatResults textEmbeddingResults) {
- validateInputSizeAgainstEmbeddings(ChunkInferenceInput.inputs(inputs), textEmbeddingResults.embeddings().size());
+ if (inferenceResults instanceof DenseEmbeddingFloatResults denseEmbeddingResults) {
+ validateInputSizeAgainstEmbeddings(ChunkInferenceInput.inputs(inputs), denseEmbeddingResults.embeddings().size());
var results = new ArrayList(inputs.size());
@@ -137,7 +137,7 @@ private static List translateToChunkedResults(
new ChunkedInferenceEmbedding(
List.of(
new EmbeddingResults.Chunk(
- textEmbeddingResults.embeddings().get(i),
+ denseEmbeddingResults.embeddings().get(i),
new ChunkedInference.TextOffset(0, inputs.get(i).input().length())
)
)
@@ -153,7 +153,7 @@ private static List translateToChunkedResults(
} else {
String expectedClasses = Strings.format(
"One of [%s,%s]",
- TextEmbeddingFloatResults.class.getSimpleName(),
+ DenseEmbeddingFloatResults.class.getSimpleName(),
SparseEmbeddingResults.class.getSimpleName()
);
throw createInvalidChunkedResultException(expectedClasses, inferenceResults.getWriteableName());
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntity.java
index baf1e884108fb..126d5b03fcde4 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntity.java
@@ -12,7 +12,7 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.XContentUtils;
@@ -33,7 +33,7 @@ public class HuggingFaceEmbeddingsResponseEntity {
* Parse the response from hugging face. The known formats are an array of arrays and object with an {@code embeddings} field containing
* an array of arrays.
*/
- public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
+ public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
@@ -91,13 +91,13 @@ public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult
* sentence-transformers/all-MiniLM-L6-v2
* sentence-transformers/all-MiniLM-L12-v2
*/
- private static TextEmbeddingFloatResults parseArrayFormat(XContentParser parser) throws IOException {
- List embeddingList = parseList(
+ private static DenseEmbeddingFloatResults parseArrayFormat(XContentParser parser) throws IOException {
+ List embeddingList = parseList(
parser,
HuggingFaceEmbeddingsResponseEntity::parseEmbeddingEntry
);
- return new TextEmbeddingFloatResults(embeddingList);
+ return new DenseEmbeddingFloatResults(embeddingList);
}
/**
@@ -136,22 +136,22 @@ private static TextEmbeddingFloatResults parseArrayFormat(XContentParser parser)
* intfloat/multilingual-e5-small
* sentence-transformers/all-mpnet-base-v2
*/
- private static TextEmbeddingFloatResults parseObjectFormat(XContentParser parser) throws IOException {
+ private static DenseEmbeddingFloatResults parseObjectFormat(XContentParser parser) throws IOException {
positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE);
- List embeddingList = parseList(
+ List embeddingList = parseList(
parser,
HuggingFaceEmbeddingsResponseEntity::parseEmbeddingEntry
);
- return new TextEmbeddingFloatResults(embeddingList);
+ return new DenseEmbeddingFloatResults(embeddingList);
}
- private static TextEmbeddingFloatResults.Embedding parseEmbeddingEntry(XContentParser parser) throws IOException {
+ private static DenseEmbeddingFloatResults.Embedding parseEmbeddingEntry(XContentParser parser) throws IOException {
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
List embeddingValuesList = parseList(parser, XContentUtils::parseFloat);
- return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList);
+ return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList);
}
private HuggingFaceEmbeddingsResponseEntity() {}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntity.java
index 4fda9d5661a2c..d12f44932ed6e 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntity.java
@@ -12,7 +12,7 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.XContentUtils;
@@ -30,7 +30,7 @@ public class IbmWatsonxEmbeddingsResponseEntity {
private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in IBM watsonx embeddings response";
- public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
+ public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
@@ -41,16 +41,16 @@ public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult
positionParserAtTokenAfterField(jsonParser, "results", FAILED_TO_FIND_FIELD_TEMPLATE);
- List embeddingList = parseList(
+ List embeddingList = parseList(
jsonParser,
IbmWatsonxEmbeddingsResponseEntity::parseEmbeddingObject
);
- return new TextEmbeddingFloatResults(embeddingList);
+ return new DenseEmbeddingFloatResults(embeddingList);
}
}
- private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException {
+ private static DenseEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);
@@ -59,7 +59,7 @@ private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContent
// parse and discard the rest of the object
consumeUntilObjectEnd(parser);
- return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList);
+ return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList);
}
private IbmWatsonxEmbeddingsResponseEntity() {}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntity.java
index 8eee003accba0..f9c80d2593642 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntity.java
@@ -15,9 +15,9 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.XContentUtils;
@@ -126,15 +126,15 @@ public static InferenceServiceResults fromResponse(Request request, HttpResult r
}
private static InferenceServiceResults parseFloatDataObject(XContentParser jsonParser) throws IOException {
- List embeddingList = parseList(
+ List embeddingList = parseList(
jsonParser,
JinaAIEmbeddingsResponseEntity::parseFloatEmbeddingObject
);
- return new TextEmbeddingFloatResults(embeddingList);
+ return new DenseEmbeddingFloatResults(embeddingList);
}
- private static TextEmbeddingFloatResults.Embedding parseFloatEmbeddingObject(XContentParser parser) throws IOException {
+ private static DenseEmbeddingFloatResults.Embedding parseFloatEmbeddingObject(XContentParser parser) throws IOException {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);
@@ -143,19 +143,19 @@ private static TextEmbeddingFloatResults.Embedding parseFloatEmbeddingObject(XCo
// parse and discard the rest of the object
consumeUntilObjectEnd(parser);
- return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList);
+ return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList);
}
private static InferenceServiceResults parseBitDataObject(XContentParser jsonParser) throws IOException {
- List embeddingList = parseList(
+ List embeddingList = parseList(
jsonParser,
JinaAIEmbeddingsResponseEntity::parseBitEmbeddingObject
);
- return new TextEmbeddingBitResults(embeddingList);
+ return new DenseEmbeddingBitResults(embeddingList);
}
- private static TextEmbeddingByteResults.Embedding parseBitEmbeddingObject(XContentParser parser) throws IOException {
+ private static DenseEmbeddingByteResults.Embedding parseBitEmbeddingObject(XContentParser parser) throws IOException {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);
@@ -164,7 +164,7 @@ private static TextEmbeddingByteResults.Embedding parseBitEmbeddingObject(XConte
// parse and discard the rest of the object
consumeUntilObjectEnd(parser);
- return TextEmbeddingByteResults.Embedding.of(embeddingList);
+ return DenseEmbeddingByteResults.Embedding.of(embeddingList);
}
private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntity.java
index b8130545a711d..3298ca310c01d 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntity.java
@@ -12,7 +12,7 @@
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
@@ -65,9 +65,9 @@ public class OpenAiEmbeddingsResponseEntity {
*
*
*/
- public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
+ public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) {
- return EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults();
+ return EmbeddingFloatResult.PARSER.apply(p, null).toDenseEmbeddingFloatResults();
}
}
@@ -83,9 +83,9 @@ public record EmbeddingFloatResult(List embeddingResu
PARSER.declareObjectArray(constructorArg(), EmbeddingFloatResultEntry.PARSER::apply, new ParseField("data"));
}
- public TextEmbeddingFloatResults toTextEmbeddingFloatResults() {
- return new TextEmbeddingFloatResults(
- embeddingResults.stream().map(entry -> TextEmbeddingFloatResults.Embedding.of(entry.embedding)).toList()
+ public DenseEmbeddingFloatResults toDenseEmbeddingFloatResults() {
+ return new DenseEmbeddingFloatResults(
+ embeddingResults.stream().map(entry -> DenseEmbeddingFloatResults.Embedding.of(entry.embedding)).toList()
);
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredServiceSchema.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredServiceSchema.java
index 70ed23e215dd6..94d96b3db6d26 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredServiceSchema.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredServiceSchema.java
@@ -12,6 +12,7 @@
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.validation.DenseEmbeddingModelValidator;
/**
* Contains any model-specific settings that are stored in SageMakerServiceSettings.
@@ -72,7 +73,7 @@ default boolean isFragment() {
/**
* If this Schema supports Text Embeddings, then we need to implement this.
- * {@link org.elasticsearch.xpack.inference.services.validation.TextEmbeddingModelValidator} will set the dimensions if the user
+ * {@link DenseEmbeddingModelValidator} will set the dimensions if the user
* does not do it, so we need to store the dimensions and flip the {@link #dimensionsSetByUser()} boolean.
*/
default SageMakerStoredServiceSchema updateModelWithEmbeddingDetails(Integer dimensions) {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java
index a5fd194f12109..dbbf82f5703b9 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java
@@ -24,10 +24,10 @@
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParserConfiguration;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingResults;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema;
@@ -92,7 +92,7 @@ public Stream namedWriteables() {
}
@Override
- public TextEmbeddingResults> responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception {
+ public DenseEmbeddingResults> responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception {
try (var p = jsonXContent.createParser(XContentParserConfiguration.EMPTY, response.body().asInputStream())) {
return switch (model.apiServiceSettings().elementType()) {
case BIT -> TextEmbeddingBinary.PARSER.apply(p, null);
@@ -103,7 +103,7 @@ public TextEmbeddingResults> responseBody(SageMakerModel model, InvokeEndpoint
}
/**
- * Reads binary format (it says bytes, but the lengths are different)
+ * Reads binary format
* {
* "text_embedding_bits": [
* {
@@ -120,12 +120,12 @@ public TextEmbeddingResults> responseBody(SageMakerModel model, InvokeEndpoint
* }
*/
private static class TextEmbeddingBinary {
- private static final ParseField TEXT_EMBEDDING_BITS = new ParseField(TextEmbeddingBitResults.TEXT_EMBEDDING_BITS);
+ private static final ParseField TEXT_EMBEDDING_BITS = new ParseField(DenseEmbeddingBitResults.TEXT_EMBEDDING_BITS);
@SuppressWarnings("unchecked")
- private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
- TextEmbeddingBitResults.class.getSimpleName(),
+ private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
+ DenseEmbeddingBitResults.class.getSimpleName(),
IGNORE_UNKNOWN_FIELDS,
- args -> new TextEmbeddingBitResults((List) args[0])
+ args -> new DenseEmbeddingBitResults((List) args[0])
);
static {
@@ -151,20 +151,20 @@ private static class TextEmbeddingBinary {
* }
*/
private static class TextEmbeddingBytes {
- private static final ParseField TEXT_EMBEDDING_BYTES = new ParseField("text_embedding_bytes");
+ private static final ParseField TEXT_EMBEDDING_BYTES = new ParseField(DenseEmbeddingByteResults.TEXT_EMBEDDING_BYTES);
@SuppressWarnings("unchecked")
- private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
- TextEmbeddingByteResults.class.getSimpleName(),
+ private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
+ DenseEmbeddingByteResults.class.getSimpleName(),
IGNORE_UNKNOWN_FIELDS,
- args -> new TextEmbeddingByteResults((List) args[0])
+ args -> new DenseEmbeddingByteResults((List) args[0])
);
@SuppressWarnings("unchecked")
- private static final ConstructingObjectParser BYTE_PARSER =
+ private static final ConstructingObjectParser BYTE_PARSER =
new ConstructingObjectParser<>(
- TextEmbeddingByteResults.Embedding.class.getSimpleName(),
+ DenseEmbeddingByteResults.Embedding.class.getSimpleName(),
IGNORE_UNKNOWN_FIELDS,
- args -> TextEmbeddingByteResults.Embedding.of((List) args[0])
+ args -> DenseEmbeddingByteResults.Embedding.of((List) args[0])
);
static {
@@ -197,20 +197,20 @@ private static class TextEmbeddingBytes {
* }
*/
private static class TextEmbeddingFloat {
- private static final ParseField TEXT_EMBEDDING_FLOAT = new ParseField("text_embedding");
+ private static final ParseField TEXT_EMBEDDING_FLOAT = new ParseField(DenseEmbeddingFloatResults.TEXT_EMBEDDING);
@SuppressWarnings("unchecked")
- private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
- TextEmbeddingByteResults.class.getSimpleName(),
+ private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
+ DenseEmbeddingFloatResults.class.getSimpleName(),
IGNORE_UNKNOWN_FIELDS,
- args -> new TextEmbeddingFloatResults((List) args[0])
+ args -> new DenseEmbeddingFloatResults((List) args[0])
);
@SuppressWarnings("unchecked")
- private static final ConstructingObjectParser FLOAT_PARSER =
+ private static final ConstructingObjectParser FLOAT_PARSER =
new ConstructingObjectParser<>(
- TextEmbeddingFloatResults.Embedding.class.getSimpleName(),
+ DenseEmbeddingFloatResults.Embedding.class.getSimpleName(),
IGNORE_UNKNOWN_FIELDS,
- args -> TextEmbeddingFloatResults.Embedding.of((List) args[0])
+ args -> DenseEmbeddingFloatResults.Embedding.of((List) args[0])
);
static {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java
index f060c7a75797b..b4c0b62f7ca12 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java
@@ -25,7 +25,7 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.json.JsonXContent;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
@@ -116,9 +116,9 @@ public SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest req
}
@Override
- public TextEmbeddingFloatResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception {
+ public DenseEmbeddingFloatResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception {
try (var p = jsonXContent.createParser(XContentParserConfiguration.EMPTY, response.body().asInputStream())) {
- return OpenAiEmbeddingsResponseEntity.EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults();
+ return OpenAiEmbeddingsResponseEntity.EmbeddingFloatResult.PARSER.apply(p, null).toDenseEmbeddingFloatResults();
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/DenseEmbeddingModelValidator.java
similarity index 85%
rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java
rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/DenseEmbeddingModelValidator.java
index ce9df7376ebcb..fb7b415758ca5 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/DenseEmbeddingModelValidator.java
@@ -16,14 +16,14 @@
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.validation.ServiceIntegrationValidator;
import org.elasticsearch.rest.RestStatus;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingResults;
-public class TextEmbeddingModelValidator implements ModelValidator {
+public class DenseEmbeddingModelValidator implements ModelValidator {
private final ServiceIntegrationValidator serviceIntegrationValidator;
- public TextEmbeddingModelValidator(ServiceIntegrationValidator serviceIntegrationValidator) {
+ public DenseEmbeddingModelValidator(ServiceIntegrationValidator serviceIntegrationValidator) {
this.serviceIntegrationValidator = serviceIntegrationValidator;
}
@@ -35,7 +35,7 @@ public void validate(InferenceService service, Model model, TimeValue timeout, A
}
private Model postValidate(InferenceService service, Model model, InferenceServiceResults results) {
- if (results instanceof TextEmbeddingResults> embeddingResults) {
+ if (results instanceof DenseEmbeddingResults> embeddingResults) {
var serviceSettings = model.getServiceSettings();
var dimensions = serviceSettings.dimensions();
int embeddingSize = getEmbeddingSize(embeddingResults);
@@ -60,7 +60,7 @@ private Model postValidate(InferenceService service, Model model, InferenceServi
throw new ElasticsearchStatusException(
"Validation call did not return expected results type."
+ "Expected a result of type ["
- + TextEmbeddingFloatResults.NAME
+ + DenseEmbeddingFloatResults.NAME
+ "] got ["
+ (results == null ? "null" : results.getWriteableName())
+ "]",
@@ -69,7 +69,7 @@ private Model postValidate(InferenceService service, Model model, InferenceServi
}
}
- private int getEmbeddingSize(TextEmbeddingResults> embeddingResults) {
+ private int getEmbeddingSize(DenseEmbeddingResults> embeddingResults) {
int embeddingSize;
try {
embeddingSize = embeddingResults.getFirstEmbeddingSize();
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java
index fa0e1b3e590a4..6078531eec3f9 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java
@@ -17,8 +17,8 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.validation.ServiceIntegrationValidator;
import org.elasticsearch.rest.RestStatus;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingResults;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandEmbeddingModel;
public class ElasticsearchInternalServiceModelValidator implements ModelValidator {
@@ -54,7 +54,7 @@ public void validate(InferenceService service, Model model, TimeValue timeout, A
}
private Model postValidate(InferenceService service, Model model, InferenceServiceResults results) {
- if (results instanceof TextEmbeddingResults> embeddingResults) {
+ if (results instanceof DenseEmbeddingResults> embeddingResults) {
var serviceSettings = model.getServiceSettings();
var dimensions = serviceSettings.dimensions();
int embeddingSize = getEmbeddingSize(embeddingResults);
@@ -79,7 +79,7 @@ private Model postValidate(InferenceService service, Model model, InferenceServi
throw new ElasticsearchStatusException(
"Validation call did not return expected results type."
+ "Expected a result of type ["
- + TextEmbeddingFloatResults.NAME
+ + DenseEmbeddingFloatResults.NAME
+ "] got ["
+ (results == null ? "null" : results.getWriteableName())
+ "]",
@@ -88,7 +88,7 @@ private Model postValidate(InferenceService service, Model model, InferenceServi
}
}
- private int getEmbeddingSize(TextEmbeddingResults> embeddingResults) {
+ private int getEmbeddingSize(DenseEmbeddingResults> embeddingResults) {
int embeddingSize;
try {
embeddingSize = embeddingResults.getFirstEmbeddingSize();
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java
index fac9ee5e9c1c1..59fcdac82c2a4 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java
@@ -36,7 +36,7 @@ private static ModelValidator buildModelValidatorForTaskType(TaskType taskType,
switch (taskType) {
case TEXT_EMBEDDING -> {
- return new TextEmbeddingModelValidator(
+ return new DenseEmbeddingModelValidator(
Objects.requireNonNullElse(validatorFromService, new SimpleServiceIntegrationValidator())
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntity.java
index f9ba5fd58d21a..61436d509e45a 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntity.java
@@ -15,9 +15,9 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType;
@@ -75,9 +75,9 @@ private static void checkByteBounds(Integer value) {
}
}
- public TextEmbeddingByteResults.Embedding toInferenceByteEmbedding() {
+ public DenseEmbeddingByteResults.Embedding toInferenceByteEmbedding() {
embedding.forEach(EmbeddingInt8ResultEntry::checkByteBounds);
- return TextEmbeddingByteResults.Embedding.of(embedding.stream().map(Integer::byteValue).toList());
+ return DenseEmbeddingByteResults.Embedding.of(embedding.stream().map(Integer::byteValue).toList());
}
}
@@ -108,8 +108,8 @@ record EmbeddingFloatResultEntry(Integer index, List embedding) {
PARSER.declareFloatArray(constructorArg(), new ParseField("embedding"));
}
- public TextEmbeddingFloatResults.Embedding toInferenceFloatEmbedding() {
- return TextEmbeddingFloatResults.Embedding.of(embedding);
+ public DenseEmbeddingFloatResults.Embedding toInferenceFloatEmbedding() {
+ return DenseEmbeddingFloatResults.Embedding.of(embedding);
}
}
@@ -166,22 +166,22 @@ public static InferenceServiceResults fromResponse(Request request, HttpResult r
if (embeddingType == null || embeddingType == VoyageAIEmbeddingType.FLOAT) {
var embeddingResult = EmbeddingFloatResult.PARSER.apply(jsonParser, null);
- List embeddingList = embeddingResult.entries.stream()
+ List embeddingList = embeddingResult.entries.stream()
.map(EmbeddingFloatResultEntry::toInferenceFloatEmbedding)
.toList();
- return new TextEmbeddingFloatResults(embeddingList);
+ return new DenseEmbeddingFloatResults(embeddingList);
} else if (embeddingType == VoyageAIEmbeddingType.INT8) {
var embeddingResult = EmbeddingInt8Result.PARSER.apply(jsonParser, null);
- List embeddingList = embeddingResult.entries.stream()
+ List embeddingList = embeddingResult.entries.stream()
.map(EmbeddingInt8ResultEntry::toInferenceByteEmbedding)
.toList();
- return new TextEmbeddingByteResults(embeddingList);
+ return new DenseEmbeddingByteResults(embeddingList);
} else if (embeddingType == VoyageAIEmbeddingType.BIT || embeddingType == VoyageAIEmbeddingType.BINARY) {
var embeddingResult = EmbeddingInt8Result.PARSER.apply(jsonParser, null);
- List embeddingList = embeddingResult.entries.stream()
+ List embeddingList = embeddingResult.entries.stream()
.map(EmbeddingInt8ResultEntry::toInferenceByteEmbedding)
.toList();
- return new TextEmbeddingBitResults(embeddingList);
+ return new DenseEmbeddingBitResults(embeddingList);
} else {
throw new IllegalArgumentException(
"Illegal embedding_type value: " + embeddingType + ". Supported types are: " + VALID_EMBEDDING_TYPES_STRING
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java
index 7625534f8c41d..4e40d23dfd7f1 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java
@@ -11,9 +11,9 @@
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
-import org.elasticsearch.xpack.core.inference.results.LegacyMlTextEmbeddingResultsTests;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests;
+import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResultsTests;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider;
@@ -40,8 +40,8 @@ protected Writeable.Reader instanceReader() {
@Override
protected InferenceAction.Response createTestInstance() {
var result = switch (randomIntBetween(0, 2)) {
- case 0 -> TextEmbeddingFloatResultsTests.createRandomResults();
- case 1 -> LegacyMlTextEmbeddingResultsTests.createRandomResults().transformToTextEmbeddingResults();
+ case 0 -> DenseEmbeddingFloatResultsTests.createRandomResults();
+ case 1 -> LegacyTextEmbeddingResultsTests.createRandomResults().transformToDenseEmbeddingResults();
default -> SparseEmbeddingResultsTests.createRandomResults();
};
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java
index 411d992adfa3d..9d450912d9c4e 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java
@@ -14,10 +14,10 @@
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.hamcrest.Matchers;
import java.util.ArrayList;
@@ -393,12 +393,12 @@ public void testVeryLongInput_Float() {
// Produce inference results for each request, with increasing weights.
float weight = 0f;
for (var batch : batches) {
- var embeddings = new ArrayList();
+ var embeddings = new ArrayList();
for (int i = 0; i < batch.batch().requests().size(); i++) {
weight += 1 / 16384f;
- embeddings.add(new TextEmbeddingFloatResults.Embedding(new float[] { weight }));
+ embeddings.add(new DenseEmbeddingFloatResults.Embedding(new float[] { weight }));
}
- batch.listener().onResponse(new TextEmbeddingFloatResults(embeddings));
+ batch.listener().onResponse(new DenseEmbeddingFloatResults(embeddings));
}
assertNotNull(finalListener.results);
@@ -410,8 +410,10 @@ public void testVeryLongInput_Float() {
ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
assertThat(chunkedEmbedding.chunks(), hasSize(1));
assertThat(getMatchedText(inputs.get(0).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small"));
- assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
- TextEmbeddingFloatResults.Embedding embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
+ assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
+ DenseEmbeddingFloatResults.Embedding embedding = (DenseEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks()
+ .get(0)
+ .embedding();
assertThat(embedding.values(), equalTo(new float[] { 1 / 16384f }));
// The very long passage "word0 word1 ... word199999" is split into 10000 chunks for
@@ -427,8 +429,8 @@ public void testVeryLongInput_Float() {
// is the average of the weights 2/16384 ... 21/16384.
assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 "));
assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399"));
- assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
- embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
+ assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
+ embedding = (DenseEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
assertThat(embedding.values(), equalTo(new float[] { (2 + 21) / (2 * 16384f) }));
// The last merged chunk consists of 19 small chunks (so 380 words) and the weight
@@ -438,8 +440,8 @@ public void testVeryLongInput_Float() {
startsWith(" word199620 word199621 ")
);
assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999"));
- assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
- embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(511).embedding();
+ assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
+ embedding = (DenseEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(511).embedding();
assertThat(embedding.values(), equalTo(new float[] { (9983 + 10001) / (2 * 16384f) }));
// The last input has the token with weight 10002/16384.
@@ -448,8 +450,8 @@ public void testVeryLongInput_Float() {
chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
assertThat(chunkedEmbedding.chunks(), hasSize(1));
assertThat(getMatchedText(inputs.get(2).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small"));
- assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
- embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
+ assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
+ embedding = (DenseEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
assertThat(embedding.values(), equalTo(new float[] { 10002 / 16384f }));
}
@@ -484,12 +486,12 @@ public void testVeryLongInput_Byte() {
// Produce inference results for each request, with increasing weights.
byte weight = 0;
for (var batch : batches) {
- var embeddings = new ArrayList();
+ var embeddings = new ArrayList();
for (int i = 0; i < batch.batch().requests().size(); i++) {
weight += 1;
- embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { weight }));
+ embeddings.add(new DenseEmbeddingByteResults.Embedding(new byte[] { weight }));
}
- batch.listener().onResponse(new TextEmbeddingByteResults(embeddings));
+ batch.listener().onResponse(new DenseEmbeddingByteResults(embeddings));
}
assertNotNull(finalListener.results);
@@ -501,8 +503,8 @@ public void testVeryLongInput_Byte() {
ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
assertThat(chunkedEmbedding.chunks(), hasSize(1));
assertThat(getMatchedText(inputs.get(0).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small"));
- assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class));
- TextEmbeddingByteResults.Embedding embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
+ assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class));
+ DenseEmbeddingByteResults.Embedding embedding = (DenseEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
assertThat(embedding.values(), equalTo(new byte[] { 1 }));
// The very long passage "word0 word1 ... word199999" is split into 10000 chunks for
@@ -518,8 +520,8 @@ public void testVeryLongInput_Byte() {
// is the average of the weights 2 ... 21, so 11.5, which is rounded to 12.
assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 "));
assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399"));
- assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class));
- embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
+ assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class));
+ embedding = (DenseEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
assertThat(embedding.values(), equalTo(new byte[] { 12 }));
// The last merged chunk consists of 19 small chunks (so 380 words) and the weight
@@ -530,8 +532,8 @@ public void testVeryLongInput_Byte() {
startsWith(" word199620 word199621 ")
);
assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999"));
- assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class));
- embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(511).embedding();
+ assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class));
+ embedding = (DenseEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(511).embedding();
assertThat(embedding.values(), equalTo(new byte[] { 8 }));
// The last input has the token with weight 10002 % 256 = 18
@@ -540,8 +542,8 @@ public void testVeryLongInput_Byte() {
chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
assertThat(chunkedEmbedding.chunks(), hasSize(1));
assertThat(getMatchedText(inputs.get(2).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small"));
- assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class));
- embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
+ assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class));
+ embedding = (DenseEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
assertThat(embedding.values(), equalTo(new byte[] { 18 }));
}
@@ -570,18 +572,18 @@ public void testMergingListener_Float() {
// 4 inputs in 2 batches
{
- var embeddings = new ArrayList();
+ var embeddings = new ArrayList();
for (int i = 0; i < batchSize; i++) {
- embeddings.add(new TextEmbeddingFloatResults.Embedding(new float[] { randomFloat() }));
+ embeddings.add(new DenseEmbeddingFloatResults.Embedding(new float[] { randomFloat() }));
}
- batches.get(0).listener().onResponse(new TextEmbeddingFloatResults(embeddings));
+ batches.get(0).listener().onResponse(new DenseEmbeddingFloatResults(embeddings));
}
{
- var embeddings = new ArrayList();
+ var embeddings = new ArrayList();
for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch
- embeddings.add(new TextEmbeddingFloatResults.Embedding(new float[] { randomFloat() }));
+ embeddings.add(new DenseEmbeddingFloatResults.Embedding(new float[] { randomFloat() }));
}
- batches.get(1).listener().onResponse(new TextEmbeddingFloatResults(embeddings));
+ batches.get(1).listener().onResponse(new DenseEmbeddingFloatResults(embeddings));
}
assertNotNull(finalListener.results);
@@ -650,18 +652,18 @@ public void testMergingListener_Byte() {
// 4 inputs in 2 batches
{
- var embeddings = new ArrayList();
+ var embeddings = new ArrayList();
for (int i = 0; i < batchSize; i++) {
- embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { randomByte() }));
+ embeddings.add(new DenseEmbeddingByteResults.Embedding(new byte[] { randomByte() }));
}
- batches.get(0).listener().onResponse(new TextEmbeddingByteResults(embeddings));
+ batches.get(0).listener().onResponse(new DenseEmbeddingByteResults(embeddings));
}
{
- var embeddings = new ArrayList();
+ var embeddings = new ArrayList();
for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch
- embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { randomByte() }));
+ embeddings.add(new DenseEmbeddingByteResults.Embedding(new byte[] { randomByte() }));
}
- batches.get(1).listener().onResponse(new TextEmbeddingByteResults(embeddings));
+ batches.get(1).listener().onResponse(new DenseEmbeddingByteResults(embeddings));
}
assertNotNull(finalListener.results);
@@ -727,18 +729,18 @@ public void testMergingListener_Bit() {
// 4 inputs in 2 batches
{
- var embeddings = new ArrayList();
+ var embeddings = new ArrayList();
for (int i = 0; i < batchSize; i++) {
- embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { randomByte() }));
+ embeddings.add(new DenseEmbeddingByteResults.Embedding(new byte[] { randomByte() }));
}
- batches.get(0).listener().onResponse(new TextEmbeddingBitResults(embeddings));
+ batches.get(0).listener().onResponse(new DenseEmbeddingBitResults(embeddings));
}
{
- var embeddings = new ArrayList();
+ var embeddings = new ArrayList();
for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch
- embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { randomByte() }));
+ embeddings.add(new DenseEmbeddingByteResults.Embedding(new byte[] { randomByte() }));
}
- batches.get(1).listener().onResponse(new TextEmbeddingBitResults(embeddings));
+ batches.get(1).listener().onResponse(new DenseEmbeddingBitResults(embeddings));
}
assertNotNull(finalListener.results);
@@ -892,10 +894,10 @@ public void onFailure(Exception e) {
var batches = new EmbeddingRequestChunker<>(inputs, 10, 100, 0).batchRequestsWithListeners(listener);
assertThat(batches, hasSize(1));
- var embeddings = new ArrayList();
- embeddings.add(new TextEmbeddingFloatResults.Embedding(new float[] { randomFloat() }));
- embeddings.add(new TextEmbeddingFloatResults.Embedding(new float[] { randomFloat() }));
- batches.get(0).listener().onResponse(new TextEmbeddingFloatResults(embeddings));
+ var embeddings = new ArrayList();
+ embeddings.add(new DenseEmbeddingFloatResults.Embedding(new float[] { randomFloat() }));
+ embeddings.add(new DenseEmbeddingFloatResults.Embedding(new float[] { randomFloat() }));
+ batches.get(0).listener().onResponse(new DenseEmbeddingFloatResults(embeddings));
assertEquals("Error the number of embedding responses [2] does not equal the number of requests [3]", failureMessage.get());
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java
index 027a19aca6d1f..f104e7b87136e 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java
@@ -48,7 +48,7 @@
import java.util.concurrent.atomic.AtomicReference;
import static org.elasticsearch.core.Strings.format;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java
index d1499f4009d0a..d596677f4125f 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java
@@ -25,10 +25,10 @@
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.utils.FloatConversionUtils;
import org.elasticsearch.xpack.inference.chunking.NoneChunkingSettings;
import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings;
@@ -211,7 +211,7 @@ public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingByte(Mode
}
chunks.add(
new EmbeddingResults.Chunk(
- new TextEmbeddingByteResults.Embedding(values),
+ new DenseEmbeddingByteResults.Embedding(values),
new ChunkedInference.TextOffset(0, input.length())
)
);
@@ -233,7 +233,7 @@ public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingFloat(Mod
}
chunks.add(
new EmbeddingResults.Chunk(
- new TextEmbeddingFloatResults.Embedding(values),
+ new DenseEmbeddingFloatResults.Embedding(values),
new ChunkedInference.TextOffset(0, input.length())
)
);
@@ -415,8 +415,8 @@ public static ChunkedInference toChunkedResult(
ChunkedInference.TextOffset offset = createOffset(useLegacyFormat, entryChunk, matchedText);
double[] values = parseDenseVector(entryChunk.rawEmbeddings(), embeddingLength, field.contentType());
EmbeddingResults.Embedding> embedding = switch (elementType) {
- case FLOAT -> new TextEmbeddingFloatResults.Embedding(FloatConversionUtils.floatArrayOf(values));
- case BYTE, BIT -> new TextEmbeddingByteResults.Embedding(byteArrayOf(values));
+ case FLOAT -> new DenseEmbeddingFloatResults.Embedding(FloatConversionUtils.floatArrayOf(values));
+ case BYTE, BIT -> new DenseEmbeddingByteResults.Embedding(byteArrayOf(values));
};
chunks.add(new EmbeddingResults.Chunk(embedding, offset));
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java
index 20354aa9b2dc7..de8583132319b 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java
@@ -22,7 +22,7 @@
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.search.vectors.VectorData;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
@@ -175,8 +175,8 @@ public void testInterceptAndRewrite() throws Exception {
new FullyQualifiedInferenceId(LOCAL_CLUSTER_GROUP_KEY, DENSE_INFERENCE_ID)
);
assertThat(inferenceResults, notNullValue());
- assertThat(inferenceResults, instanceOf(MlTextEmbeddingResults.class));
- VectorData queryVector = new VectorData(((MlTextEmbeddingResults) inferenceResults).getInferenceAsFloat());
+ assertThat(inferenceResults, instanceOf(MlDenseEmbeddingResults.class));
+ VectorData queryVector = new VectorData(((MlDenseEmbeddingResults) inferenceResults).getInferenceAsFloat());
// Perform data node rewrite on test index 1
final QueryRewriteContext indexMetadataContextTestIndex1 = createIndexMetadataContext(
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/MockInferenceClient.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/MockInferenceClient.java
index 17438d9786ba3..b22191735f1d9 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/MockInferenceClient.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/MockInferenceClient.java
@@ -20,11 +20,11 @@
import org.elasticsearch.test.client.NoOpClient;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import java.util.Arrays;
@@ -58,8 +58,8 @@ protected void
InferenceResults inferenceResults = generateInferenceResults(inferenceId, input);
if (inferenceResults instanceof TextExpansionResults textExpansionResults) {
inferenceServiceResults = SparseEmbeddingResults.of(List.of(textExpansionResults));
- } else if (inferenceResults instanceof MlTextEmbeddingResults mlTextEmbeddingResults) {
- inferenceServiceResults = TextEmbeddingFloatResults.of(List.of(mlTextEmbeddingResults));
+ } else if (inferenceResults instanceof MlDenseEmbeddingResults mlDenseEmbeddingResults) {
+ inferenceServiceResults = DenseEmbeddingFloatResults.of(List.of(mlDenseEmbeddingResults));
} else {
throw new IllegalStateException("Unexpected inference results type [" + inferenceResults.getWriteableName() + "]");
}
@@ -126,6 +126,6 @@ private static InferenceResults generateTextEmbeddingResults(MinimalServiceSetti
double[] embedding = new double[embeddingSize];
Arrays.fill(embedding, Byte.MIN_VALUE); // Always use a byte value so that the embedding is valid regardless of the element type
- return new MlTextEmbeddingResults(DEFAULT_RESULTS_FIELD, embedding, false);
+ return new MlDenseEmbeddingResults(DEFAULT_RESULTS_FIELD, embedding, false);
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java
index b2d7218720a57..160261eb1b1cc 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java
@@ -67,10 +67,10 @@
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.XPackClientPlugin;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
@@ -348,9 +348,9 @@ private InferenceAction.Response generateTextEmbeddingInferenceResponse() {
int inferenceLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(denseVectorElementType, TEXT_EMBEDDING_DIMENSION_COUNT);
double[] inference = new double[inferenceLength];
Arrays.fill(inference, 1.0);
- MlTextEmbeddingResults textEmbeddingResults = new MlTextEmbeddingResults(DEFAULT_RESULTS_FIELD, inference, false);
+ MlDenseEmbeddingResults textEmbeddingResults = new MlDenseEmbeddingResults(DEFAULT_RESULTS_FIELD, inference, false);
- return new InferenceAction.Response(TextEmbeddingFloatResults.of(List.of(textEmbeddingResults)));
+ return new InferenceAction.Response(DenseEmbeddingFloatResults.of(List.of(textEmbeddingResults)));
}
@Override
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java
index 4e60b09530684..ece25247bb5b5 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java
@@ -23,7 +23,7 @@
import org.elasticsearch.xpack.core.inference.InferenceContext;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.junit.Before;
@@ -234,7 +234,7 @@ public void testExtractProductUseCase_EmptyWhenHeaderValueEmpty() {
static InferenceAction.Response createResponse() {
return new InferenceAction.Response(
- new TextEmbeddingByteResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -1 })))
+ new DenseEmbeddingByteResults(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -1 })))
);
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java
index c9f5ea738b33b..94f65acb67cff 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java
@@ -33,9 +33,9 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.InputTypeTests;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
@@ -516,7 +516,7 @@ private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettin
var firstResult = results.getFirst();
assertThat(firstResult, instanceOf(ChunkedInferenceEmbedding.class));
Class> expectedClass = switch (taskType) {
- case TEXT_EMBEDDING -> TextEmbeddingFloatResults.Chunk.class;
+ case TEXT_EMBEDDING -> DenseEmbeddingFloatResults.Chunk.class;
case SPARSE_EMBEDDING -> SparseEmbeddingResults.Chunk.class;
default -> null;
};
@@ -650,10 +650,10 @@ private AlibabaCloudSearchModel createEmbeddingsModel(
) {
public ExecutableAction accept(AlibabaCloudSearchActionVisitor visitor, Map taskSettings) {
return (inferenceInputs, timeout, listener) -> {
- TextEmbeddingFloatResults results = new TextEmbeddingFloatResults(
+ DenseEmbeddingFloatResults results = new DenseEmbeddingFloatResults(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123f, -0.0123f }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.0456f, -0.0456f })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123f, -0.0123f }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0456f, -0.0456f })
)
);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java
index b09fbf43a8ca4..dec6539732156 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java
@@ -21,11 +21,11 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
@@ -50,9 +50,9 @@
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests.buildExpectationRerank;
import static org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests.buildExpectationSparseEmbeddings;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.hamcrest.Matchers.is;
@@ -88,7 +88,7 @@ public void testExecute_withTextEmbeddingsAction_Success() {
float[] values = { 0.1111111f, 0.2222222f, 0.3333333f };
doAnswer(invocation -> {
ActionListener listener = invocation.getArgument(3);
- listener.onResponse(new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(values))));
+ listener.onResponse(new DenseEmbeddingFloatResults(List.of(new DenseEmbeddingFloatResults.Embedding(values))));
return Void.TYPE;
}).when(sender).send(any(), any(), any(), any());
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntityTests.java
index ed8a1185bd846..28e5c9422a970 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntityTests.java
@@ -9,7 +9,7 @@
import org.apache.http.HttpResponse;
import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.request.AlibabaCloudSearchRequest;
@@ -50,14 +50,14 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException,
URI uri = new URI("mock_uri");
when(request.getURI()).thenReturn(uri);
- TextEmbeddingFloatResults parsedResults = AlibabaCloudSearchEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = AlibabaCloudSearchEmbeddingsResponseEntity.fromResponse(
request,
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
assertThat(
parsedResults.embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { -0.02868066355586052f, 0.022033605724573135f })))
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { -0.02868066355586052f, 0.022033605724573135f })))
);
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java
index 04c8bca2287e5..33d2d120e1b8d 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java
@@ -38,7 +38,7 @@
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.Utils;
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -70,7 +70,7 @@
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
@@ -1026,7 +1026,7 @@ public void testInfer_SendsRequest_ForTitanEmbeddingsModel() throws IOException
);
var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()
) {
- var results = new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })));
+ var results = new DenseEmbeddingFloatResults(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })));
requestSender.enqueue(results);
PlainActionFuture listener = new PlainActionFuture<>();
service.infer(
@@ -1067,8 +1067,8 @@ public void testInfer_SendsRequest_ForCohereEmbeddingsModel() throws IOException
)
) {
try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) {
- var results = new TextEmbeddingFloatResults(
- List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F }))
+ var results = new DenseEmbeddingFloatResults(
+ List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F }))
);
requestSender.enqueue(results);
@@ -1343,14 +1343,14 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep
) {
try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) {
{
- var mockResults1 = new TextEmbeddingFloatResults(
- List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F }))
+ var mockResults1 = new DenseEmbeddingFloatResults(
+ List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F }))
);
requestSender.enqueue(mockResults1);
}
{
- var mockResults2 = new TextEmbeddingFloatResults(
- List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.223F, 0.278F }))
+ var mockResults2 = new DenseEmbeddingFloatResults(
+ List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.223F, 0.278F }))
);
requestSender.enqueue(mockResults2);
}
@@ -1373,10 +1373,10 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 0.123F, 0.678F },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
@@ -1385,10 +1385,10 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 0.223F, 0.278F },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java
index 5dd42dc66485f..8f1ba6b277c37 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java
@@ -15,7 +15,7 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.InputTypeTests;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
@@ -33,7 +33,7 @@
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.hamcrest.Matchers.is;
@@ -53,8 +53,8 @@ public void shutdown() throws IOException {
public void testEmbeddingsRequestAction_Titan() throws IOException {
var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool);
- var mockedFloatResults = List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }));
- var mockedResult = new TextEmbeddingFloatResults(mockedFloatResults);
+ var mockedFloatResults = List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }));
+ var mockedResult = new DenseEmbeddingFloatResults(mockedFloatResults);
try (var sender = new AmazonBedrockMockRequestSender()) {
sender.enqueue(mockedResult);
var creator = new AmazonBedrockActionCreator(sender, serviceComponents, TIMEOUT);
@@ -91,8 +91,8 @@ public void testEmbeddingsRequestAction_Titan() throws IOException {
public void testEmbeddingsRequestAction_Cohere() throws IOException {
var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool);
- var mockedFloatResults = List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }));
- var mockedResult = new TextEmbeddingFloatResults(mockedFloatResults);
+ var mockedFloatResults = List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }));
+ var mockedResult = new DenseEmbeddingFloatResults(mockedFloatResults);
try (var sender = new AmazonBedrockMockRequestSender()) {
sender.enqueue(mockedResult);
var creator = new AmazonBedrockActionCreator(sender, serviceComponents, TIMEOUT);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockExecutorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockExecutorTests.java
index 19fe2ca784b1c..d267f80080885 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockExecutorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockExecutorTests.java
@@ -34,7 +34,7 @@
import java.util.List;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.common.TruncatorTests.createTruncator;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java
index d608b6ec9fb8e..38a9ca17fec7b 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java
@@ -32,7 +32,7 @@
import java.util.concurrent.atomic.AtomicReference;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.services.amazonbedrock.client.AmazonBedrockExecutorTests.TEST_AMAZON_TITAN_EMBEDDINGS_RESULT;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java
index 06668b91e1965..734d66a20bf56 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java
@@ -39,8 +39,8 @@
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@@ -1347,10 +1347,10 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 0.0123f, -0.0123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
@@ -1359,10 +1359,10 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 1.0123f, -1.0123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java
index 975661d180795..ef82cc18946b1 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java
@@ -46,7 +46,7 @@
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioEmbeddingsResponseEntityTests.java
index 3da3598a4637a..e2ebd72b17da5 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioEmbeddingsResponseEntityTests.java
@@ -9,7 +9,7 @@
import org.apache.http.HttpResponse;
import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
@@ -50,11 +50,11 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
var entity = new AzureAiStudioEmbeddingsResponseEntity();
- var parsedResults = (TextEmbeddingFloatResults) entity.apply(
+ var parsedResults = (DenseEmbeddingFloatResults) entity.apply(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingFloatResults.Embedding.of(List.of(0.014539449F, -0.015288644F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(DenseEmbeddingFloatResults.Embedding.of(List.of(0.014539449F, -0.015288644F)))));
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java
index b5d8fb887c62c..6c0b4f272650a 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java
@@ -37,7 +37,7 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@@ -63,7 +63,7 @@
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
@@ -1007,10 +1007,10 @@ private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOExcepti
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 0.123f, -0.123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
@@ -1019,10 +1019,10 @@ private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOExcepti
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 1.123f, -1.123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java
index 1724747f9861a..23c773ab0d61b 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java
@@ -42,7 +42,7 @@
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.core.Strings.format;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java
index 20ac31398de2f..c6db0f441d348 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java
@@ -44,7 +44,7 @@
import java.util.List;
import java.util.concurrent.TimeUnit;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
index 0f897500698ee..a340236e31eaa 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
@@ -39,8 +39,8 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@@ -69,7 +69,7 @@
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
@@ -1351,7 +1351,7 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException {
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
assertArrayEquals(
new float[] { 0.123f, -0.123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
@@ -1362,7 +1362,7 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException {
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
assertArrayEquals(
new float[] { 0.223f, -0.223f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
@@ -1448,10 +1448,10 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException {
var byteResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(byteResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), byteResult.chunks().get(0).offset());
- assertThat(byteResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class));
+ assertThat(byteResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class));
assertArrayEquals(
new byte[] { 23, -23 },
- ((TextEmbeddingByteResults.Embedding) byteResult.chunks().get(0).embedding()).values()
+ ((DenseEmbeddingByteResults.Embedding) byteResult.chunks().get(0).embedding()).values()
);
}
{
@@ -1459,10 +1459,10 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException {
var byteResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(byteResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), byteResult.chunks().get(0).offset());
- assertThat(byteResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class));
+ assertThat(byteResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class));
assertArrayEquals(
new byte[] { 24, -24 },
- ((TextEmbeddingByteResults.Embedding) byteResult.chunks().get(0).embedding()).values()
+ ((DenseEmbeddingByteResults.Embedding) byteResult.chunks().get(0).embedding()).values()
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java
index f3d816d9a118d..a6f3e6a07d66c 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java
@@ -40,7 +40,7 @@
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java
index f5191ecac5ce6..008bce852b8d6 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java
@@ -44,8 +44,8 @@
import java.util.Map;
import java.util.concurrent.TimeUnit;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationByte;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationByte;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntityTests.java
index 6df356bfe0a80..78bf38f204bf7 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntityTests.java
@@ -10,9 +10,9 @@
import org.apache.http.HttpResponse;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.hamcrest.MatcherAssert;
@@ -56,10 +56,10 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- MatcherAssert.assertThat(parsedResults, instanceOf(TextEmbeddingFloatResults.class));
+ MatcherAssert.assertThat(parsedResults, instanceOf(DenseEmbeddingFloatResults.class));
MatcherAssert.assertThat(
- ((TextEmbeddingFloatResults) parsedResults).embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F })))
+ ((DenseEmbeddingFloatResults) parsedResults).embeddings(),
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F })))
);
}
@@ -90,14 +90,14 @@ public void testFromResponse_CreatesResultsForASingleItem_ObjectFormat() throws
}
""";
- TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
MatcherAssert.assertThat(
parsedResults.embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F })))
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F })))
);
}
@@ -134,14 +134,14 @@ public void testFromResponse_UsesTheFirstValidEmbeddingsEntry() throws IOExcepti
}
""";
- TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
MatcherAssert.assertThat(
parsedResults.embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F })))
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F })))
);
}
@@ -178,14 +178,14 @@ public void testFromResponse_UsesTheFirstValidEmbeddingsEntryInt8_WithInvalidFir
}
""";
- TextEmbeddingByteResults parsedResults = (TextEmbeddingByteResults) CohereEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingByteResults parsedResults = (DenseEmbeddingByteResults) CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
MatcherAssert.assertThat(
parsedResults.embeddings(),
- is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -1, (byte) 0 })))
+ is(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -1, (byte) 0 })))
);
}
@@ -216,14 +216,14 @@ public void testFromResponse_ParsesBytes() throws IOException {
}
""";
- TextEmbeddingByteResults parsedResults = (TextEmbeddingByteResults) CohereEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingByteResults parsedResults = (DenseEmbeddingByteResults) CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
MatcherAssert.assertThat(
parsedResults.embeddings(),
- is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -1, (byte) 0 })))
+ is(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -1, (byte) 0 })))
);
}
@@ -257,14 +257,14 @@ public void testFromResponse_ParsesBytes_FromBinaryEmbeddingsEntry() throws IOEx
}
""";
- TextEmbeddingBitResults parsedResults = (TextEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingBitResults parsedResults = (DenseEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
MatcherAssert.assertThat(
parsedResults.embeddings(),
- is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 })))
+ is(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 })))
);
}
@@ -297,7 +297,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
}
""";
- TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -306,8 +306,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
parsedResults.embeddings(),
is(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { -0.123F, 0.123F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { -0.123F, 0.123F })
)
)
);
@@ -344,7 +344,7 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw
}
""";
- TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -353,8 +353,8 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw
parsedResults.embeddings(),
is(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { -0.123F, 0.123F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { -0.123F, 0.123F })
)
)
);
@@ -397,7 +397,7 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat_Binary(
}
""";
- TextEmbeddingBitResults parsedResults = (TextEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingBitResults parsedResults = (DenseEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -406,8 +406,8 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat_Binary(
parsedResults.embeddings(),
is(
List.of(
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67 }),
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 34, (byte) -64, (byte) 97, (byte) 65, (byte) -42 })
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67 }),
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 34, (byte) -64, (byte) 97, (byte) 65, (byte) -42 })
)
)
);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java
index 1c5c13e2086c2..fb9802c6938d5 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java
@@ -14,7 +14,7 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
-import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
+import org.elasticsearch.xpack.inference.services.custom.response.DenseEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.hamcrest.MatcherAssert;
@@ -94,7 +94,7 @@ public static CustomModel createModel(
public static CustomModel getTestModel() {
return getTestModel(
TaskType.TEXT_EMBEDDING,
- new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT)
+ new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT)
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java
index 65d9de30576ff..74c3e0f57e73c 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java
@@ -26,10 +26,10 @@
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser;
+import org.elasticsearch.xpack.inference.services.custom.response.DenseEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser;
-import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.hamcrest.MatcherAssert;
@@ -60,7 +60,7 @@ public static CustomServiceSettings createRandom(String inputUrl) {
var requestContentString = randomAlphaOfLength(10);
var responseJsonParser = switch (taskType) {
- case TEXT_EMBEDDING -> new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT);
+ case TEXT_EMBEDDING -> new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT);
case SPARSE_EMBEDDING -> new SparseEmbeddingResponseParser(
"$.result.sparse_embeddings[*].embedding[*].token_id",
"$.result.sparse_embeddings[*].embedding[*].weights"
@@ -100,7 +100,7 @@ public void testFromMap() {
var queryParameters = List.of(List.of("key", "value"));
String requestContentString = "request body";
- var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT);
+ var responseParser = new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT);
var settings = CustomServiceSettings.fromMap(
new HashMap<>(
@@ -124,7 +124,7 @@ public void testFromMap() {
Map.of(
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
- Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
+ Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
)
)
),
@@ -158,7 +158,7 @@ public void testFromMap_EmbeddingType_Bit() {
String url = "http://www.abc.com";
String requestContentString = "request body";
- var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.BIT);
+ var responseParser = new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.BIT);
var settings = CustomServiceSettings.fromMap(
new HashMap<>(
@@ -173,9 +173,9 @@ public void testFromMap_EmbeddingType_Bit() {
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
Map.of(
- TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS,
+ DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS,
"$.result.embeddings[*].embedding",
- TextEmbeddingResponseParser.EMBEDDING_TYPE,
+ DenseEmbeddingResponseParser.EMBEDDING_TYPE,
CustomServiceEmbeddingType.BIT.toString()
)
)
@@ -207,7 +207,7 @@ public void testFromMap_EmbeddingType_Binary() {
String url = "http://www.abc.com";
String requestContentString = "request body";
- var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.BINARY);
+ var responseParser = new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.BINARY);
var settings = CustomServiceSettings.fromMap(
new HashMap<>(
@@ -222,9 +222,9 @@ public void testFromMap_EmbeddingType_Binary() {
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
Map.of(
- TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS,
+ DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS,
"$.result.embeddings[*].embedding",
- TextEmbeddingResponseParser.EMBEDDING_TYPE,
+ DenseEmbeddingResponseParser.EMBEDDING_TYPE,
CustomServiceEmbeddingType.BINARY.toString()
)
)
@@ -256,7 +256,7 @@ public void testFromMap_EmbeddingType_Byte() {
String url = "http://www.abc.com";
String requestContentString = "request body";
- var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.BYTE);
+ var responseParser = new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.BYTE);
var settings = CustomServiceSettings.fromMap(
new HashMap<>(
@@ -271,9 +271,9 @@ public void testFromMap_EmbeddingType_Byte() {
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
Map.of(
- TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS,
+ DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS,
"$.result.embeddings[*].embedding",
- TextEmbeddingResponseParser.EMBEDDING_TYPE,
+ DenseEmbeddingResponseParser.EMBEDDING_TYPE,
CustomServiceEmbeddingType.BYTE.toString()
)
)
@@ -367,7 +367,7 @@ public void testFromMap_Completion_ThrowsWhenEmbeddingIsIncludedInMap() {
Map.of(
CompletionResponseParser.COMPLETION_PARSER_RESULT,
"$.result.text",
- TextEmbeddingResponseParser.EMBEDDING_TYPE,
+ DenseEmbeddingResponseParser.EMBEDDING_TYPE,
"byte"
)
)
@@ -393,7 +393,7 @@ public void testFromMap_WithOptionalsNotSpecified() {
String url = "http://www.abc.com";
String requestContentString = "request body";
- var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT);
+ var responseParser = new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT);
var settings = CustomServiceSettings.fromMap(
new HashMap<>(
@@ -407,7 +407,7 @@ public void testFromMap_WithOptionalsNotSpecified() {
Map.of(
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
- Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
+ Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
)
)
)
@@ -445,7 +445,7 @@ public void testFromMap_RemovesNullValues_FromMaps() {
String requestContentString = "request body";
- var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT);
+ var responseParser = new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT);
var settings = CustomServiceSettings.fromMap(
new HashMap<>(
@@ -467,7 +467,7 @@ public void testFromMap_RemovesNullValues_FromMaps() {
Map.of(
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
- Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
+ Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
)
)
)
@@ -519,7 +519,7 @@ public void testFromMap_ReturnsError_IfHeadersContainsNonStringValues() {
Map.of(
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
- Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
+ Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
)
)
)
@@ -566,7 +566,7 @@ public void testFromMap_ReturnsError_IfQueryParamsContainsNonStringValues() {
Map.of(
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
- Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
+ Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
)
)
)
@@ -604,7 +604,7 @@ public void testFromMap_ReturnsError_IfRequestMapIsMissing() {
Map.of(
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
- Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
+ Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
)
)
)
@@ -636,7 +636,7 @@ public void testFromMap_ReturnsError_IfResponseMapIsMissing() {
Map.of(
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
- Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
+ Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
)
)
)
@@ -675,7 +675,7 @@ public void testFromMap_ReturnsError_IfJsonParserMapIsNotEmptyAfterParsing() {
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
Map.of(
- TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS,
+ DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS,
"$.result.embeddings[*].embedding",
"key",
"value"
@@ -717,7 +717,7 @@ public void testFromMap_ReturnsError_IfResponseMapIsNotEmptyAfterParsing() {
Map.of(
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
- Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
+ Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
),
"key",
"value"
@@ -757,7 +757,7 @@ public void testFromMap_ReturnsError_IfTaskTypeIsInvalid() {
Map.of(
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
- Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
+ Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
)
)
)
@@ -779,7 +779,7 @@ public void testXContent() throws IOException {
Map.of("key", "value"),
null,
"string",
- new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT),
+ new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT),
null
);
@@ -868,7 +868,7 @@ public void testXContent_WithInputTypeTranslationValues() throws IOException {
Map.of("key", "value"),
null,
"string",
- new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT),
+ new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT),
null,
null,
new InputTypeTranslator(Map.of(InputType.SEARCH, "do_search", InputType.INGEST, "do_ingest"), "a_default")
@@ -915,7 +915,7 @@ public void testXContent_BatchSize11() throws IOException {
Map.of("key", "value"),
null,
"string",
- new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT),
+ new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT),
null,
11,
InputTypeTranslator.EMPTY_TRANSLATOR
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java
index 44bf17c3ac96d..1d43330381fa4 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java
@@ -28,9 +28,9 @@
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@@ -39,9 +39,9 @@
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
+import org.elasticsearch.xpack.inference.services.custom.response.DenseEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser;
-import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.hamcrest.CoreMatchers;
import org.hamcrest.Matchers;
@@ -138,7 +138,7 @@ private static void assertTextEmbeddingModel(Model model, boolean modelIncludesS
var customModel = assertCommonModelFields(model, modelIncludesSecrets);
assertThat(customModel.getTaskType(), is(TaskType.TEXT_EMBEDDING));
- assertThat(customModel.getServiceSettings().getResponseJsonParser(), instanceOf(TextEmbeddingResponseParser.class));
+ assertThat(customModel.getServiceSettings().getResponseJsonParser(), instanceOf(DenseEmbeddingResponseParser.class));
}
private static CustomModel assertCommonModelFields(Model model, boolean modelIncludesSecrets) {
@@ -204,7 +204,7 @@ private static Map createServiceSettingsMap(TaskType taskType) {
private static Map createResponseParserMap(TaskType taskType) {
return switch (taskType) {
case TEXT_EMBEDDING -> new HashMap<>(
- Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
+ Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
);
case COMPLETION -> new HashMap<>(Map.of(CompletionResponseParser.COMPLETION_PARSER_RESULT, "$.result.text"));
case SPARSE_EMBEDDING -> new HashMap<>(
@@ -240,18 +240,18 @@ private static Map createSecretSettingsMap() {
private static CustomModel createInternalEmbeddingModel(SimilarityMeasure similarityMeasure) {
return createInternalEmbeddingModel(
similarityMeasure,
- new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT),
+ new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT),
"http://www.abc.com"
);
}
- private static CustomModel createInternalEmbeddingModel(TextEmbeddingResponseParser parser, String url) {
+ private static CustomModel createInternalEmbeddingModel(DenseEmbeddingResponseParser parser, String url) {
return createInternalEmbeddingModel(SimilarityMeasure.DOT_PRODUCT, parser, url);
}
private static CustomModel createInternalEmbeddingModel(
@Nullable SimilarityMeasure similarityMeasure,
- TextEmbeddingResponseParser parser,
+ DenseEmbeddingResponseParser parser,
String url
) {
var inferenceId = "inference_id";
@@ -276,7 +276,7 @@ private static CustomModel createInternalEmbeddingModel(
private static CustomModel createInternalEmbeddingModel(
@Nullable SimilarityMeasure similarityMeasure,
- TextEmbeddingResponseParser parser,
+ DenseEmbeddingResponseParser parser,
String url,
@Nullable ChunkingSettings chunkingSettings,
@Nullable Integer batchSize
@@ -336,7 +336,7 @@ public void testInfer_ReturnsAnError_WithoutParsingTheResponseBody() throws IOEx
webServer.enqueue(new MockResponse().setResponseCode(400).setBody(responseJson));
var model = createInternalEmbeddingModel(
- new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT),
+ new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT),
getUrl(webServer)
);
PlainActionFuture listener = new PlainActionFuture<>();
@@ -394,7 +394,7 @@ public void testInfer_HandlesTextEmbeddingRequest_OpenAI_Format() throws IOExcep
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
var model = createInternalEmbeddingModel(
- new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT),
+ new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT),
getUrl(webServer)
);
PlainActionFuture listener = new PlainActionFuture<>();
@@ -412,12 +412,12 @@ public void testInfer_HandlesTextEmbeddingRequest_OpenAI_Format() throws IOExcep
);
InferenceServiceResults results = listener.actionGet(TIMEOUT);
- assertThat(results, instanceOf(TextEmbeddingFloatResults.class));
+ assertThat(results, instanceOf(DenseEmbeddingFloatResults.class));
- var embeddingResults = (TextEmbeddingFloatResults) results;
+ var embeddingResults = (DenseEmbeddingFloatResults) results;
assertThat(
embeddingResults.embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })))
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })))
);
}
}
@@ -676,7 +676,7 @@ public void testParseRequestConfig_ThrowsAValidationError_WhenReplacementDoesNot
public void testChunkedInfer_ChunkingSettingsSet() throws IOException {
var model = createInternalEmbeddingModel(
SimilarityMeasure.DOT_PRODUCT,
- new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT),
+ new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT),
getUrl(webServer),
ChunkingSettingsTests.createRandomChunkingSettings(),
2
@@ -732,10 +732,10 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException {
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 0.123f, -0.123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
@@ -744,10 +744,10 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException {
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 0.223f, -0.223f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
@@ -762,7 +762,7 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException {
public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
var model = createInternalEmbeddingModel(
- new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT),
+ new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT),
getUrl(webServer)
);
String responseJson = """
@@ -807,10 +807,10 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 0.123f, -0.123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java
index 8b35979c3daf5..8ad0b028d3b34 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java
@@ -27,8 +27,8 @@
import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings;
import org.elasticsearch.xpack.inference.services.custom.InputTypeTranslator;
import org.elasticsearch.xpack.inference.services.custom.QueryParameters;
+import org.elasticsearch.xpack.inference.services.custom.response.DenseEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
-import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import java.io.IOException;
@@ -62,7 +62,7 @@ public void testCreateRequest() throws IOException {
headers,
new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"), new QueryParameters.Parameter("key", "value2"))),
requestContentString,
- new TextEmbeddingResponseParser("$.result.embeddings", CustomServiceEmbeddingType.FLOAT),
+ new DenseEmbeddingResponseParser("$.result.embeddings", CustomServiceEmbeddingType.FLOAT),
new RateLimitSettings(10_000),
null,
new InputTypeTranslator(Map.of(InputType.INGEST, "value"), "default")
@@ -122,7 +122,7 @@ public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() throws IOEx
)
),
requestContentString,
- new TextEmbeddingResponseParser("$.result.embeddings", CustomServiceEmbeddingType.FLOAT),
+ new DenseEmbeddingResponseParser("$.result.embeddings", CustomServiceEmbeddingType.FLOAT),
new RateLimitSettings(10_000),
null,
new InputTypeTranslator(Map.of(InputType.INGEST, "value"), "default")
@@ -180,7 +180,7 @@ public void testCreateRequest_SecretsInTheJsonBody_AreEncodedCorrectly() throws
headers,
new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"), new QueryParameters.Parameter("key", "value2"))),
requestContentString,
- new TextEmbeddingResponseParser("$.result.embeddings", CustomServiceEmbeddingType.FLOAT),
+ new DenseEmbeddingResponseParser("$.result.embeddings", CustomServiceEmbeddingType.FLOAT),
new RateLimitSettings(10_000)
);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java
index e53add6733aca..b475a770c7cfe 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java
@@ -13,9 +13,9 @@
import org.elasticsearch.inference.WeightedToken;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
@@ -62,7 +62,7 @@ public void testFromTextEmbeddingResponse() throws IOException {
var model = CustomModelTests.getTestModel(
TaskType.TEXT_EMBEDDING,
- new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT)
+ new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT)
);
var request = new CustomRequest(
EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null), model.getServiceSettings().getInputTypeTranslator()),
@@ -73,10 +73,10 @@ public void testFromTextEmbeddingResponse() throws IOException {
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(results, instanceOf(TextEmbeddingFloatResults.class));
+ assertThat(results, instanceOf(DenseEmbeddingFloatResults.class));
assertThat(
- ((TextEmbeddingFloatResults) results).embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { -0.02868066355586052f, 0.022033605724573135f })))
+ ((DenseEmbeddingFloatResults) results).embeddings(),
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { -0.02868066355586052f, 0.022033605724573135f })))
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/DenseEmbeddingResponseParserTests.java
similarity index 71%
rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java
rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/DenseEmbeddingResponseParserTests.java
index 7796b5e1e7f6b..cffb2a5db6af5 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/DenseEmbeddingResponseParserTests.java
@@ -16,9 +16,9 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.services.custom.CustomServiceEmbeddingType;
@@ -29,24 +29,24 @@
import java.util.List;
import java.util.Map;
-import static org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser.EMBEDDING_TYPE;
-import static org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS;
+import static org.elasticsearch.xpack.inference.services.custom.response.DenseEmbeddingResponseParser.EMBEDDING_TYPE;
+import static org.elasticsearch.xpack.inference.services.custom.response.DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
-public class TextEmbeddingResponseParserTests extends AbstractBWCWireSerializationTestCase {
+public class DenseEmbeddingResponseParserTests extends AbstractBWCWireSerializationTestCase {
private static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE = TransportVersion.fromName(
"ml_inference_custom_service_embedding_type"
);
- public static TextEmbeddingResponseParser createRandom() {
- return new TextEmbeddingResponseParser("$." + randomAlphaOfLength(5), randomFrom(CustomServiceEmbeddingType.values()));
+ public static DenseEmbeddingResponseParser createRandom() {
+ return new DenseEmbeddingResponseParser("$." + randomAlphaOfLength(5), randomFrom(CustomServiceEmbeddingType.values()));
}
public void testFromMap() {
var validation = new ValidationException();
- var parser = TextEmbeddingResponseParser.fromMap(
+ var parser = DenseEmbeddingResponseParser.fromMap(
new HashMap<>(
Map.of(
TEXT_EMBEDDING_PARSER_EMBEDDINGS,
@@ -59,14 +59,14 @@ public void testFromMap() {
validation
);
- assertThat(parser, is(new TextEmbeddingResponseParser("$.result[*].embeddings", CustomServiceEmbeddingType.BIT)));
+ assertThat(parser, is(new DenseEmbeddingResponseParser("$.result[*].embeddings", CustomServiceEmbeddingType.BIT)));
}
public void testFromMap_ThrowsException_WhenRequiredFieldIsNotPresent() {
var validation = new ValidationException();
var exception = expectThrows(
ValidationException.class,
- () -> TextEmbeddingResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.result[*].embeddings")), "scope", validation)
+ () -> DenseEmbeddingResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.result[*].embeddings")), "scope", validation)
);
assertThat(
@@ -76,7 +76,7 @@ public void testFromMap_ThrowsException_WhenRequiredFieldIsNotPresent() {
}
public void testToXContent() throws IOException {
- var entity = new TextEmbeddingResponseParser("$.result.path", CustomServiceEmbeddingType.BINARY);
+ var entity = new DenseEmbeddingResponseParser("$.result.path", CustomServiceEmbeddingType.BINARY);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
{
@@ -120,14 +120,18 @@ public void testParse() throws IOException {
}
""";
- var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT);
- TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) parser.parse(
+ var parser = new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT);
+ DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) parser.parse(
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
assertThat(
parsedResults,
- is(new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))))
+ is(
+ new DenseEmbeddingFloatResults(
+ List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))
+ )
+ )
);
}
@@ -153,12 +157,15 @@ public void testParseByte() throws IOException {
}
""";
- var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.BYTE);
- TextEmbeddingByteResults parsedResults = (TextEmbeddingByteResults) parser.parse(
+ var parser = new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.BYTE);
+ DenseEmbeddingByteResults parsedResults = (DenseEmbeddingByteResults) parser.parse(
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults, is(new TextEmbeddingByteResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { 1, -2 })))));
+ assertThat(
+ parsedResults,
+ is(new DenseEmbeddingByteResults(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { 1, -2 }))))
+ );
}
public void testParseBit() throws IOException {
@@ -183,12 +190,12 @@ public void testParseBit() throws IOException {
}
""";
- var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.BIT);
- TextEmbeddingBitResults parsedResults = (TextEmbeddingBitResults) parser.parse(
+ var parser = new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.BIT);
+ DenseEmbeddingBitResults parsedResults = (DenseEmbeddingBitResults) parser.parse(
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults, is(new TextEmbeddingBitResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { 1, -2 })))));
+ assertThat(parsedResults, is(new DenseEmbeddingBitResults(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { 1, -2 })))));
}
public void testParse_MultipleEmbeddings() throws IOException {
@@ -221,18 +228,18 @@ public void testParse_MultipleEmbeddings() throws IOException {
}
""";
- var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT);
- TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) parser.parse(
+ var parser = new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT);
+ DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) parser.parse(
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
assertThat(
parsedResults,
is(
- new TextEmbeddingFloatResults(
+ new DenseEmbeddingFloatResults(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 1F, -2F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 1F, -2F })
)
)
)
@@ -269,7 +276,7 @@ public void testParse_ThrowsException_WhenExtractedField_IsNotAListOfFloats() {
}
""";
- var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT);
+ var parser = new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT);
var exception = expectThrows(
IllegalArgumentException.class,
() -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)))
@@ -303,7 +310,7 @@ public void testParse_ThrowsException_WhenExtractedField_IsNotAList() {
}
""";
- var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT);
+ var parser = new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT);
var exception = expectThrows(
IllegalArgumentException.class,
() -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)))
@@ -319,25 +326,25 @@ public void testParse_ThrowsException_WhenExtractedField_IsNotAList() {
}
@Override
- protected TextEmbeddingResponseParser mutateInstanceForVersion(TextEmbeddingResponseParser instance, TransportVersion version) {
+ protected DenseEmbeddingResponseParser mutateInstanceForVersion(DenseEmbeddingResponseParser instance, TransportVersion version) {
if (version.supports(ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE) == false) {
- return new TextEmbeddingResponseParser(instance.getTextEmbeddingsPath(), CustomServiceEmbeddingType.FLOAT);
+ return new DenseEmbeddingResponseParser(instance.getTextEmbeddingsPath(), CustomServiceEmbeddingType.FLOAT);
}
return instance;
}
@Override
- protected Writeable.Reader instanceReader() {
- return TextEmbeddingResponseParser::new;
+ protected Writeable.Reader instanceReader() {
+ return DenseEmbeddingResponseParser::new;
}
@Override
- protected TextEmbeddingResponseParser createTestInstance() {
+ protected DenseEmbeddingResponseParser createTestInstance() {
return createRandom();
}
@Override
- protected TextEmbeddingResponseParser mutateInstance(TextEmbeddingResponseParser instance) throws IOException {
- return randomValueOtherThan(instance, TextEmbeddingResponseParserTests::createRandom);
+ protected DenseEmbeddingResponseParser mutateInstance(DenseEmbeddingResponseParser instance) throws IOException {
+ return randomValueOtherThan(instance, DenseEmbeddingResponseParserTests::createRandom);
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java
index 760312fe2d97b..5f14ef1435837 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java
@@ -40,8 +40,8 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
@@ -889,9 +889,9 @@ public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOExceptio
var denseResult = (ChunkedInferenceEmbedding) results.getFirst();
assertThat(denseResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, "hello world".length()), denseResult.chunks().getFirst().offset());
- assertThat(denseResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(denseResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
- var embedding = (TextEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding();
+ var embedding = (DenseEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding();
assertArrayEquals(new float[] { 0.123f, -0.456f, 0.789f }, embedding.values(), 0.0f);
}
@@ -901,9 +901,9 @@ public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOExceptio
var denseResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(denseResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, "dense embedding".length()), denseResult.chunks().getFirst().offset());
- assertThat(denseResult.chunks().getFirst().embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(denseResult.chunks().getFirst().embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
- var embedding = (TextEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding();
+ var embedding = (DenseEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding();
assertArrayEquals(new float[] { 0.987f, -0.654f, 0.321f }, embedding.values(), 0.0f);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java
index 1d91713eef931..ab715b7f73e96 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java
@@ -20,9 +20,9 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@@ -298,8 +298,8 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction()
var result = listener.actionGet(TIMEOUT);
- assertThat(result, instanceOf(TextEmbeddingFloatResults.class));
- var textEmbeddingResults = (TextEmbeddingFloatResults) result;
+ assertThat(result, instanceOf(DenseEmbeddingFloatResults.class));
+ var textEmbeddingResults = (DenseEmbeddingFloatResults) result;
assertThat(textEmbeddingResults.embeddings(), hasSize(2));
var firstEmbedding = textEmbeddingResults.embeddings().get(0);
@@ -354,8 +354,8 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_W
var result = listener.actionGet(TIMEOUT);
- assertThat(result, instanceOf(TextEmbeddingFloatResults.class));
- var textEmbeddingResults = (TextEmbeddingFloatResults) result;
+ assertThat(result, instanceOf(DenseEmbeddingFloatResults.class));
+ var textEmbeddingResults = (DenseEmbeddingFloatResults) result;
assertThat(textEmbeddingResults.embeddings(), hasSize(1));
var embedding = textEmbeddingResults.embeddings().get(0);
@@ -447,8 +447,8 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_E
var result = listener.actionGet(TIMEOUT);
- assertThat(result, instanceOf(TextEmbeddingFloatResults.class));
- var textEmbeddingResults = (TextEmbeddingFloatResults) result;
+ assertThat(result, instanceOf(DenseEmbeddingFloatResults.class));
+ var textEmbeddingResults = (DenseEmbeddingFloatResults) result;
assertThat(textEmbeddingResults.embeddings(), hasSize(0));
assertThat(webServer.requests(), hasSize(1));
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests.java
index 2883a1ab73c21..79721b95af067 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests.java
@@ -9,7 +9,7 @@
import org.apache.http.HttpResponse;
import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceDenseTextEmbeddingsResponseEntity;
@@ -35,7 +35,7 @@ public void testDenseTextEmbeddingsResponse_SingleEmbeddingInData_NoMeta() throw
}
""";
- TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -64,7 +64,7 @@ public void testDenseTextEmbeddingsResponse_MultipleEmbeddingsInData_NoMeta() th
}
""";
- TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -85,7 +85,7 @@ public void testDenseTextEmbeddingsResponse_EmptyData() throws Exception {
}
""";
- TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -111,7 +111,7 @@ public void testDenseTextEmbeddingsResponse_SingleEmbeddingInData_IgnoresMeta()
}
""";
- TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java
index 2cc0323e6d913..0db705c1770ba 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java
@@ -53,9 +53,9 @@
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
@@ -71,8 +71,8 @@
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentTests;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResultsTests;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResultsTests;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResultsTests;
import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResultsTests;
@@ -1109,8 +1109,8 @@ public void testChunkInfer_E5ChunkingSettingsSet() throws InterruptedException {
@SuppressWarnings("unchecked")
private void testChunkInfer_e5(ChunkingSettings chunkingSettings) throws InterruptedException {
var mlTrainedModelResults = new ArrayList();
- mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
- mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
+ mlTrainedModelResults.add(MlDenseEmbeddingResultsTests.createRandomResults());
+ mlTrainedModelResults.add(MlDenseEmbeddingResultsTests.createRandomResults());
var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true);
Client client = mock(Client.class);
@@ -1136,20 +1136,20 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) throws Interru
assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbedding.class));
var result1 = (ChunkedInferenceEmbedding) chunkedResponse.get(0);
assertThat(result1.chunks(), hasSize(1));
- assertThat(result1.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(result1.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
- ((MlTextEmbeddingResults) mlTrainedModelResults.get(0)).getInferenceAsFloat(),
- ((TextEmbeddingFloatResults.Embedding) result1.chunks().get(0).embedding()).values(),
+ ((MlDenseEmbeddingResults) mlTrainedModelResults.get(0)).getInferenceAsFloat(),
+ ((DenseEmbeddingFloatResults.Embedding) result1.chunks().get(0).embedding()).values(),
0.0001f
);
assertEquals(new ChunkedInference.TextOffset(0, 1), result1.chunks().get(0).offset());
assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbedding.class));
var result2 = (ChunkedInferenceEmbedding) chunkedResponse.get(1);
assertThat(result2.chunks(), hasSize(1));
- assertThat(result2.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(result2.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
- ((MlTextEmbeddingResults) mlTrainedModelResults.get(1)).getInferenceAsFloat(),
- ((TextEmbeddingFloatResults.Embedding) result2.chunks().get(0).embedding()).values(),
+ ((MlDenseEmbeddingResults) mlTrainedModelResults.get(1)).getInferenceAsFloat(),
+ ((DenseEmbeddingFloatResults.Embedding) result2.chunks().get(0).embedding()).values(),
0.0001f
);
assertEquals(new ChunkedInference.TextOffset(0, 2), result2.chunks().get(0).offset());
@@ -1377,8 +1377,8 @@ public void testChunkInferSetsTokenization() {
@SuppressWarnings("unchecked")
public void testChunkInfer_FailsBatch() throws InterruptedException {
var mlTrainedModelResults = new ArrayList();
- mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
- mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
+ mlTrainedModelResults.add(MlDenseEmbeddingResultsTests.createRandomResults());
+ mlTrainedModelResults.add(MlDenseEmbeddingResultsTests.createRandomResults());
mlTrainedModelResults.add(new ErrorInferenceResults(new RuntimeException("boom")));
var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true);
@@ -1454,7 +1454,7 @@ public void testChunkingLargeDocument() throws InterruptedException {
var listener = (ActionListener) invocationOnMock.getArguments()[2];
var mlTrainedModelResults = new ArrayList();
for (int i = 0; i < request.numberOfDocuments(); i++) {
- mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
+ mlTrainedModelResults.add(MlDenseEmbeddingResultsTests.createRandomResults());
}
var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true);
listener.onResponse(response);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java
index be087f73f8d5b..7a5fef5ff0995 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java
@@ -39,7 +39,7 @@
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@@ -66,7 +66,7 @@
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
@@ -941,11 +941,11 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, input.get(0).input().length()), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertTrue(
Arrays.equals(
new float[] { 0.0123f, -0.0123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
)
);
}
@@ -956,11 +956,11 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, input.get(1).input().length()), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertTrue(
Arrays.equals(
new float[] { 0.0456f, -0.0456f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
)
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java
index 38e33c546b291..9e6b5547a6b1d 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java
@@ -40,7 +40,7 @@
import java.util.Map;
import java.util.concurrent.TimeUnit;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntityTests.java
index eca4a369c29c8..6bfb33769602f 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntityTests.java
@@ -9,7 +9,7 @@
import org.apache.http.HttpResponse;
import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
@@ -36,12 +36,12 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
}
""";
- TextEmbeddingFloatResults parsedResults = GoogleAiStudioEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = GoogleAiStudioEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(DenseEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F)))));
}
public void testFromResponse_CreatesResultsForMultipleItems() throws IOException {
@@ -64,7 +64,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
}
""";
- TextEmbeddingFloatResults parsedResults = GoogleAiStudioEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = GoogleAiStudioEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -73,8 +73,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
parsedResults.embeddings(),
is(
List.of(
- TextEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F)),
- TextEmbeddingFloatResults.Embedding.of(List.of(0.030681048F, 0.01714732F))
+ DenseEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F)),
+ DenseEmbeddingFloatResults.Embedding.of(List.of(0.030681048F, 0.01714732F))
)
)
);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntityTests.java
index 8f19edb3031d7..c3a8a8b74078b 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntityTests.java
@@ -9,7 +9,7 @@
import org.apache.http.HttpResponse;
import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
@@ -42,12 +42,12 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
}
""";
- TextEmbeddingFloatResults parsedResults = GoogleVertexAiEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = GoogleVertexAiEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingFloatResults.Embedding.of(List.of(-0.123F, 0.123F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(DenseEmbeddingFloatResults.Embedding.of(List.of(-0.123F, 0.123F)))));
}
public void testFromResponse_CreatesResultsForMultipleItems() throws IOException {
@@ -82,7 +82,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
}
""";
- TextEmbeddingFloatResults parsedResults = GoogleVertexAiEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = GoogleVertexAiEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -91,8 +91,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
parsedResults.embeddings(),
is(
List.of(
- TextEmbeddingFloatResults.Embedding.of(List.of(-0.123F, 0.123F)),
- TextEmbeddingFloatResults.Embedding.of(List.of(-0.456F, 0.456F))
+ DenseEmbeddingFloatResults.Embedding.of(List.of(-0.123F, 0.123F)),
+ DenseEmbeddingFloatResults.Embedding.of(List.of(-0.456F, 0.456F))
)
)
);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java
index 67534d83a2e15..5d60a87200f3e 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java
@@ -42,8 +42,8 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -76,7 +76,7 @@
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
@@ -1212,10 +1212,10 @@ public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() th
var embeddingResult = (ChunkedInferenceEmbedding) result;
assertThat(embeddingResult.chunks(), hasSize(1));
assertThat(embeddingResult.chunks().get(0).offset(), is(new ChunkedInference.TextOffset(0, "abc".length())));
- assertThat(embeddingResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(embeddingResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { -0.0123f, 0.0123f },
- ((TextEmbeddingFloatResults.Embedding) embeddingResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) embeddingResult.chunks().get(0).embedding()).values(),
0.001f
);
assertThat(webServer.requests(), hasSize(1));
@@ -1266,10 +1266,10 @@ public void testChunkedInfer() throws IOException {
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 3), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 0.123f, -0.123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java
index 1194f3a5fa95c..1225e40a0fc0c 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java
@@ -19,8 +19,8 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests;
import org.elasticsearch.xpack.inference.InputTypeTests;
import org.elasticsearch.xpack.inference.common.TruncatorTests;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
@@ -231,7 +231,10 @@ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws I
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F }))));
+ assertThat(
+ result.asMap(),
+ is(DenseEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))
+ );
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
@@ -416,7 +419,10 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F }))));
+ assertThat(
+ result.asMap(),
+ is(DenseEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))
+ );
assertThat(webServer.requests(), hasSize(2));
{
@@ -478,7 +484,10 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException {
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F }))));
+ assertThat(
+ result.asMap(),
+ is(DenseEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))
+ );
assertThat(webServer.requests(), hasSize(1));
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntityTests.java
index 61e035326d163..e157038f2f244 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntityTests.java
@@ -10,7 +10,7 @@
import org.apache.http.HttpResponse;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
@@ -32,14 +32,14 @@ public void testFromResponse_CreatesResultsForASingleItem_ArrayFormat() throws I
]
""";
- TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
assertThat(
parsedResults.embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
);
}
@@ -55,14 +55,14 @@ public void testFromResponse_CreatesResultsForASingleItem_ObjectFormat() throws
}
""";
- TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
assertThat(
parsedResults.embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
);
}
@@ -80,7 +80,7 @@ public void testFromResponse_CreatesResultsForMultipleItems_ArrayFormat() throws
]
""";
- TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -89,8 +89,8 @@ public void testFromResponse_CreatesResultsForMultipleItems_ArrayFormat() throws
parsedResults.embeddings(),
is(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })
)
)
);
@@ -112,7 +112,7 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw
}
""";
- TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -121,8 +121,8 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw
parsedResults.embeddings(),
is(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })
)
)
);
@@ -255,12 +255,12 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ArrayFormat() throw
]
""";
- TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 1.0F }))));
+ assertThat(parsedResults.embeddings(), is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 1.0F }))));
}
public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ObjectFormat() throws IOException {
@@ -274,12 +274,12 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ObjectFormat() thro
}
""";
- TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 1.0F }))));
+ assertThat(parsedResults.embeddings(), is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 1.0F }))));
}
public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ArrayFormat() throws IOException {
@@ -291,12 +291,12 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ArrayFormat() thro
]
""";
- TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F }))));
+ assertThat(parsedResults.embeddings(), is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F }))));
}
public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ObjectFormat() throws IOException {
@@ -310,12 +310,12 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ObjectFormat() thr
}
""";
- TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F }))));
+ assertThat(parsedResults.embeddings(), is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F }))));
}
public void testFromResponse_FailsWhenEmbeddingValueIsAnObject_ObjectFormat() {
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java
index 6e24981e3f3b3..229bbb73eb14d 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java
@@ -38,7 +38,7 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -70,7 +70,7 @@
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
@@ -789,11 +789,11 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, input.get(0).input().length()), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertTrue(
Arrays.equals(
new float[] { 0.0123f, -0.0123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
)
);
}
@@ -804,11 +804,11 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, input.get(1).input().length()), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertTrue(
Arrays.equals(
new float[] { 0.0456f, -0.0456f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
)
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java
index a0980a8151036..793bb535c0cc5 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java
@@ -44,7 +44,7 @@
import java.util.Map;
import java.util.concurrent.TimeUnit;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntityTests.java
index db42e9c49e1e5..c3bdb8c686985 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntityTests.java
@@ -9,7 +9,7 @@
import org.apache.http.HttpResponse;
import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
@@ -36,12 +36,12 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
}
""";
- TextEmbeddingFloatResults parsedResults = IbmWatsonxEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = IbmWatsonxEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(DenseEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F)))));
}
public void testFromResponse_CreatesResultsForMultipleItems() throws IOException {
@@ -66,7 +66,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
}
""";
- TextEmbeddingFloatResults parsedResults = IbmWatsonxEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = IbmWatsonxEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -75,8 +75,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
parsedResults.embeddings(),
is(
List.of(
- TextEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F)),
- TextEmbeddingFloatResults.Embedding.of(List.of(0.030681048F, 0.01714732F))
+ DenseEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F)),
+ DenseEmbeddingFloatResults.Embedding.of(List.of(0.030681048F, 0.01714732F))
)
)
);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java
index d408e269219cb..a6a04692d33bf 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java
@@ -38,7 +38,7 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@@ -65,7 +65,7 @@
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
@@ -1688,10 +1688,10 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel mode
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 0.123f, -0.123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
@@ -1700,10 +1700,10 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel mode
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 0.223f, -0.223f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntityTests.java
index c1b19cb450789..4df61cdc439af 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntityTests.java
@@ -11,9 +11,9 @@
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.InputTypeTests;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType;
@@ -69,10 +69,10 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults, instanceOf(TextEmbeddingFloatResults.class));
+ assertThat(parsedResults, instanceOf(DenseEmbeddingFloatResults.class));
assertThat(
- ((TextEmbeddingFloatResults) parsedResults).embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
+ ((DenseEmbeddingFloatResults) parsedResults).embeddings(),
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
);
}
@@ -123,13 +123,13 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults, instanceOf(TextEmbeddingFloatResults.class));
+ assertThat(parsedResults, instanceOf(DenseEmbeddingFloatResults.class));
assertThat(
- ((TextEmbeddingFloatResults) parsedResults).embeddings(),
+ ((DenseEmbeddingFloatResults) parsedResults).embeddings(),
is(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })
)
)
);
@@ -363,8 +363,8 @@ public void testFromResponse_SucceedsWhenEmbeddingType_IsBinary() throws IOExcep
);
assertThat(
- ((TextEmbeddingBitResults) parsedResults).embeddings(),
- is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 })))
+ ((DenseEmbeddingBitResults) parsedResults).embeddings(),
+ is(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 })))
);
}
@@ -411,8 +411,8 @@ public void testFromResponse_SucceedsWhenEmbeddingType_IsBit() throws IOExceptio
);
assertThat(
- ((TextEmbeddingBitResults) parsedResults).embeddings(),
- is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 })))
+ ((DenseEmbeddingBitResults) parsedResults).embeddings(),
+ is(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 })))
);
}
@@ -504,7 +504,7 @@ public void testFieldsInDifferentOrderServer() throws IOException {
}
}""";
- TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) JinaAIEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) JinaAIEmbeddingsResponseEntity.fromResponse(
JinaAIEmbeddingsRequestTests.createRequest(
List.of("abc"),
InputTypeTests.randomWithNull(),
@@ -525,9 +525,9 @@ public void testFieldsInDifferentOrderServer() throws IOException {
parsedResults.embeddings(),
is(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F })
)
)
);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java
index f6a0232db529c..818ae94192ce1 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java
@@ -42,7 +42,7 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
@@ -695,11 +695,11 @@ public void testChunkedInfer(LlamaEmbeddingsModel model) throws IOException {
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertTrue(
Arrays.equals(
new float[] { 0.010060793f, -0.0017529363f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
)
);
}
@@ -707,11 +707,11 @@ public void testChunkedInfer(LlamaEmbeddingsModel model) throws IOException {
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertTrue(
Arrays.equals(
new float[] { 0.110060793f, -0.1017529363f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
)
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java
index 5bf65870e1dfc..df2ce97614edc 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java
@@ -19,7 +19,7 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests;
import org.elasticsearch.xpack.inference.InputTypeTests;
import org.elasticsearch.xpack.inference.common.TruncatorTests;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
@@ -98,7 +98,10 @@ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws I
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F }))));
+ assertThat(
+ result.asMap(),
+ is(DenseEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))
+ );
assertEmbeddingsRequest();
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java
index db94ec5c9c2f5..3ff7bf1a78374 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java
@@ -41,7 +41,7 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.ModelConfigurationsTests;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
@@ -1135,11 +1135,11 @@ public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException {
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertTrue(
Arrays.equals(
new float[] { 0.123f, -0.123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
)
);
}
@@ -1147,11 +1147,11 @@ public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException {
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertTrue(
Arrays.equals(
new float[] { 0.223f, -0.223f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
)
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
index 0a03c7a231e8c..85c9c73d002b2 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
@@ -43,7 +43,7 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
@@ -84,7 +84,7 @@
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
@@ -1049,11 +1049,11 @@ private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException {
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertTrue(
Arrays.equals(
new float[] { 0.123f, -0.123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
)
);
}
@@ -1062,11 +1062,11 @@ private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException {
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertTrue(
Arrays.equals(
new float[] { 0.223f, -0.223f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
)
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java
index 5cdcf402835b4..1291a1ad4a36c 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java
@@ -36,7 +36,7 @@
import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java
index c34ebbde4ac8f..1851dc4cad2b3 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java
@@ -40,7 +40,7 @@
import java.util.List;
import java.util.concurrent.TimeUnit;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntityTests.java
index f2a430eefc801..77e1f384509b4 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntityTests.java
@@ -10,7 +10,7 @@
import org.apache.http.HttpResponse;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.XContentParseException;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
@@ -45,14 +45,14 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
}
""";
- TextEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
assertThat(
parsedResults.embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
);
}
@@ -86,7 +86,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
}
""";
- TextEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -95,8 +95,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
parsedResults.embeddings(),
is(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })
)
)
);
@@ -254,12 +254,12 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOExceptio
}
""";
- TextEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 1.0F }))));
+ assertThat(parsedResults.embeddings(), is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 1.0F }))));
}
public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOException {
@@ -283,12 +283,12 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOExcepti
}
""";
- TextEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F }))));
+ assertThat(parsedResults.embeddings(), is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F }))));
}
public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() {
@@ -365,7 +365,7 @@ public void testFieldsInDifferentOrderServer() throws IOException {
}
}""";
- TextEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), response.getBytes(StandardCharsets.UTF_8))
);
@@ -374,9 +374,9 @@ public void testFieldsInDifferentOrderServer() throws IOException {
parsedResults.embeddings(),
is(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F })
)
)
);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java
index 5d6bec1bcfbff..b391ce0cc7949 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java
@@ -29,7 +29,7 @@
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
@@ -449,7 +449,7 @@ public void testChunkedInfer() throws Exception {
var model = mockModelForChunking();
SageMakerSchema schema = mock();
- when(schema.response(any(), any(), any())).thenReturn(TextEmbeddingFloatResultsTests.createRandomResults());
+ when(schema.response(any(), any(), any())).thenReturn(DenseEmbeddingFloatResultsTests.createRandomResults());
when(schemas.schemaFor(model)).thenReturn(schema);
mockInvoke();
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayloadTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayloadTests.java
index ed0ee43266ab5..3ac0f20d4242c 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayloadTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayloadTests.java
@@ -9,8 +9,8 @@
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema;
@@ -66,8 +66,8 @@ public void testBitResponse() throws Exception {
assertThat(bitResults.embeddings().size(), is(1));
var embedding = bitResults.embeddings().get(0);
- assertThat(embedding, isA(TextEmbeddingByteResults.Embedding.class));
- assertThat(((TextEmbeddingByteResults.Embedding) embedding).values(), is(new byte[] { 23 }));
+ assertThat(embedding, isA(DenseEmbeddingByteResults.Embedding.class));
+ assertThat(((DenseEmbeddingByteResults.Embedding) embedding).values(), is(new byte[] { 23 }));
}
public void testByteResponse() throws Exception {
@@ -87,8 +87,8 @@ public void testByteResponse() throws Exception {
assertThat(byteResults.embeddings().size(), is(1));
var embedding = byteResults.embeddings().get(0);
- assertThat(embedding, isA(TextEmbeddingByteResults.Embedding.class));
- assertThat(((TextEmbeddingByteResults.Embedding) embedding).values(), is(new byte[] { 23 }));
+ assertThat(embedding, isA(DenseEmbeddingByteResults.Embedding.class));
+ assertThat(((DenseEmbeddingByteResults.Embedding) embedding).values(), is(new byte[] { 23 }));
}
public void testFloatResponse() throws Exception {
@@ -108,7 +108,7 @@ public void testFloatResponse() throws Exception {
assertThat(byteResults.embeddings().size(), is(1));
var embedding = byteResults.embeddings().get(0);
- assertThat(embedding, isA(TextEmbeddingFloatResults.Embedding.class));
- assertThat(((TextEmbeddingFloatResults.Embedding) embedding).values(), is(new float[] { 0.1F }));
+ assertThat(embedding, isA(DenseEmbeddingFloatResults.Embedding.class));
+ assertThat(((DenseEmbeddingFloatResults.Embedding) embedding).values(), is(new float[] { 0.1F }));
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayloadTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayloadTests.java
index 35b78b004618c..2435514e9f7e1 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayloadTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayloadTests.java
@@ -12,7 +12,7 @@
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemaPayloadTestCase;
@@ -145,11 +145,11 @@ public void testResponse() throws Exception {
.body(SdkBytes.fromString(responseJson, StandardCharsets.UTF_8))
.build();
- var textEmbeddingFloatResults = payload.responseBody(mock(), invokeEndpointResponse);
+ var denseEmbeddingFloatResults = payload.responseBody(mock(), invokeEndpointResponse);
assertThat(
- textEmbeddingFloatResults.embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
+ denseEmbeddingFloatResults.embeddings(),
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
);
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/DenseEmbeddingModelValidatorTests.java
similarity index 90%
rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java
rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/DenseEmbeddingModelValidatorTests.java
index 45726f0789667..9ad4e81b20cb7 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/DenseEmbeddingModelValidatorTests.java
@@ -16,10 +16,10 @@
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.validation.ServiceIntegrationValidator;
import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResultsTests;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResultsTests;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.EmptyTaskSettingsTests;
import org.elasticsearch.xpack.inference.ModelConfigurationsTests;
import org.junit.Before;
@@ -38,7 +38,7 @@
import static org.mockito.Mockito.when;
import static org.mockito.MockitoAnnotations.openMocks;
-public class TextEmbeddingModelValidatorTests extends ESTestCase {
+public class DenseEmbeddingModelValidatorTests extends ESTestCase {
private static final TimeValue TIMEOUT = TimeValue.ONE_MINUTE;
@@ -53,13 +53,13 @@ public class TextEmbeddingModelValidatorTests extends ESTestCase {
@Mock
private ServiceSettings mockServiceSettings;
- private TextEmbeddingModelValidator underTest;
+ private DenseEmbeddingModelValidator underTest;
@Before
public void setup() {
openMocks(this);
- underTest = new TextEmbeddingModelValidator(mockServiceIntegrationValidator);
+ underTest = new DenseEmbeddingModelValidator(mockServiceIntegrationValidator);
when(mockInferenceService.updateModelWithEmbeddingDetails(eq(mockModel), anyInt())).thenReturn(mockModel);
when(mockActionListener.delegateFailureAndWrap(any())).thenCallRealMethod();
@@ -94,7 +94,7 @@ public void testValidate_ServiceReturnsNonTextEmbeddingResults() {
}
public void testValidate_RetrievingEmbeddingSizeThrowsIllegalStateException() {
- TextEmbeddingFloatResults results = new TextEmbeddingFloatResults(List.of());
+ DenseEmbeddingFloatResults results = new DenseEmbeddingFloatResults(List.of());
when(mockServiceSettings.dimensionsSetByUser()).thenReturn(true);
when(mockServiceSettings.dimensions()).thenReturn(randomNonNegativeInt());
@@ -107,7 +107,7 @@ public void testValidate_RetrievingEmbeddingSizeThrowsIllegalStateException() {
}
public void testValidate_DimensionsSetByUserDoNotEqualEmbeddingSize() {
- TextEmbeddingByteResults results = TextEmbeddingByteResultsTests.createRandomResults();
+ DenseEmbeddingByteResults results = DenseEmbeddingByteResultsTests.createRandomResults();
var dimensions = randomValueOtherThan(results.getFirstEmbeddingSize(), ESTestCase::randomNonNegativeInt);
when(mockServiceSettings.dimensionsSetByUser()).thenReturn(true);
@@ -131,7 +131,7 @@ public void testValidate_DimensionsNotSetByUser() {
}
private void mockSuccessfulValidation(Boolean dimensionsSetByUser) {
- TextEmbeddingByteResults results = TextEmbeddingByteResultsTests.createRandomResults();
+ DenseEmbeddingByteResults results = DenseEmbeddingByteResultsTests.createRandomResults();
when(mockModel.getConfigurations()).thenReturn(ModelConfigurationsTests.createRandomInstance());
when(mockModel.getTaskSettings()).thenReturn(EmptyTaskSettingsTests.createRandom());
when(mockServiceSettings.dimensionsSetByUser()).thenReturn(dimensionsSetByUser);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidatorTests.java
index 30e7a33757c16..c9db9907792fd 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidatorTests.java
@@ -16,7 +16,7 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingResults;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandEmbeddingModel;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings;
@@ -118,7 +118,7 @@ public void testValidate_ElandTextEmbeddingModelValidationFails() {
public void testValidate_ElandTextEmbeddingModelValidationSucceedsAndDimensionsSetByUserValid() {
var dimensions = randomIntBetween(1, 10);
- var mockInferenceServiceResults = mock(TextEmbeddingResults.class);
+ var mockInferenceServiceResults = mock(DenseEmbeddingResults.class);
var mockUpdatedModel = mock(CustomElandEmbeddingModel.class);
when(mockInferenceServiceResults.getFirstEmbeddingSize()).thenReturn(dimensions);
CustomElandEmbeddingModel customElandEmbeddingModel = createCustomElandEmbeddingModel(true, dimensions);
@@ -151,7 +151,7 @@ public void testValidate_ElandTextEmbeddingModelValidationSucceedsAndDimensionsS
public void testValidate_ElandTextEmbeddingModelValidationSucceedsAndDimensionsSetByUserInvalid() {
var dimensions = randomIntBetween(1, 10);
- var mockInferenceServiceResults = mock(TextEmbeddingResults.class);
+ var mockInferenceServiceResults = mock(DenseEmbeddingResults.class);
when(mockInferenceServiceResults.getFirstEmbeddingSize()).thenReturn(
randomValueOtherThan(dimensions, () -> randomIntBetween(1, 10))
);
@@ -207,7 +207,7 @@ public void testValidate_ElandTextEmbeddingAndValidationReturnsInvalidResultsTyp
public void testValidate_ElandTextEmbeddingModelDimensionsNotSetByUser() {
var dimensions = randomIntBetween(1, 10);
- var mockInferenceServiceResults = mock(TextEmbeddingResults.class);
+ var mockInferenceServiceResults = mock(DenseEmbeddingResults.class);
when(mockInferenceServiceResults.getFirstEmbeddingSize()).thenReturn(dimensions);
CustomElandEmbeddingModel customElandEmbeddingModel = createCustomElandEmbeddingModel(false, null);
@@ -239,7 +239,7 @@ public void testValidate_ElandTextEmbeddingModelDimensionsNotSetByUser() {
}
public void testValidate_ElandTextEmbeddingModelAndEmbeddingSizeRetrievalThrowsException() {
- var mockInferenceServiceResults = mock(TextEmbeddingResults.class);
+ var mockInferenceServiceResults = mock(DenseEmbeddingResults.class);
when(mockInferenceServiceResults.getFirstEmbeddingSize()).thenThrow(ElasticsearchStatusException.class);
CustomElandEmbeddingModel customElandEmbeddingModel = createCustomElandEmbeddingModel(false, null);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java
index b4d563b565eee..804984c866542 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java
@@ -95,7 +95,7 @@ public void testBuildModelValidator_ValidTaskType() {
private Map> taskTypeToModelValidatorClassMap() {
return Map.of(
TaskType.TEXT_EMBEDDING,
- TextEmbeddingModelValidator.class,
+ DenseEmbeddingModelValidator.class,
TaskType.SPARSE_EMBEDDING,
SimpleModelValidator.class,
TaskType.RERANK,
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java
index 9bbfcb4d58667..0e9059132f23d 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java
@@ -37,7 +37,7 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@@ -63,7 +63,7 @@
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
@@ -1654,10 +1654,10 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel mo
var floatResult = (ChunkedInferenceEmbedding) results.getFirst();
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().getFirst().offset());
- assertThat(floatResult.chunks().get(0).embedding(), CoreMatchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), CoreMatchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 0.123f, -0.123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
@@ -1666,10 +1666,13 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel mo
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().getFirst().offset());
- assertThat(floatResult.chunks().getFirst().embedding(), CoreMatchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(
+ floatResult.chunks().getFirst().embedding(),
+ CoreMatchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)
+ );
assertArrayEquals(
new float[] { 0.223f, -0.223f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java
index afbb5532d5fdd..86d6fab29842b 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java
@@ -37,7 +37,7 @@
import java.util.Map;
import java.util.concurrent.TimeUnit;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java
index 7050c2d131e17..b326664c527c1 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java
@@ -46,9 +46,9 @@
import java.util.Map;
import java.util.concurrent.TimeUnit;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationBinary;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationByte;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationBinary;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationByte;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntityTests.java
index 80a00737ddf52..7fa6b0c7bece3 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntityTests.java
@@ -11,7 +11,7 @@
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.XContentParseException;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.InputTypeTests;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.services.voyageai.request.VoyageAIEmbeddingsRequest;
@@ -60,8 +60,8 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
);
assertThat(
- ((TextEmbeddingFloatResults) parsedResults).embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
+ ((DenseEmbeddingFloatResults) parsedResults).embeddings(),
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
);
}
@@ -106,11 +106,11 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
);
assertThat(
- ((TextEmbeddingFloatResults) parsedResults).embeddings(),
+ ((DenseEmbeddingFloatResults) parsedResults).embeddings(),
is(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })
)
)
);
@@ -299,8 +299,8 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOExceptio
);
assertThat(
- ((TextEmbeddingFloatResults) parsedResults).embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 1.0F })))
+ ((DenseEmbeddingFloatResults) parsedResults).embeddings(),
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 1.0F })))
);
}
@@ -336,8 +336,8 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOExcepti
);
assertThat(
- ((TextEmbeddingFloatResults) parsedResults).embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F })))
+ ((DenseEmbeddingFloatResults) parsedResults).embeddings(),
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F })))
);
}
@@ -427,15 +427,15 @@ public void testFieldsInDifferentOrderServer() throws IOException {
new HttpResult(mock(HttpResponse.class), response.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults, instanceOf(TextEmbeddingFloatResults.class));
+ assertThat(parsedResults, instanceOf(DenseEmbeddingFloatResults.class));
assertThat(
- ((TextEmbeddingFloatResults) parsedResults).embeddings(),
+ ((DenseEmbeddingFloatResults) parsedResults).embeddings(),
is(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F })
)
)
);
diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java
index c28fc8f44c3fa..49598aa620464 100644
--- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java
+++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java
@@ -9,7 +9,7 @@
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
@@ -85,7 +85,7 @@ static InferenceResults processResult(
tokenization.anyTruncated()
);
} else {
- return new MlTextEmbeddingResults(
+ return new MlDenseEmbeddingResults(
Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD),
pyTorchResult.getInferenceResult()[0][0],
tokenization.anyTruncated()
diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessorTests.java
index 8369412580b88..6f66a05187fee 100644
--- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessorTests.java
+++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessorTests.java
@@ -10,7 +10,7 @@
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizationResult;
@@ -40,9 +40,9 @@ public void testSingleResult() {
var tokenization = tokenizer.tokenize(input, Tokenization.Truncate.NONE, 0, 0, null);
var tokenizationResult = new BertTokenizationResult(TextExpansionProcessorTests.TEST_CASED_VOCAB, tokenization, 0);
var inferenceResult = TextEmbeddingProcessor.processResult(tokenizationResult, pytorchResult, "foo", false);
- assertThat(inferenceResult, instanceOf(MlTextEmbeddingResults.class));
+ assertThat(inferenceResult, instanceOf(MlDenseEmbeddingResults.class));
- var result = (MlTextEmbeddingResults) inferenceResult;
+ var result = (MlDenseEmbeddingResults) inferenceResult;
assertThat(result.getInference().length, greaterThan(0));
}
}
diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java
index 86b4e8c588e51..fe269b6bcb0f5 100644
--- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java
+++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java
@@ -17,7 +17,7 @@
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder;
import org.elasticsearch.xpack.ml.MachineLearningTests;
@@ -53,7 +53,7 @@ public ActionResponse createResponse(float[] array, TextEmbeddingQueryVectorBuil
embedding[i] = array[i];
}
return new InferModelAction.Response(
- List.of(new MlTextEmbeddingResults("foo", embedding, randomBoolean())),
+ List.of(new MlDenseEmbeddingResults("foo", embedding, randomBoolean())),
builder.getModelId(),
true
);
From 43a6230a2f6738b5d23192980efa0d052e244cf6 Mon Sep 17 00:00:00 2001
From: elasticsearchmachine
Date: Thu, 9 Oct 2025 22:06:43 +0000
Subject: [PATCH 2/3] [CI] Auto commit changes from spotless
---
.../xpack/core/security/authc/AuthenticationTests.java | 1 -
1 file changed, 1 deletion(-)
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authc/AuthenticationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authc/AuthenticationTests.java
index 3462663266e9a..447140f2f9571 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authc/AuthenticationTests.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authc/AuthenticationTests.java
@@ -51,7 +51,6 @@
import static org.hamcrest.Matchers.hasKey;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
-import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.sameInstance;
From e7ee947f13443a545a708396447d434cee4d2f66 Mon Sep 17 00:00:00 2001
From: donalevans
Date: Thu, 9 Oct 2025 15:38:15 -0700
Subject: [PATCH 3/3] Rename an added ocurrence
---
.../esql/inference/InferenceFunctionEvaluatorTests.java | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java
index 875c136ab861c..ac847f50ce353 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java
@@ -16,7 +16,7 @@
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
@@ -250,7 +250,7 @@ private float[] randomEmbedding(int length) {
}
private InferenceAction.Response inferenceResponse(float[] embedding) {
- TextEmbeddingFloatResults.Embedding embeddingResult = new TextEmbeddingFloatResults.Embedding(embedding);
- return new InferenceAction.Response(new TextEmbeddingFloatResults(List.of(embeddingResult)));
+ DenseEmbeddingFloatResults.Embedding embeddingResult = new DenseEmbeddingFloatResults.Embedding(embedding);
+ return new InferenceAction.Response(new DenseEmbeddingFloatResults(List.of(embeddingResult)));
}
}