diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java
index 3746960ad8f78..2d1b932271f25 100644
--- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java
+++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java
@@ -24,7 +24,7 @@ public interface InferenceServiceResults extends NamedWriteable, ChunkedToXConte
/**
*
Transform the result to match the format required for the TransportCoordinatedInferenceAction.
- * TransportCoordinatedInferenceAction expects an ml plugin TextEmbeddingResults or SparseEmbeddingResults.
+ * TransportCoordinatedInferenceAction expects an ml plugin DenseEmbeddingResults or SparseEmbeddingResults.
*/
default List extends InferenceResults> transformToCoordinationFormat() {
throw new UnsupportedOperationException("transformToCoordinationFormat() is not implemented");
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingBitResults.java
similarity index 75%
rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResults.java
rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingBitResults.java
index 37fca12f1697a..0792bf90dbe12 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResults.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingBitResults.java
@@ -14,7 +14,7 @@
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.ToXContent;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import java.io.IOException;
import java.util.Iterator;
@@ -24,9 +24,10 @@
import java.util.Objects;
/**
- * Writes a text embedding result in the follow json format
+ * Writes a dense embedding result in the follow json format.
+ *
* {
- * "text_embedding_bytes": [
+ * "text_embedding_bits": [
* {
* "embedding": [
* 23
@@ -39,17 +40,19 @@
* }
* ]
* }
+ *
*/
-// Note: inheriting from TextEmbeddingByteResults gives a bad implementation of the
+// Note: inheriting from DenseEmbeddingByteResults gives a bad implementation of the
// Embedding.merge method for bits. TODO: implement a proper merge method
-public record TextEmbeddingBitResults(List embeddings)
+public record DenseEmbeddingBitResults(List embeddings)
implements
- TextEmbeddingResults {
+ DenseEmbeddingResults {
+ // This name is a holdover from before this class was renamed
public static final String NAME = "text_embedding_service_bit_results";
public static final String TEXT_EMBEDDING_BITS = "text_embedding_bits";
- public TextEmbeddingBitResults(StreamInput in) throws IOException {
- this(in.readCollectionAsList(TextEmbeddingByteResults.Embedding::new));
+ public DenseEmbeddingBitResults(StreamInput in) throws IOException {
+ this(in.readCollectionAsList(DenseEmbeddingByteResults.Embedding::new));
}
@Override
@@ -79,7 +82,7 @@ public String getWriteableName() {
@Override
public List extends InferenceResults> transformToCoordinationFormat() {
return embeddings.stream()
- .map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING_BITS, embedding.toDoubleArray(), false))
+ .map(embedding -> new MlDenseEmbeddingResults(TEXT_EMBEDDING_BITS, embedding.toDoubleArray(), false))
.toList();
}
@@ -94,7 +97,7 @@ public Map asMap() {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
- TextEmbeddingBitResults that = (TextEmbeddingBitResults) o;
+ DenseEmbeddingBitResults that = (DenseEmbeddingBitResults) o;
return Objects.equals(embeddings, that.embeddings);
}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingByteResults.java
similarity index 90%
rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java
rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingByteResults.java
index 54f858cb20ae0..9e72dc9a7b2b4 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingByteResults.java
@@ -20,7 +20,7 @@
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentBuilder;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import java.io.IOException;
import java.util.Arrays;
@@ -31,7 +31,8 @@
import java.util.Objects;
/**
- * Writes a text embedding result in the follow json format
+ * Writes a dense embedding result in the follow json format
+ *
* {
* "text_embedding_bytes": [
* {
@@ -46,13 +47,15 @@
* }
* ]
* }
+ *
*/
-public record TextEmbeddingByteResults(List embeddings) implements TextEmbeddingResults {
+public record DenseEmbeddingByteResults(List embeddings) implements DenseEmbeddingResults {
+ // This name is a holdover from before this class was renamed
public static final String NAME = "text_embedding_service_byte_results";
public static final String TEXT_EMBEDDING_BYTES = "text_embedding_bytes";
- public TextEmbeddingByteResults(StreamInput in) throws IOException {
- this(in.readCollectionAsList(TextEmbeddingByteResults.Embedding::new));
+ public DenseEmbeddingByteResults(StreamInput in) throws IOException {
+ this(in.readCollectionAsList(DenseEmbeddingByteResults.Embedding::new));
}
@Override
@@ -81,7 +84,7 @@ public String getWriteableName() {
@Override
public List extends InferenceResults> transformToCoordinationFormat() {
return embeddings.stream()
- .map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING_BYTES, embedding.toDoubleArray(), false))
+ .map(embedding -> new MlDenseEmbeddingResults(TEXT_EMBEDDING_BYTES, embedding.toDoubleArray(), false))
.toList();
}
@@ -96,7 +99,7 @@ public Map asMap() {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
- TextEmbeddingByteResults that = (TextEmbeddingByteResults) o;
+ DenseEmbeddingByteResults that = (DenseEmbeddingByteResults) o;
return Objects.equals(embeddings, that.embeddings);
}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingFloatResults.java
similarity index 84%
rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResults.java
rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingFloatResults.java
index 9dbdccd26e5d2..d14f38259208c 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResults.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingFloatResults.java
@@ -23,7 +23,7 @@
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentBuilder;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import java.io.IOException;
import java.util.ArrayList;
@@ -35,7 +35,8 @@
import java.util.Objects;
/**
- * Writes a text embedding result in the follow json format
+ * Writes a dense embedding result in the follow json format
+ *
* {
* "text_embedding": [
* {
@@ -50,20 +51,24 @@
* }
* ]
* }
+ *
*/
-public record TextEmbeddingFloatResults(List embeddings) implements TextEmbeddingResults {
+public record DenseEmbeddingFloatResults(List embeddings)
+ implements
+ DenseEmbeddingResults {
+ // This name is a holdover from before this class was renamed
public static final String NAME = "text_embedding_service_results";
public static final String TEXT_EMBEDDING = TaskType.TEXT_EMBEDDING.toString();
- public TextEmbeddingFloatResults(StreamInput in) throws IOException {
- this(in.readCollectionAsList(TextEmbeddingFloatResults.Embedding::new));
+ public DenseEmbeddingFloatResults(StreamInput in) throws IOException {
+ this(in.readCollectionAsList(DenseEmbeddingFloatResults.Embedding::new));
}
- public static TextEmbeddingFloatResults of(List extends InferenceResults> results) {
+ public static DenseEmbeddingFloatResults of(List extends InferenceResults> results) {
List embeddings = new ArrayList<>(results.size());
for (InferenceResults result : results) {
- if (result instanceof MlTextEmbeddingResults embeddingResult) {
- embeddings.add(TextEmbeddingFloatResults.Embedding.of(embeddingResult));
+ if (result instanceof MlDenseEmbeddingResults embeddingResult) {
+ embeddings.add(DenseEmbeddingFloatResults.Embedding.of(embeddingResult));
} else if (result instanceof org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults errorResult) {
if (errorResult.getException() instanceof ElasticsearchStatusException statusException) {
throw statusException;
@@ -76,11 +81,15 @@ public static TextEmbeddingFloatResults of(List extends InferenceResults> resu
}
} else {
throw new IllegalArgumentException(
- "Received invalid inference result, of type " + result.getClass().getName() + " but expected TextEmbeddingResults."
+ "Received invalid inference result, of type "
+ + result.getClass().getName()
+ + " but expected "
+ + MlDenseEmbeddingResults.class.getName()
+ + "."
);
}
}
- return new TextEmbeddingFloatResults(embeddings);
+ return new DenseEmbeddingFloatResults(embeddings);
}
@Override
@@ -108,7 +117,7 @@ public String getWriteableName() {
@Override
public List extends InferenceResults> transformToCoordinationFormat() {
- return embeddings.stream().map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING, embedding.asDoubleArray(), false)).toList();
+ return embeddings.stream().map(embedding -> new MlDenseEmbeddingResults(TEXT_EMBEDDING, embedding.asDoubleArray(), false)).toList();
}
public Map asMap() {
@@ -122,7 +131,7 @@ public Map asMap() {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
- TextEmbeddingFloatResults that = (TextEmbeddingFloatResults) o;
+ DenseEmbeddingFloatResults that = (DenseEmbeddingFloatResults) o;
return Objects.equals(embeddings, that.embeddings);
}
@@ -148,7 +157,7 @@ public Embedding(StreamInput in) throws IOException {
this(in.readFloatArray());
}
- public static Embedding of(MlTextEmbeddingResults embeddingResult) {
+ public static Embedding of(MlDenseEmbeddingResults embeddingResult) {
float[] embeddingAsArray = embeddingResult.getInferenceAsFloat();
return new Embedding(embeddingAsArray);
}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingResults.java
similarity index 66%
rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java
rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingResults.java
index ea4e45ec67407..af6abb357bd98 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingResults.java
@@ -7,11 +7,11 @@
package org.elasticsearch.xpack.core.inference.results;
-public interface TextEmbeddingResults> extends EmbeddingResults {
+public interface DenseEmbeddingResults> extends EmbeddingResults {
/**
- * Returns the first text embedding entry in the result list's array size.
- * @return the size of the text embedding
+ * Returns the first embedding entry in the result list's array size.
+ * @return the size of the embedding
* @throws IllegalStateException if the list of embeddings is empty
*/
int getFirstEmbeddingSize() throws IllegalStateException;
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java
index 667d7bf63efc9..f2952fee36cda 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java
@@ -25,7 +25,7 @@
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults;
@@ -669,7 +669,7 @@ public List getNamedWriteables() {
);
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, TextExpansionResults.NAME, TextExpansionResults::new));
namedWriteables.add(
- new NamedWriteableRegistry.Entry(InferenceResults.class, MlTextEmbeddingResults.NAME, MlTextEmbeddingResults::new)
+ new NamedWriteableRegistry.Entry(InferenceResults.class, MlDenseEmbeddingResults.NAME, MlDenseEmbeddingResults::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/MlTextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/MlDenseEmbeddingResults.java
similarity index 87%
rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/MlTextEmbeddingResults.java
rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/MlDenseEmbeddingResults.java
index 0c0fa6f3f690e..b839ade030e1e 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/MlTextEmbeddingResults.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/MlDenseEmbeddingResults.java
@@ -16,20 +16,21 @@
import java.util.Map;
import java.util.Objects;
-public class MlTextEmbeddingResults extends NlpInferenceResults {
+public class MlDenseEmbeddingResults extends NlpInferenceResults {
+ // This name is a holdover from before this class was renamed
public static final String NAME = "text_embedding_result";
private final String resultsField;
private final double[] inference;
- public MlTextEmbeddingResults(String resultsField, double[] inference, boolean isTruncated) {
+ public MlDenseEmbeddingResults(String resultsField, double[] inference, boolean isTruncated) {
super(isTruncated);
this.inference = inference;
this.resultsField = resultsField;
}
- public MlTextEmbeddingResults(StreamInput in) throws IOException {
+ public MlDenseEmbeddingResults(StreamInput in) throws IOException {
super(in);
inference = in.readDoubleArray();
resultsField = in.readString();
@@ -89,7 +90,7 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (super.equals(o) == false) return false;
- MlTextEmbeddingResults that = (MlTextEmbeddingResults) o;
+ MlDenseEmbeddingResults that = (MlDenseEmbeddingResults) o;
return Objects.equals(resultsField, that.resultsField) && Arrays.equals(inference, that.inference);
}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/vectors/TextEmbeddingQueryVectorBuilder.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/vectors/TextEmbeddingQueryVectorBuilder.java
index 1f4967d95c9b7..a5b085e877a73 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/vectors/TextEmbeddingQueryVectorBuilder.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/vectors/TextEmbeddingQueryVectorBuilder.java
@@ -20,7 +20,7 @@
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate;
@@ -117,14 +117,14 @@ public void buildVector(Client client, ActionListener listener) {
return;
}
- if (response.getInferenceResults().get(0) instanceof MlTextEmbeddingResults textEmbeddingResults) {
+ if (response.getInferenceResults().get(0) instanceof MlDenseEmbeddingResults textEmbeddingResults) {
listener.onResponse(textEmbeddingResults.getInferenceAsFloat());
} else if (response.getInferenceResults().get(0) instanceof WarningInferenceResults warning) {
listener.onFailure(new IllegalStateException(warning.getWarning()));
} else {
throw new IllegalArgumentException(
"expected a result of type ["
- + MlTextEmbeddingResults.NAME
+ + MlDenseEmbeddingResults.NAME
+ "] received ["
+ response.getInferenceResults().get(0).getWriteableName()
+ "]. Is ["
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingBitResultsTests.java
similarity index 55%
rename from x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResultsTests.java
rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingBitResultsTests.java
index 61b49075702a2..9e5814825d659 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResultsTests.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingBitResultsTests.java
@@ -10,7 +10,7 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import java.io.IOException;
import java.util.ArrayList;
@@ -19,19 +19,19 @@
import static org.hamcrest.Matchers.is;
-public class TextEmbeddingBitResultsTests extends AbstractWireSerializingTestCase {
- public static TextEmbeddingBitResults createRandomResults() {
+public class DenseEmbeddingBitResultsTests extends AbstractWireSerializingTestCase {
+ public static DenseEmbeddingBitResults createRandomResults() {
int embeddings = randomIntBetween(1, 10);
- List embeddingResults = new ArrayList<>(embeddings);
+ List embeddingResults = new ArrayList<>(embeddings);
for (int i = 0; i < embeddings; i++) {
embeddingResults.add(createRandomEmbedding());
}
- return new TextEmbeddingBitResults(embeddingResults);
+ return new DenseEmbeddingBitResults(embeddingResults);
}
- private static TextEmbeddingByteResults.Embedding createRandomEmbedding() {
+ private static DenseEmbeddingByteResults.Embedding createRandomEmbedding() {
int columns = randomIntBetween(1, 10);
byte[] bytes = new byte[columns];
@@ -39,11 +39,11 @@ private static TextEmbeddingByteResults.Embedding createRandomEmbedding() {
bytes[i] = randomByte();
}
- return new TextEmbeddingByteResults.Embedding(bytes);
+ return new DenseEmbeddingByteResults.Embedding(bytes);
}
public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException {
- var entity = new TextEmbeddingBitResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23 })));
+ var entity = new DenseEmbeddingBitResults(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23 })));
String xContentResult = Strings.toString(entity, true, true);
assertThat(xContentResult, is("""
@@ -59,10 +59,10 @@ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOE
}
public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException {
- var entity = new TextEmbeddingBitResults(
+ var entity = new DenseEmbeddingBitResults(
List.of(
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }),
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 24 })
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }),
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 24 })
)
);
@@ -85,10 +85,10 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I
}
public void testTransformToCoordinationFormat() {
- var results = new TextEmbeddingBitResults(
+ var results = new DenseEmbeddingBitResults(
List.of(
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }),
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 })
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }),
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 })
)
).transformToCoordinationFormat();
@@ -96,18 +96,18 @@ public void testTransformToCoordinationFormat() {
results,
is(
List.of(
- new MlTextEmbeddingResults(TextEmbeddingBitResults.TEXT_EMBEDDING_BITS, new double[] { 23F, 24F }, false),
- new MlTextEmbeddingResults(TextEmbeddingBitResults.TEXT_EMBEDDING_BITS, new double[] { 25F, 26F }, false)
+ new MlDenseEmbeddingResults(DenseEmbeddingBitResults.TEXT_EMBEDDING_BITS, new double[] { 23F, 24F }, false),
+ new MlDenseEmbeddingResults(DenseEmbeddingBitResults.TEXT_EMBEDDING_BITS, new double[] { 25F, 26F }, false)
)
)
);
}
public void testGetFirstEmbeddingSize() {
- var firstEmbeddingSize = new TextEmbeddingBitResults(
+ var firstEmbeddingSize = new DenseEmbeddingBitResults(
List.of(
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }),
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 })
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }),
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 })
)
).getFirstEmbeddingSize();
@@ -115,33 +115,33 @@ public void testGetFirstEmbeddingSize() {
}
@Override
- protected Writeable.Reader instanceReader() {
- return TextEmbeddingBitResults::new;
+ protected Writeable.Reader instanceReader() {
+ return DenseEmbeddingBitResults::new;
}
@Override
- protected TextEmbeddingBitResults createTestInstance() {
+ protected DenseEmbeddingBitResults createTestInstance() {
return createRandomResults();
}
@Override
- protected TextEmbeddingBitResults mutateInstance(TextEmbeddingBitResults instance) throws IOException {
+ protected DenseEmbeddingBitResults mutateInstance(DenseEmbeddingBitResults instance) throws IOException {
// if true we reduce the embeddings list by a random amount, if false we add an embedding to the list
if (randomBoolean()) {
// -1 to remove at least one item from the list
int end = randomInt(instance.embeddings().size() - 1);
- return new TextEmbeddingBitResults(instance.embeddings().subList(0, end));
+ return new DenseEmbeddingBitResults(instance.embeddings().subList(0, end));
} else {
- List embeddings = new ArrayList<>(instance.embeddings());
+ List embeddings = new ArrayList<>(instance.embeddings());
embeddings.add(createRandomEmbedding());
- return new TextEmbeddingBitResults(embeddings);
+ return new DenseEmbeddingBitResults(embeddings);
}
}
public static Map buildExpectationByte(List> embeddings) {
return Map.of(
- TextEmbeddingBitResults.TEXT_EMBEDDING_BITS,
- embeddings.stream().map(embedding -> Map.of(TextEmbeddingByteResults.Embedding.EMBEDDING, embedding)).toList()
+ DenseEmbeddingBitResults.TEXT_EMBEDDING_BITS,
+ embeddings.stream().map(embedding -> Map.of(DenseEmbeddingByteResults.Embedding.EMBEDDING, embedding)).toList()
);
}
}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingByteResultsTests.java
similarity index 50%
rename from x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResultsTests.java
rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingByteResultsTests.java
index 60f45399cfb32..53cd8690b7dc0 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResultsTests.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingByteResultsTests.java
@@ -10,7 +10,7 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import java.io.IOException;
import java.util.ArrayList;
@@ -20,19 +20,19 @@
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
-public class TextEmbeddingByteResultsTests extends AbstractWireSerializingTestCase {
- public static TextEmbeddingByteResults createRandomResults() {
+public class DenseEmbeddingByteResultsTests extends AbstractWireSerializingTestCase {
+ public static DenseEmbeddingByteResults createRandomResults() {
int embeddings = randomIntBetween(1, 10);
- List embeddingResults = new ArrayList<>(embeddings);
+ List embeddingResults = new ArrayList<>(embeddings);
for (int i = 0; i < embeddings; i++) {
embeddingResults.add(createRandomEmbedding());
}
- return new TextEmbeddingByteResults(embeddingResults);
+ return new DenseEmbeddingByteResults(embeddingResults);
}
- private static TextEmbeddingByteResults.Embedding createRandomEmbedding() {
+ private static DenseEmbeddingByteResults.Embedding createRandomEmbedding() {
int columns = randomIntBetween(1, 10);
byte[] bytes = new byte[columns];
@@ -40,11 +40,11 @@ private static TextEmbeddingByteResults.Embedding createRandomEmbedding() {
bytes[i] = randomByte();
}
- return new TextEmbeddingByteResults.Embedding(bytes);
+ return new DenseEmbeddingByteResults.Embedding(bytes);
}
public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException {
- var entity = new TextEmbeddingByteResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23 })));
+ var entity = new DenseEmbeddingByteResults(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23 })));
String xContentResult = Strings.toString(entity, true, true);
assertThat(xContentResult, is("""
@@ -60,10 +60,10 @@ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOE
}
public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException {
- var entity = new TextEmbeddingByteResults(
+ var entity = new DenseEmbeddingByteResults(
List.of(
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }),
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 24 })
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }),
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 24 })
)
);
@@ -86,10 +86,10 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I
}
public void testTransformToCoordinationFormat() {
- var results = new TextEmbeddingByteResults(
+ var results = new DenseEmbeddingByteResults(
List.of(
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }),
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 })
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }),
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 })
)
).transformToCoordinationFormat();
@@ -97,18 +97,18 @@ public void testTransformToCoordinationFormat() {
results,
is(
List.of(
- new MlTextEmbeddingResults(TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, new double[] { 23F, 24F }, false),
- new MlTextEmbeddingResults(TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, new double[] { 25F, 26F }, false)
+ new MlDenseEmbeddingResults(DenseEmbeddingByteResults.TEXT_EMBEDDING_BYTES, new double[] { 23F, 24F }, false),
+ new MlDenseEmbeddingResults(DenseEmbeddingByteResults.TEXT_EMBEDDING_BYTES, new double[] { 25F, 26F }, false)
)
)
);
}
public void testGetFirstEmbeddingSize() {
- var firstEmbeddingSize = new TextEmbeddingByteResults(
+ var firstEmbeddingSize = new DenseEmbeddingByteResults(
List.of(
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }),
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 })
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }),
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 })
)
).getFirstEmbeddingSize();
@@ -116,43 +116,43 @@ public void testGetFirstEmbeddingSize() {
}
public void testEmbeddingMerge() {
- TextEmbeddingByteResults.Embedding embedding1 = new TextEmbeddingByteResults.Embedding(new byte[] { 1, 1, -128 });
- TextEmbeddingByteResults.Embedding embedding2 = new TextEmbeddingByteResults.Embedding(new byte[] { 1, 0, 127 });
- TextEmbeddingByteResults.Embedding embedding3 = new TextEmbeddingByteResults.Embedding(new byte[] { 0, 0, 100 });
- TextEmbeddingByteResults.Embedding mergedEmbedding = embedding1.merge(embedding2);
- assertThat(mergedEmbedding, equalTo(new TextEmbeddingByteResults.Embedding(new byte[] { 1, 1, 0 })));
+ DenseEmbeddingByteResults.Embedding embedding1 = new DenseEmbeddingByteResults.Embedding(new byte[] { 1, 1, -128 });
+ DenseEmbeddingByteResults.Embedding embedding2 = new DenseEmbeddingByteResults.Embedding(new byte[] { 1, 0, 127 });
+ DenseEmbeddingByteResults.Embedding embedding3 = new DenseEmbeddingByteResults.Embedding(new byte[] { 0, 0, 100 });
+ DenseEmbeddingByteResults.Embedding mergedEmbedding = embedding1.merge(embedding2);
+ assertThat(mergedEmbedding, equalTo(new DenseEmbeddingByteResults.Embedding(new byte[] { 1, 1, 0 })));
mergedEmbedding = mergedEmbedding.merge(embedding3);
- assertThat(mergedEmbedding, equalTo(new TextEmbeddingByteResults.Embedding(new byte[] { 1, 0, 33 })));
+ assertThat(mergedEmbedding, equalTo(new DenseEmbeddingByteResults.Embedding(new byte[] { 1, 0, 33 })));
}
@Override
- protected Writeable.Reader instanceReader() {
- return TextEmbeddingByteResults::new;
+ protected Writeable.Reader instanceReader() {
+ return DenseEmbeddingByteResults::new;
}
@Override
- protected TextEmbeddingByteResults createTestInstance() {
+ protected DenseEmbeddingByteResults createTestInstance() {
return createRandomResults();
}
@Override
- protected TextEmbeddingByteResults mutateInstance(TextEmbeddingByteResults instance) throws IOException {
+ protected DenseEmbeddingByteResults mutateInstance(DenseEmbeddingByteResults instance) throws IOException {
// if true we reduce the embeddings list by a random amount, if false we add an embedding to the list
if (randomBoolean()) {
// -1 to remove at least one item from the list
int end = randomInt(instance.embeddings().size() - 1);
- return new TextEmbeddingByteResults(instance.embeddings().subList(0, end));
+ return new DenseEmbeddingByteResults(instance.embeddings().subList(0, end));
} else {
- List embeddings = new ArrayList<>(instance.embeddings());
+ List embeddings = new ArrayList<>(instance.embeddings());
embeddings.add(createRandomEmbedding());
- return new TextEmbeddingByteResults(embeddings);
+ return new DenseEmbeddingByteResults(embeddings);
}
}
public static Map buildExpectationByte(List> embeddings) {
return Map.of(
- TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES,
- embeddings.stream().map(embedding -> Map.of(TextEmbeddingByteResults.Embedding.EMBEDDING, embedding)).toList()
+ DenseEmbeddingByteResults.TEXT_EMBEDDING_BYTES,
+ embeddings.stream().map(embedding -> Map.of(DenseEmbeddingByteResults.Embedding.EMBEDDING, embedding)).toList()
);
}
}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingFloatResultsTests.java
similarity index 50%
rename from x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResultsTests.java
rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingFloatResultsTests.java
index 8cdd98bcdebc6..4481692127979 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResultsTests.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingFloatResultsTests.java
@@ -10,7 +10,7 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import java.io.IOException;
import java.util.ArrayList;
@@ -20,30 +20,30 @@
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
-public class TextEmbeddingFloatResultsTests extends AbstractWireSerializingTestCase {
- public static TextEmbeddingFloatResults createRandomResults() {
+public class DenseEmbeddingFloatResultsTests extends AbstractWireSerializingTestCase {
+ public static DenseEmbeddingFloatResults createRandomResults() {
int embeddings = randomIntBetween(1, 10);
- List embeddingResults = new ArrayList<>(embeddings);
+ List embeddingResults = new ArrayList<>(embeddings);
for (int i = 0; i < embeddings; i++) {
embeddingResults.add(createRandomEmbedding());
}
- return new TextEmbeddingFloatResults(embeddingResults);
+ return new DenseEmbeddingFloatResults(embeddingResults);
}
- private static TextEmbeddingFloatResults.Embedding createRandomEmbedding() {
+ private static DenseEmbeddingFloatResults.Embedding createRandomEmbedding() {
int columns = randomIntBetween(1, 10);
float[] floats = new float[columns];
for (int i = 0; i < columns; i++) {
floats[i] = randomFloat();
}
- return new TextEmbeddingFloatResults.Embedding(floats);
+ return new DenseEmbeddingFloatResults.Embedding(floats);
}
public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException {
- var entity = new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F })));
+ var entity = new DenseEmbeddingFloatResults(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F })));
String xContentResult = Strings.toString(entity, true, true);
assertThat(xContentResult, is("""
@@ -59,10 +59,10 @@ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOE
}
public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException {
- var entity = new TextEmbeddingFloatResults(
+ var entity = new DenseEmbeddingFloatResults(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.2F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.2F })
)
);
@@ -86,10 +86,10 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I
}
public void testTransformToCoordinationFormat() {
- var results = new TextEmbeddingFloatResults(
+ var results = new DenseEmbeddingFloatResults(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.2F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.3F, 0.4F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.2F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.3F, 0.4F })
)
).transformToCoordinationFormat();
@@ -97,18 +97,18 @@ public void testTransformToCoordinationFormat() {
results,
is(
List.of(
- new MlTextEmbeddingResults(TextEmbeddingFloatResults.TEXT_EMBEDDING, new double[] { 0.1F, 0.2F }, false),
- new MlTextEmbeddingResults(TextEmbeddingFloatResults.TEXT_EMBEDDING, new double[] { 0.3F, 0.4F }, false)
+ new MlDenseEmbeddingResults(DenseEmbeddingFloatResults.TEXT_EMBEDDING, new double[] { 0.1F, 0.2F }, false),
+ new MlDenseEmbeddingResults(DenseEmbeddingFloatResults.TEXT_EMBEDDING, new double[] { 0.3F, 0.4F }, false)
)
)
);
}
public void testGetFirstEmbeddingSize() {
- var firstEmbeddingSize = new TextEmbeddingFloatResults(
+ var firstEmbeddingSize = new DenseEmbeddingFloatResults(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.2F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.3F, 0.4F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.2F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.3F, 0.4F })
)
).getFirstEmbeddingSize();
@@ -116,51 +116,54 @@ public void testGetFirstEmbeddingSize() {
}
public void testEmbeddingMerge() {
- TextEmbeddingFloatResults.Embedding embedding1 = new TextEmbeddingFloatResults.Embedding(new float[] { 0.1f, 0.2f, 0.3f, 0.4f });
- TextEmbeddingFloatResults.Embedding embedding2 = new TextEmbeddingFloatResults.Embedding(new float[] { 0.0f, 0.4f, 0.1f, 1.0f });
- TextEmbeddingFloatResults.Embedding embedding3 = new TextEmbeddingFloatResults.Embedding(new float[] { 0.2f, 0.9f, 0.8f, 0.1f });
- TextEmbeddingFloatResults.Embedding mergedEmbedding = embedding1.merge(embedding2);
- assertThat(mergedEmbedding, equalTo(new TextEmbeddingFloatResults.Embedding(new float[] { 0.05f, 0.3f, 0.2f, 0.7f })));
+ DenseEmbeddingFloatResults.Embedding embedding1 = new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1f, 0.2f, 0.3f, 0.4f });
+ DenseEmbeddingFloatResults.Embedding embedding2 = new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0f, 0.4f, 0.1f, 1.0f });
+ DenseEmbeddingFloatResults.Embedding embedding3 = new DenseEmbeddingFloatResults.Embedding(new float[] { 0.2f, 0.9f, 0.8f, 0.1f });
+ DenseEmbeddingFloatResults.Embedding mergedEmbedding = embedding1.merge(embedding2);
+ assertThat(mergedEmbedding, equalTo(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.05f, 0.3f, 0.2f, 0.7f })));
mergedEmbedding = mergedEmbedding.merge(embedding3);
- assertThat(mergedEmbedding, equalTo(new TextEmbeddingFloatResults.Embedding(new float[] { 0.1f, 0.5f, 0.4f, 0.5f })));
+ assertThat(mergedEmbedding, equalTo(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1f, 0.5f, 0.4f, 0.5f })));
}
@Override
- protected Writeable.Reader instanceReader() {
- return TextEmbeddingFloatResults::new;
+ protected Writeable.Reader instanceReader() {
+ return DenseEmbeddingFloatResults::new;
}
@Override
- protected TextEmbeddingFloatResults createTestInstance() {
+ protected DenseEmbeddingFloatResults createTestInstance() {
return createRandomResults();
}
@Override
- protected TextEmbeddingFloatResults mutateInstance(TextEmbeddingFloatResults instance) throws IOException {
+ protected DenseEmbeddingFloatResults mutateInstance(DenseEmbeddingFloatResults instance) throws IOException {
// if true we reduce the embeddings list by a random amount, if false we add an embedding to the list
if (randomBoolean()) {
// -1 to remove at least one item from the list
int end = randomInt(instance.embeddings().size() - 1);
- return new TextEmbeddingFloatResults(instance.embeddings().subList(0, end));
+ return new DenseEmbeddingFloatResults(instance.embeddings().subList(0, end));
} else {
- List embeddings = new ArrayList<>(instance.embeddings());
+ List embeddings = new ArrayList<>(instance.embeddings());
embeddings.add(createRandomEmbedding());
- return new TextEmbeddingFloatResults(embeddings);
+ return new DenseEmbeddingFloatResults(embeddings);
}
}
public static Map buildExpectationFloat(List embeddings) {
- return Map.of(TextEmbeddingFloatResults.TEXT_EMBEDDING, embeddings.stream().map(TextEmbeddingFloatResults.Embedding::new).toList());
+ return Map.of(
+ DenseEmbeddingFloatResults.TEXT_EMBEDDING,
+ embeddings.stream().map(DenseEmbeddingFloatResults.Embedding::new).toList()
+ );
}
public static Map buildExpectationByte(List embeddings) {
return Map.of(
- TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES,
- embeddings.stream().map(TextEmbeddingByteResults.Embedding::new).toList()
+ DenseEmbeddingByteResults.TEXT_EMBEDDING_BYTES,
+ embeddings.stream().map(DenseEmbeddingByteResults.Embedding::new).toList()
);
}
public static Map buildExpectationBinary(List embeddings) {
- return Map.of("text_embedding_bits", embeddings.stream().map(TextEmbeddingByteResults.Embedding::new).toList());
+ return Map.of("text_embedding_bits", embeddings.stream().map(DenseEmbeddingByteResults.Embedding::new).toList());
}
}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java
index 87049d6bde90c..b897dc8aef6dd 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java
@@ -17,8 +17,8 @@
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResultsTests;
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResultsTests;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResultsTests;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResultsTests;
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
import org.elasticsearch.xpack.core.ml.inference.results.NerResultsTests;
import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults;
@@ -50,7 +50,7 @@ public class InferModelActionResponseTests extends AbstractWireSerializingTestCa
PyTorchPassThroughResults.NAME,
QuestionAnsweringInferenceResults.NAME,
RegressionInferenceResults.NAME,
- MlTextEmbeddingResults.NAME,
+ MlDenseEmbeddingResults.NAME,
TextExpansionResults.NAME,
TextSimilarityInferenceResults.NAME,
WarningInferenceResults.NAME
@@ -87,7 +87,7 @@ private static InferenceResults randomInferenceResult(String resultType) {
case PyTorchPassThroughResults.NAME -> PyTorchPassThroughResultsTests.createRandomResults();
case QuestionAnsweringInferenceResults.NAME -> QuestionAnsweringInferenceResultsTests.createRandomResults();
case RegressionInferenceResults.NAME -> RegressionInferenceResultsTests.createRandomResults();
- case MlTextEmbeddingResults.NAME -> MlTextEmbeddingResultsTests.createRandomResults();
+ case MlDenseEmbeddingResults.NAME -> MlDenseEmbeddingResultsTests.createRandomResults();
case TextExpansionResults.NAME -> TextExpansionResultsTests.createRandomResults();
case TextSimilarityInferenceResults.NAME -> TextSimilarityInferenceResultsTests.createRandomResults();
case WarningInferenceResults.NAME -> WarningInferenceResultsTests.createRandomResults();
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentResponseTests.java
index 62d26ead2b641..80d905fe1d83f 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentResponseTests.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentResponseTests.java
@@ -13,7 +13,7 @@
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResultsTests;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResultsTests;
import org.junit.Before;
import java.util.List;
@@ -49,10 +49,10 @@ protected Writeable.Reader instanceR
protected InferTrainedModelDeploymentAction.Response createTestInstance() {
return new InferTrainedModelDeploymentAction.Response(
List.of(
- MlTextEmbeddingResultsTests.createRandomResults(),
- MlTextEmbeddingResultsTests.createRandomResults(),
- MlTextEmbeddingResultsTests.createRandomResults(),
- MlTextEmbeddingResultsTests.createRandomResults()
+ MlDenseEmbeddingResultsTests.createRandomResults(),
+ MlDenseEmbeddingResultsTests.createRandomResults(),
+ MlDenseEmbeddingResultsTests.createRandomResults(),
+ MlDenseEmbeddingResultsTests.createRandomResults()
)
);
}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/MlTextEmbeddingResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/MlDenseEmbeddingResultsTests.java
similarity index 68%
rename from x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/MlTextEmbeddingResultsTests.java
rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/MlDenseEmbeddingResultsTests.java
index 3338609eebdc3..b5aee244771ea 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/MlTextEmbeddingResultsTests.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/MlDenseEmbeddingResultsTests.java
@@ -16,35 +16,35 @@
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
-public class MlTextEmbeddingResultsTests extends InferenceResultsTestCase {
+public class MlDenseEmbeddingResultsTests extends InferenceResultsTestCase {
- public static MlTextEmbeddingResults createRandomResults() {
+ public static MlDenseEmbeddingResults createRandomResults() {
int columns = randomIntBetween(1, 10);
double[] arr = new double[columns];
for (int i = 0; i < columns; i++) {
arr[i] = randomDouble();
}
- return new MlTextEmbeddingResults(DEFAULT_RESULTS_FIELD, arr, randomBoolean());
+ return new MlDenseEmbeddingResults(DEFAULT_RESULTS_FIELD, arr, randomBoolean());
}
@Override
- protected Writeable.Reader instanceReader() {
- return MlTextEmbeddingResults::new;
+ protected Writeable.Reader instanceReader() {
+ return MlDenseEmbeddingResults::new;
}
@Override
- protected MlTextEmbeddingResults createTestInstance() {
+ protected MlDenseEmbeddingResults createTestInstance() {
return createRandomResults();
}
@Override
- protected MlTextEmbeddingResults mutateInstance(MlTextEmbeddingResults instance) {
+ protected MlDenseEmbeddingResults mutateInstance(MlDenseEmbeddingResults instance) {
return null;// TODO implement https://github.com/elastic/elasticsearch/issues/25929
}
public void testAsMap() {
- MlTextEmbeddingResults testInstance = createTestInstance();
+ MlDenseEmbeddingResults testInstance = createTestInstance();
Map asMap = testInstance.asMap();
int size = testInstance.isTruncated ? 2 : 1;
assertThat(asMap.keySet(), hasSize(size));
@@ -55,7 +55,7 @@ public void testAsMap() {
}
@Override
- void assertFieldValues(MlTextEmbeddingResults createdInstance, IngestDocument document, String parentField, String resultsField) {
+ void assertFieldValues(MlDenseEmbeddingResults createdInstance, IngestDocument document, String parentField, String resultsField) {
assertArrayEquals(document.getFieldValue(parentField + resultsField, double[].class), createdInstance.getInference(), 1e-10);
}
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java
index a2b0a32e77b05..6cc27f931cfc8 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java
@@ -12,14 +12,14 @@
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingResults;
import org.elasticsearch.xpack.esql.inference.InferenceOperator;
/**
* {@link TextEmbeddingOperatorOutputBuilder} builds the output page for text embedding by converting
- * {@link TextEmbeddingResults} into a {@link FloatBlock} containing dense vector embeddings.
+ * {@link DenseEmbeddingResults} into a {@link FloatBlock} containing dense vector embeddings.
*/
class TextEmbeddingOperatorOutputBuilder implements InferenceOperator.OutputBuilder {
private final Page inputPage;
@@ -39,7 +39,7 @@ public void close() {
* Adds an inference response to the output builder.
*
*
- * If the response is null or not of type {@link TextEmbeddingResults} an {@link IllegalStateException} is thrown.
+ * If the response is null or not of type {@link DenseEmbeddingResults} an {@link IllegalStateException} is thrown.
* Else, the embedding vector is added to the output block as a multi-value position.
*
*
@@ -55,7 +55,7 @@ public void addInferenceResponse(InferenceAction.Response inferenceResponse) {
return;
}
- TextEmbeddingResults> embeddingResults = inferenceResults(inferenceResponse);
+ DenseEmbeddingResults> embeddingResults = inferenceResults(inferenceResponse);
var embeddings = embeddingResults.embeddings();
if (embeddings.isEmpty()) {
@@ -82,21 +82,25 @@ public Page buildOutput() {
return inputPage.appendBlock(outputBlock);
}
- private TextEmbeddingResults> inferenceResults(InferenceAction.Response inferenceResponse) {
- return InferenceOperator.OutputBuilder.inferenceResults(inferenceResponse, TextEmbeddingResults.class);
+ private DenseEmbeddingResults> inferenceResults(InferenceAction.Response inferenceResponse) {
+ return InferenceOperator.OutputBuilder.inferenceResults(inferenceResponse, DenseEmbeddingResults.class);
}
/**
* Extracts the embedding as a float array from the embedding result.
*/
- private static float[] getEmbeddingAsFloatArray(TextEmbeddingResults> embedding) {
+ private static float[] getEmbeddingAsFloatArray(DenseEmbeddingResults> embedding) {
return switch (embedding.embeddings().get(0)) {
- case TextEmbeddingFloatResults.Embedding floatEmbedding -> floatEmbedding.values();
- case TextEmbeddingByteResults.Embedding byteEmbedding -> toFloatArray(byteEmbedding.values());
+ case DenseEmbeddingFloatResults.Embedding floatEmbedding -> floatEmbedding.values();
+ case DenseEmbeddingByteResults.Embedding byteEmbedding -> toFloatArray(byteEmbedding.values());
default -> throw new IllegalArgumentException(
"Unsupported embedding type: "
+ embedding.embeddings().get(0).getClass().getName()
- + ". Expected TextEmbeddingFloatResults.Embedding or TextEmbeddingByteResults.Embedding."
+ + ". Expected "
+ + DenseEmbeddingFloatResults.Embedding.class.getName()
+ + " or "
+ + DenseEmbeddingByteResults.Embedding.class.getName()
+ + "."
);
};
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java
index 875c136ab861c..ac847f50ce353 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java
@@ -16,7 +16,7 @@
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
@@ -250,7 +250,7 @@ private float[] randomEmbedding(int length) {
}
private InferenceAction.Response inferenceResponse(float[] embedding) {
- TextEmbeddingFloatResults.Embedding embeddingResult = new TextEmbeddingFloatResults.Embedding(embedding);
- return new InferenceAction.Response(new TextEmbeddingFloatResults(List.of(embeddingResult)));
+ DenseEmbeddingFloatResults.Embedding embeddingResult = new DenseEmbeddingFloatResults.Embedding(embedding);
+ return new InferenceAction.Response(new DenseEmbeddingFloatResults(List.of(embeddingResult)));
}
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilderTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilderTests.java
index ea77c6bed3c38..ac1fb5daa36a5 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilderTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilderTests.java
@@ -15,8 +15,8 @@
import org.elasticsearch.compute.test.RandomBlock;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import java.util.List;
@@ -182,7 +182,7 @@ private void assertByteEmbeddingContent(FloatBlock block, byte[][] expectedByteE
int firstValueIndex = block.getFirstValueIndex(currentPos);
for (int i = 0; i < expectedByteEmbeddings[currentPos].length; i++) {
float actualValue = block.getFloat(firstValueIndex + i);
- // Convert byte to float the same way as TextEmbeddingByteResults.Embedding.toFloatArray()
+ // Convert byte to float the same way as DenseEmbeddingByteResults.Embedding.toFloatArray()
float expectedValue = expectedByteEmbeddings[currentPos][i];
assertThat(actualValue, equalTo(expectedValue));
}
@@ -206,20 +206,20 @@ private byte[] randomByteEmbedding(int dimension) {
}
private static InferenceAction.Response createFloatEmbeddingResponse(float[] embedding) {
- var embeddingResult = new TextEmbeddingFloatResults.Embedding(embedding);
- var textEmbeddingResults = new TextEmbeddingFloatResults(List.of(embeddingResult));
- return new InferenceAction.Response(textEmbeddingResults);
+ var embeddingResult = new DenseEmbeddingFloatResults.Embedding(embedding);
+ var denseEmbeddingResults = new DenseEmbeddingFloatResults(List.of(embeddingResult));
+ return new InferenceAction.Response(denseEmbeddingResults);
}
private static InferenceAction.Response createByteEmbeddingResponse(byte[] embedding) {
- var embeddingResult = new TextEmbeddingByteResults.Embedding(embedding);
- var textEmbeddingResults = new TextEmbeddingByteResults(List.of(embeddingResult));
- return new InferenceAction.Response(textEmbeddingResults);
+ var embeddingResult = new DenseEmbeddingByteResults.Embedding(embedding);
+ var denseEmbeddingResults = new DenseEmbeddingByteResults(List.of(embeddingResult));
+ return new InferenceAction.Response(denseEmbeddingResults);
}
private static InferenceAction.Response createEmptyFloatEmbeddingResponse() {
- var textEmbeddingResults = new TextEmbeddingFloatResults(List.of());
- return new InferenceAction.Response(textEmbeddingResults);
+ var denseEmbeddingResults = new DenseEmbeddingFloatResults(List.of());
+ return new InferenceAction.Response(denseEmbeddingResults);
}
private Page randomInputPage(int positionCount, int columnCount) {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorTests.java
index 6ff9a90b70b16..06441be9e7148 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorTests.java
@@ -13,7 +13,7 @@
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.esql.inference.InferenceOperatorTestCase;
import org.hamcrest.Matcher;
import org.junit.Before;
@@ -23,7 +23,7 @@
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
-public class TextEmbeddingOperatorTests extends InferenceOperatorTestCase {
+public class TextEmbeddingOperatorTests extends InferenceOperatorTestCase {
private static final String SIMPLE_INFERENCE_ID = "test_text_embedding";
private static final int EMBEDDING_DIMENSION = 384; // Common embedding dimension
@@ -89,15 +89,15 @@ private void assertTextEmbeddingResults(Page inputPage, Page resultPage) {
}
@Override
- protected TextEmbeddingFloatResults mockInferenceResult(InferenceAction.Request request) {
+ protected DenseEmbeddingFloatResults mockInferenceResult(InferenceAction.Request request) {
// For text embedding, we expect one input text per request
String inputText = request.getInput().get(0);
// Generate a deterministic mock embedding based on the input text
float[] mockEmbedding = generateMockEmbedding(inputText, EMBEDDING_DIMENSION);
- var embeddingResult = new TextEmbeddingFloatResults.Embedding(mockEmbedding);
- return new TextEmbeddingFloatResults(List.of(embeddingResult));
+ var embeddingResult = new DenseEmbeddingFloatResults.Embedding(mockEmbedding);
+ return new DenseEmbeddingFloatResults(List.of(embeddingResult));
}
@Override
diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java
index 051b6dbf3e8fa..45ecb3dedf3f1 100644
--- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java
+++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java
@@ -36,7 +36,7 @@
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
@@ -167,22 +167,22 @@ public void chunkedInfer(
}
}
- private TextEmbeddingFloatResults makeResults(List input, ServiceSettings serviceSettings) {
- List embeddings = new ArrayList<>();
+ private DenseEmbeddingFloatResults makeResults(List input, ServiceSettings serviceSettings) {
+ List embeddings = new ArrayList<>();
for (String inputString : input) {
List floatEmbeddings = generateEmbedding(inputString, serviceSettings.dimensions(), serviceSettings.elementType());
- embeddings.add(TextEmbeddingFloatResults.Embedding.of(floatEmbeddings));
+ embeddings.add(DenseEmbeddingFloatResults.Embedding.of(floatEmbeddings));
}
- return new TextEmbeddingFloatResults(embeddings);
+ return new DenseEmbeddingFloatResults(embeddings);
}
private List makeChunkedResults(List inputs, ServiceSettings serviceSettings) {
var results = new ArrayList();
for (ChunkInferenceInput input : inputs) {
List chunkedInput = chunkInputs(input);
- List chunks = chunkedInput.stream()
+ List chunks = chunkedInput.stream()
.map(
- c -> new TextEmbeddingFloatResults.Chunk(
+ c -> new DenseEmbeddingFloatResults.Chunk(
makeResults(List.of(c.input()), serviceSettings).embeddings().get(0),
new ChunkedInference.TextOffset(c.startOffset(), c.endOffset())
)
diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java
index 28a191a1bbfac..9ea2301abfa0c 100644
--- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java
+++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java
@@ -35,9 +35,9 @@
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.inference.DequeUtils;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import java.io.IOException;
import java.util.ArrayList;
@@ -138,9 +138,9 @@ public void infer(
)
);
} else {
- // Return text embedding results when creating a sparse_embedding inference endpoint to allow creation validation to
- // pass. This is required to test that streaming fails for a sparse_embedding endpoint.
- listener.onResponse(makeTextEmbeddingResults(input));
+ // Return dense embedding results when creating a sparse_embedding inference endpoint to allow creation validation
+ // to pass. This is required to test that streaming fails for a sparse_embedding endpoint.
+ listener.onResponse(makeDenseEmbeddingResults(input));
}
}
default -> listener.onFailure(
@@ -189,16 +189,16 @@ public void cancel() {}
});
}
- private TextEmbeddingFloatResults makeTextEmbeddingResults(List input) {
- var embeddings = new ArrayList();
+ private DenseEmbeddingFloatResults makeDenseEmbeddingResults(List input) {
+ var embeddings = new ArrayList();
for (int i = 0; i < input.size(); i++) {
var values = new float[5];
for (int j = 0; j < 5; j++) {
values[j] = random.nextFloat();
}
- embeddings.add(new TextEmbeddingFloatResults.Embedding(values));
+ embeddings.add(new DenseEmbeddingFloatResults.Embedding(values));
}
- return new TextEmbeddingFloatResults(embeddings);
+ return new DenseEmbeddingFloatResults(embeddings);
}
private InferenceServiceResults.Result completionChunk(String delta) {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
index b1887cff763fa..f227140e431bd 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
@@ -20,12 +20,12 @@
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
import org.elasticsearch.xpack.inference.chunking.NoneChunkingSettings;
import org.elasticsearch.xpack.inference.chunking.RecursiveChunkingSettings;
@@ -71,10 +71,10 @@
import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings;
import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
+import org.elasticsearch.xpack.inference.services.custom.response.DenseEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser;
-import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings;
@@ -198,7 +198,11 @@ private static void addCustomNamedWriteables(List
namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, CustomSecretSettings.NAME, CustomSecretSettings::new));
namedWriteables.add(
- new NamedWriteableRegistry.Entry(CustomResponseParser.class, TextEmbeddingResponseParser.NAME, TextEmbeddingResponseParser::new)
+ new NamedWriteableRegistry.Entry(
+ CustomResponseParser.class,
+ DenseEmbeddingResponseParser.NAME,
+ DenseEmbeddingResponseParser::new
+ )
);
namedWriteables.add(
@@ -649,10 +653,14 @@ private static void addInferenceResultsNamedWriteables(List
*
*/
- public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
+ public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) {
- return EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults();
+ return EmbeddingFloatResult.PARSER.apply(p, null).toDenseEmbeddingFloatResults();
}
}
@@ -81,9 +81,9 @@ public record EmbeddingFloatResult(List embeddingResu
}, new ParseField("data"), org.elasticsearch.xcontent.ObjectParser.ValueType.OBJECT_ARRAY);
}
- public TextEmbeddingFloatResults toTextEmbeddingFloatResults() {
- return new TextEmbeddingFloatResults(
- embeddingResults.stream().map(entry -> TextEmbeddingFloatResults.Embedding.of(entry.embedding)).toList()
+ public DenseEmbeddingFloatResults toDenseEmbeddingFloatResults() {
+ return new DenseEmbeddingFloatResults(
+ embeddingResults.stream().map(entry -> DenseEmbeddingFloatResults.Embedding.of(entry.embedding)).toList()
);
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java
index 1a8b162eb1b46..1e9a66c776b53 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java
@@ -80,7 +80,7 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter;
@@ -967,13 +967,13 @@ yield new SparseVectorQueryBuilder(
);
}
case TEXT_EMBEDDING -> {
- if (inferenceResults instanceof MlTextEmbeddingResults == false) {
+ if (inferenceResults instanceof MlDenseEmbeddingResults == false) {
throw new IllegalArgumentException(
- generateQueryInferenceResultsTypeMismatchMessage(inferenceResults, MlTextEmbeddingResults.NAME)
+ generateQueryInferenceResultsTypeMismatchMessage(inferenceResults, MlDenseEmbeddingResults.NAME)
);
}
- MlTextEmbeddingResults textEmbeddingResults = (MlTextEmbeddingResults) inferenceResults;
+ MlDenseEmbeddingResults textEmbeddingResults = (MlDenseEmbeddingResults) inferenceResults;
float[] inference = textEmbeddingResults.getInferenceAsFloat();
int dimensions = modelSettings.elementType() == DenseVectorFieldMapper.ElementType.BIT
? inference.length * Byte.SIZE // Bit vectors encode 8 dimensions into each byte value
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java
index 210ab2b67f9c9..cd590b8410234 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java
@@ -25,7 +25,7 @@
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.search.vectors.QueryVectorBuilder;
import org.elasticsearch.search.vectors.VectorData;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
@@ -195,7 +195,7 @@ private QueryBuilder querySemanticTextField(String clusterAlias, SemanticTextFie
fullyQualifiedInferenceId = new FullyQualifiedInferenceId(clusterAlias, semanticTextFieldType.getSearchInferenceId());
}
- MlTextEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(fullyQualifiedInferenceId);
+ MlDenseEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(fullyQualifiedInferenceId);
queryVector = new VectorData(textEmbeddingResults.getInferenceAsFloat());
}
@@ -226,7 +226,7 @@ private QueryBuilder queryNonSemanticTextField() {
throw new IllegalStateException("No query vector or query vector builder model ID specified");
}
- MlTextEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(fullyQualifiedInferenceId);
+ MlDenseEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(fullyQualifiedInferenceId);
queryVector = new VectorData(textEmbeddingResults.getInferenceAsFloat());
}
@@ -244,20 +244,20 @@ private QueryBuilder queryNonSemanticTextField() {
return knnQuery;
}
- private MlTextEmbeddingResults getTextEmbeddingResults(FullyQualifiedInferenceId fullyQualifiedInferenceId) {
+ private MlDenseEmbeddingResults getTextEmbeddingResults(FullyQualifiedInferenceId fullyQualifiedInferenceId) {
InferenceResults inferenceResults = inferenceResultsMap.get(fullyQualifiedInferenceId);
if (inferenceResults == null) {
throw new IllegalStateException("Could not find inference results from inference endpoint [" + fullyQualifiedInferenceId + "]");
- } else if (inferenceResults instanceof MlTextEmbeddingResults == false) {
+ } else if (inferenceResults instanceof MlDenseEmbeddingResults == false) {
throw new IllegalArgumentException(
"Expected query inference results to be of type ["
- + MlTextEmbeddingResults.NAME
+ + MlDenseEmbeddingResults.NAME
+ "], got ["
+ inferenceResults.getWriteableName()
+ "]. Are you specifying a compatible inference endpoint? Has the inference endpoint configuration changed?"
);
}
- return (MlTextEmbeddingResults) inferenceResults;
+ return (MlDenseEmbeddingResults) inferenceResults;
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java
index fddcf75ce3ba3..52abe808ee4f2 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java
@@ -33,7 +33,7 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.inference.InferenceException;
@@ -499,7 +499,7 @@ private static InferenceResults validateAndConvertInferenceResults(
InferenceResults inferenceResults = inferenceResultsList.getFirst();
if (inferenceResults instanceof TextExpansionResults == false
- && inferenceResults instanceof MlTextEmbeddingResults == false
+ && inferenceResults instanceof MlDenseEmbeddingResults == false
&& inferenceResults instanceof ErrorInferenceResults == false
&& inferenceResults instanceof WarningInferenceResults == false) {
return new ErrorInferenceResults(
@@ -507,7 +507,7 @@ private static InferenceResults validateAndConvertInferenceResults(
"Expected query inference results to be of type ["
+ TextExpansionResults.NAME
+ "] or ["
- + MlTextEmbeddingResults.NAME
+ + MlDenseEmbeddingResults.NAME
+ "], got ["
+ inferenceResults.getWriteableName()
+ "]. Has the inference endpoint ["
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntity.java
index 4e73f03e2898b..0a314202c922c 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntity.java
@@ -9,7 +9,7 @@
import org.elasticsearch.common.xcontent.XContentParserUtils;
import org.elasticsearch.xcontent.XContentParser;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
@@ -70,20 +70,20 @@ public class AlibabaCloudSearchEmbeddingsResponseEntity extends AlibabaCloudSear
*
*
*/
- public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
+ public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
return fromResponse(request, response, parser -> {
positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE);
- List embeddingList = XContentParserUtils.parseList(
+ List embeddingList = XContentParserUtils.parseList(
parser,
AlibabaCloudSearchEmbeddingsResponseEntity::parseEmbeddingObject
);
- return new TextEmbeddingFloatResults(embeddingList);
+ return new DenseEmbeddingFloatResults(embeddingList);
});
}
- private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException {
+ private static DenseEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException {
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);
@@ -95,7 +95,7 @@ private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContent
// if there are additional fields within this object, lets skip them, so we can begin parsing the next embedding array
parser.skipChildren();
- return TextEmbeddingFloatResults.Embedding.of(embeddingValues);
+ return DenseEmbeddingFloatResults.Embedding.of(embeddingValues);
}
private static float parseEmbeddingList(XContentParser parser) throws IOException {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java
index 831bf9938c211..61ec5f0c39790 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java
@@ -16,7 +16,7 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.response.XContentUtils;
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider;
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.AmazonBedrockRequest;
@@ -48,7 +48,7 @@ public InferenceServiceResults accept(AmazonBedrockRequest request) {
throw new ElasticsearchException("unexpected request type [" + request.getClass() + "]");
}
- public static TextEmbeddingFloatResults fromResponse(InvokeModelResponse response, AmazonBedrockProvider provider) {
+ public static DenseEmbeddingFloatResults fromResponse(InvokeModelResponse response, AmazonBedrockProvider provider) {
var charset = StandardCharsets.UTF_8;
var bodyText = String.valueOf(charset.decode(response.body().asByteBuffer()));
@@ -63,13 +63,13 @@ public static TextEmbeddingFloatResults fromResponse(InvokeModelResponse respons
var embeddingList = parseEmbeddings(jsonParser, provider);
- return new TextEmbeddingFloatResults(embeddingList);
+ return new DenseEmbeddingFloatResults(embeddingList);
} catch (IOException e) {
throw new ElasticsearchException(e);
}
}
- private static List parseEmbeddings(XContentParser jsonParser, AmazonBedrockProvider provider)
+ private static List parseEmbeddings(XContentParser jsonParser, AmazonBedrockProvider provider)
throws IOException {
switch (provider) {
case AMAZONTITAN -> {
@@ -82,7 +82,7 @@ private static List parseEmbeddings(XConten
}
}
- private static List parseTitanEmbeddings(XContentParser parser) throws IOException {
+ private static List parseTitanEmbeddings(XContentParser parser) throws IOException {
/*
Titan response:
{
@@ -92,11 +92,11 @@ private static List parseTitanEmbeddings(XC
*/
positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);
List embeddingValuesList = parseList(parser, XContentUtils::parseFloat);
- var embeddingValues = TextEmbeddingFloatResults.Embedding.of(embeddingValuesList);
+ var embeddingValues = DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList);
return List.of(embeddingValues);
}
- private static List parseCohereEmbeddings(XContentParser parser) throws IOException {
+ private static List parseCohereEmbeddings(XContentParser parser) throws IOException {
/*
Cohere response:
{
@@ -111,7 +111,7 @@ private static List parseCohereEmbeddings(X
*/
positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE);
- List embeddingList = parseList(
+ List embeddingList = parseList(
parser,
AmazonBedrockEmbeddingsResponse::parseCohereEmbeddingsListItem
);
@@ -119,9 +119,9 @@ private static List parseCohereEmbeddings(X
return embeddingList;
}
- private static TextEmbeddingFloatResults.Embedding parseCohereEmbeddingsListItem(XContentParser parser) throws IOException {
+ private static DenseEmbeddingFloatResults.Embedding parseCohereEmbeddingsListItem(XContentParser parser) throws IOException {
List embeddingValuesList = parseList(parser, XContentUtils::parseFloat);
- return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList);
+ return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList);
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntity.java
index b4a2e142b3792..a5a174e80b289 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntity.java
@@ -15,9 +15,9 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.XContentUtils;
@@ -189,20 +189,20 @@ private static InferenceServiceResults parseBitEmbeddingsArray(XContentParser pa
// Cohere returns array of binary embeddings encoded as bytes with int8 precision so we can reuse the byte parser
var embeddingList = parseList(parser, CohereEmbeddingsResponseEntity::parseByteArrayEntry);
- return new TextEmbeddingBitResults(embeddingList);
+ return new DenseEmbeddingBitResults(embeddingList);
}
private static InferenceServiceResults parseByteEmbeddingsArray(XContentParser parser) throws IOException {
var embeddingList = parseList(parser, CohereEmbeddingsResponseEntity::parseByteArrayEntry);
- return new TextEmbeddingByteResults(embeddingList);
+ return new DenseEmbeddingByteResults(embeddingList);
}
- private static TextEmbeddingByteResults.Embedding parseByteArrayEntry(XContentParser parser) throws IOException {
+ private static DenseEmbeddingByteResults.Embedding parseByteArrayEntry(XContentParser parser) throws IOException {
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
List embeddingValuesList = parseList(parser, CohereEmbeddingsResponseEntity::parseEmbeddingInt8Entry);
- return TextEmbeddingByteResults.Embedding.of(embeddingValuesList);
+ return DenseEmbeddingByteResults.Embedding.of(embeddingValuesList);
}
private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException {
@@ -223,13 +223,13 @@ private static void checkByteBounds(short value) {
private static InferenceServiceResults parseFloatEmbeddingsArray(XContentParser parser) throws IOException {
var embeddingList = parseList(parser, CohereEmbeddingsResponseEntity::parseFloatArrayEntry);
- return new TextEmbeddingFloatResults(embeddingList);
+ return new DenseEmbeddingFloatResults(embeddingList);
}
- private static TextEmbeddingFloatResults.Embedding parseFloatArrayEntry(XContentParser parser) throws IOException {
+ private static DenseEmbeddingFloatResults.Embedding parseFloatArrayEntry(XContentParser parser) throws IOException {
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
List embeddingValuesList = parseList(parser, XContentUtils::parseFloat);
- return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList);
+ return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList);
}
private CohereEmbeddingsResponseEntity() {}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java
index 5986b2104ffb1..44901b49a7e47 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java
@@ -25,10 +25,10 @@
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
+import org.elasticsearch.xpack.inference.services.custom.response.DenseEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser;
-import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@@ -501,7 +501,7 @@ private static CustomResponseParser extractResponseParser(
}
return switch (taskType) {
- case TEXT_EMBEDDING -> TextEmbeddingResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException);
+ case TEXT_EMBEDDING -> DenseEmbeddingResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException);
case SPARSE_EMBEDDING -> SparseEmbeddingResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException);
case RERANK -> RerankResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException);
case COMPLETION -> CompletionResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException);
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/DenseEmbeddingResponseParser.java
similarity index 80%
rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java
rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/DenseEmbeddingResponseParser.java
index f665c0be81511..18aaf186b4a3f 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/DenseEmbeddingResponseParser.java
@@ -14,9 +14,9 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.XContentBuilder;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.common.MapPathExtractor;
import org.elasticsearch.xpack.inference.services.custom.CustomServiceEmbeddingType;
@@ -31,8 +31,8 @@
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER;
-public class TextEmbeddingResponseParser extends BaseCustomResponseParser {
-
+public class DenseEmbeddingResponseParser extends BaseCustomResponseParser {
+ // This name is a holdover from before this class was renamed
public static final String NAME = "text_embedding_response_parser";
public static final String TEXT_EMBEDDING_PARSER_EMBEDDINGS = "text_embeddings";
public static final String EMBEDDING_TYPE = "embedding_type";
@@ -41,7 +41,7 @@ public class TextEmbeddingResponseParser extends BaseCustomResponseParser {
"ml_inference_custom_service_embedding_type"
);
- public static TextEmbeddingResponseParser fromMap(
+ public static DenseEmbeddingResponseParser fromMap(
Map responseParserMap,
String scope,
ValidationException validationException
@@ -70,18 +70,18 @@ public static TextEmbeddingResponseParser fromMap(
throw validationException;
}
- return new TextEmbeddingResponseParser(path, embeddingType);
+ return new DenseEmbeddingResponseParser(path, embeddingType);
}
private final String textEmbeddingsPath;
private final CustomServiceEmbeddingType embeddingType;
- public TextEmbeddingResponseParser(String textEmbeddingsPath, CustomServiceEmbeddingType embeddingType) {
+ public DenseEmbeddingResponseParser(String textEmbeddingsPath, CustomServiceEmbeddingType embeddingType) {
this.textEmbeddingsPath = Objects.requireNonNull(textEmbeddingsPath);
this.embeddingType = Objects.requireNonNull(embeddingType);
}
- public TextEmbeddingResponseParser(StreamInput in) throws IOException {
+ public DenseEmbeddingResponseParser(StreamInput in) throws IOException {
this.textEmbeddingsPath = in.readString();
if (in.getTransportVersion().supports(ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE)) {
this.embeddingType = in.readEnum(CustomServiceEmbeddingType.class);
@@ -122,7 +122,7 @@ public CustomServiceEmbeddingType getEmbeddingType() {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
- TextEmbeddingResponseParser that = (TextEmbeddingResponseParser) o;
+ DenseEmbeddingResponseParser that = (DenseEmbeddingResponseParser) o;
return Objects.equals(textEmbeddingsPath, that.textEmbeddingsPath) && Objects.equals(embeddingType, that.embeddingType);
}
@@ -174,7 +174,7 @@ private interface EmbeddingConverter {
private static class FloatEmbeddings implements EmbeddingConverter {
- private final List embeddings;
+ private final List embeddings;
FloatEmbeddings() {
this.embeddings = new ArrayList<>();
@@ -182,17 +182,17 @@ private static class FloatEmbeddings implements EmbeddingConverter {
public void toEmbedding(Object entry, String fieldName) {
var embeddingsAsListFloats = convertToListOfFloats(entry, fieldName);
- embeddings.add(TextEmbeddingFloatResults.Embedding.of(embeddingsAsListFloats));
+ embeddings.add(DenseEmbeddingFloatResults.Embedding.of(embeddingsAsListFloats));
}
- public TextEmbeddingFloatResults getResults() {
- return new TextEmbeddingFloatResults(embeddings);
+ public DenseEmbeddingFloatResults getResults() {
+ return new DenseEmbeddingFloatResults(embeddings);
}
}
private static class ByteEmbeddings implements EmbeddingConverter {
- private final List embeddings;
+ private final List embeddings;
ByteEmbeddings() {
this.embeddings = new ArrayList<>();
@@ -200,17 +200,17 @@ private static class ByteEmbeddings implements EmbeddingConverter {
public void toEmbedding(Object entry, String fieldName) {
var convertedEmbeddings = convertToListOfBytes(entry, fieldName);
- this.embeddings.add(TextEmbeddingByteResults.Embedding.of(convertedEmbeddings));
+ this.embeddings.add(DenseEmbeddingByteResults.Embedding.of(convertedEmbeddings));
}
- public TextEmbeddingByteResults getResults() {
- return new TextEmbeddingByteResults(embeddings);
+ public DenseEmbeddingByteResults getResults() {
+ return new DenseEmbeddingByteResults(embeddings);
}
}
private static class BitEmbeddings implements EmbeddingConverter {
- private final List embeddings;
+ private final List embeddings;
BitEmbeddings() {
this.embeddings = new ArrayList<>();
@@ -218,11 +218,11 @@ private static class BitEmbeddings implements EmbeddingConverter {
public void toEmbedding(Object entry, String fieldName) {
var convertedEmbeddings = convertToListOfBits(entry, fieldName);
- this.embeddings.add(TextEmbeddingByteResults.Embedding.of(convertedEmbeddings));
+ this.embeddings.add(DenseEmbeddingByteResults.Embedding.of(convertedEmbeddings));
}
- public TextEmbeddingBitResults getResults() {
- return new TextEmbeddingBitResults(embeddings);
+ public DenseEmbeddingBitResults getResults() {
+ return new DenseEmbeddingBitResults(embeddings);
}
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java
index 2c1ee96b519a3..646ee520de83f 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java
@@ -38,15 +38,15 @@
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.XPackSettings;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
@@ -642,7 +642,7 @@ public void inferTextEmbedding(
);
ActionListener mlResultsListener = listener.delegateFailureAndWrap(
- (l, inferenceResult) -> l.onResponse(TextEmbeddingFloatResults.of(inferenceResult.getInferenceResults()))
+ (l, inferenceResult) -> l.onResponse(DenseEmbeddingFloatResults.of(inferenceResult.getInferenceResults()))
);
var maybeDeployListener = mlResultsListener.delegateResponse(
@@ -772,22 +772,22 @@ private static void translateToChunkedResult(
ActionListener chunkPartListener
) {
if (taskType == TaskType.TEXT_EMBEDDING) {
- var translated = new ArrayList();
+ var translated = new ArrayList();
for (var inferenceResult : inferenceResults) {
- if (inferenceResult instanceof MlTextEmbeddingResults mlTextEmbeddingResult) {
- translated.add(new TextEmbeddingFloatResults.Embedding(mlTextEmbeddingResult.getInferenceAsFloat()));
+ if (inferenceResult instanceof MlDenseEmbeddingResults mlTextEmbeddingResult) {
+ translated.add(new DenseEmbeddingFloatResults.Embedding(mlTextEmbeddingResult.getInferenceAsFloat()));
} else if (inferenceResult instanceof ErrorInferenceResults error) {
chunkPartListener.onFailure(error.getException());
return;
} else {
chunkPartListener.onFailure(
- createInvalidChunkedResultException(MlTextEmbeddingResults.NAME, inferenceResult.getWriteableName())
+ createInvalidChunkedResultException(MlDenseEmbeddingResults.NAME, inferenceResult.getWriteableName())
);
return;
}
}
- chunkPartListener.onResponse(new TextEmbeddingFloatResults(translated));
+ chunkPartListener.onResponse(new DenseEmbeddingFloatResults(translated));
} else { // sparse
var translated = new ArrayList();
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntity.java
index 499fe9ae0c6c7..67527698bf02c 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntity.java
@@ -12,7 +12,7 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.XContentUtils;
@@ -70,7 +70,7 @@ public class GoogleAiStudioEmbeddingsResponseEntity {
*
*/
- public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
+ public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
@@ -81,16 +81,16 @@ public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult
positionParserAtTokenAfterField(jsonParser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE);
- List embeddingList = parseList(
+ List embeddingList = parseList(
jsonParser,
GoogleAiStudioEmbeddingsResponseEntity::parseEmbeddingObject
);
- return new TextEmbeddingFloatResults(embeddingList);
+ return new DenseEmbeddingFloatResults(embeddingList);
}
}
- private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException {
+ private static DenseEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
positionParserAtTokenAfterField(parser, "values", FAILED_TO_FIND_FIELD_TEMPLATE);
@@ -99,7 +99,7 @@ private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContent
// parse and discard the rest of the object
consumeUntilObjectEnd(parser);
- return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList);
+ return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList);
}
private GoogleAiStudioEmbeddingsResponseEntity() {}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntity.java
index b4038e42c62cb..94272815e8db2 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntity.java
@@ -13,7 +13,7 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
@@ -64,7 +64,7 @@ public class GoogleVertexAiEmbeddingsResponseEntity {
*
*/
- public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
+ public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
@@ -75,16 +75,16 @@ public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult
positionParserAtTokenAfterField(jsonParser, "predictions", FAILED_TO_FIND_FIELD_TEMPLATE);
- List embeddingList = parseList(
+ List embeddingList = parseList(
jsonParser,
GoogleVertexAiEmbeddingsResponseEntity::parseEmbeddingObject
);
- return new TextEmbeddingFloatResults(embeddingList);
+ return new DenseEmbeddingFloatResults(embeddingList);
}
}
- private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException {
+ private static DenseEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE);
@@ -99,7 +99,7 @@ private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContent
consumeUntilObjectEnd(parser);
consumeUntilObjectEnd(parser);
- return TextEmbeddingFloatResults.Embedding.of(embeddingValueList);
+ return DenseEmbeddingFloatResults.Embedding.of(embeddingValueList);
}
private static float parseEmbeddingList(XContentParser parser) throws IOException {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java
index 081d5c63b84ff..681058d8c21a5 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java
@@ -26,9 +26,9 @@
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -127,8 +127,8 @@ private static List translateToChunkedResults(
List inputs,
InferenceServiceResults inferenceResults
) {
- if (inferenceResults instanceof TextEmbeddingFloatResults textEmbeddingResults) {
- validateInputSizeAgainstEmbeddings(ChunkInferenceInput.inputs(inputs), textEmbeddingResults.embeddings().size());
+ if (inferenceResults instanceof DenseEmbeddingFloatResults denseEmbeddingResults) {
+ validateInputSizeAgainstEmbeddings(ChunkInferenceInput.inputs(inputs), denseEmbeddingResults.embeddings().size());
var results = new ArrayList(inputs.size());
@@ -137,7 +137,7 @@ private static List translateToChunkedResults(
new ChunkedInferenceEmbedding(
List.of(
new EmbeddingResults.Chunk(
- textEmbeddingResults.embeddings().get(i),
+ denseEmbeddingResults.embeddings().get(i),
new ChunkedInference.TextOffset(0, inputs.get(i).input().length())
)
)
@@ -153,7 +153,7 @@ private static List translateToChunkedResults(
} else {
String expectedClasses = Strings.format(
"One of [%s,%s]",
- TextEmbeddingFloatResults.class.getSimpleName(),
+ DenseEmbeddingFloatResults.class.getSimpleName(),
SparseEmbeddingResults.class.getSimpleName()
);
throw createInvalidChunkedResultException(expectedClasses, inferenceResults.getWriteableName());
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntity.java
index baf1e884108fb..126d5b03fcde4 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntity.java
@@ -12,7 +12,7 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.XContentUtils;
@@ -33,7 +33,7 @@ public class HuggingFaceEmbeddingsResponseEntity {
* Parse the response from hugging face. The known formats are an array of arrays and object with an {@code embeddings} field containing
* an array of arrays.
*/
- public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
+ public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
@@ -91,13 +91,13 @@ public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult
* sentence-transformers/all-MiniLM-L6-v2
* sentence-transformers/all-MiniLM-L12-v2
*/
- private static TextEmbeddingFloatResults parseArrayFormat(XContentParser parser) throws IOException {
- List embeddingList = parseList(
+ private static DenseEmbeddingFloatResults parseArrayFormat(XContentParser parser) throws IOException {
+ List embeddingList = parseList(
parser,
HuggingFaceEmbeddingsResponseEntity::parseEmbeddingEntry
);
- return new TextEmbeddingFloatResults(embeddingList);
+ return new DenseEmbeddingFloatResults(embeddingList);
}
/**
@@ -136,22 +136,22 @@ private static TextEmbeddingFloatResults parseArrayFormat(XContentParser parser)
* intfloat/multilingual-e5-small
* sentence-transformers/all-mpnet-base-v2
*/
- private static TextEmbeddingFloatResults parseObjectFormat(XContentParser parser) throws IOException {
+ private static DenseEmbeddingFloatResults parseObjectFormat(XContentParser parser) throws IOException {
positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE);
- List embeddingList = parseList(
+ List embeddingList = parseList(
parser,
HuggingFaceEmbeddingsResponseEntity::parseEmbeddingEntry
);
- return new TextEmbeddingFloatResults(embeddingList);
+ return new DenseEmbeddingFloatResults(embeddingList);
}
- private static TextEmbeddingFloatResults.Embedding parseEmbeddingEntry(XContentParser parser) throws IOException {
+ private static DenseEmbeddingFloatResults.Embedding parseEmbeddingEntry(XContentParser parser) throws IOException {
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
List embeddingValuesList = parseList(parser, XContentUtils::parseFloat);
- return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList);
+ return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList);
}
private HuggingFaceEmbeddingsResponseEntity() {}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntity.java
index 4fda9d5661a2c..d12f44932ed6e 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntity.java
@@ -12,7 +12,7 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.XContentUtils;
@@ -30,7 +30,7 @@ public class IbmWatsonxEmbeddingsResponseEntity {
private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in IBM watsonx embeddings response";
- public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
+ public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
@@ -41,16 +41,16 @@ public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult
positionParserAtTokenAfterField(jsonParser, "results", FAILED_TO_FIND_FIELD_TEMPLATE);
- List embeddingList = parseList(
+ List embeddingList = parseList(
jsonParser,
IbmWatsonxEmbeddingsResponseEntity::parseEmbeddingObject
);
- return new TextEmbeddingFloatResults(embeddingList);
+ return new DenseEmbeddingFloatResults(embeddingList);
}
}
- private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException {
+ private static DenseEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);
@@ -59,7 +59,7 @@ private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContent
// parse and discard the rest of the object
consumeUntilObjectEnd(parser);
- return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList);
+ return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList);
}
private IbmWatsonxEmbeddingsResponseEntity() {}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntity.java
index 8eee003accba0..f9c80d2593642 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntity.java
@@ -15,9 +15,9 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.XContentUtils;
@@ -126,15 +126,15 @@ public static InferenceServiceResults fromResponse(Request request, HttpResult r
}
private static InferenceServiceResults parseFloatDataObject(XContentParser jsonParser) throws IOException {
- List embeddingList = parseList(
+ List embeddingList = parseList(
jsonParser,
JinaAIEmbeddingsResponseEntity::parseFloatEmbeddingObject
);
- return new TextEmbeddingFloatResults(embeddingList);
+ return new DenseEmbeddingFloatResults(embeddingList);
}
- private static TextEmbeddingFloatResults.Embedding parseFloatEmbeddingObject(XContentParser parser) throws IOException {
+ private static DenseEmbeddingFloatResults.Embedding parseFloatEmbeddingObject(XContentParser parser) throws IOException {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);
@@ -143,19 +143,19 @@ private static TextEmbeddingFloatResults.Embedding parseFloatEmbeddingObject(XCo
// parse and discard the rest of the object
consumeUntilObjectEnd(parser);
- return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList);
+ return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList);
}
private static InferenceServiceResults parseBitDataObject(XContentParser jsonParser) throws IOException {
- List embeddingList = parseList(
+ List embeddingList = parseList(
jsonParser,
JinaAIEmbeddingsResponseEntity::parseBitEmbeddingObject
);
- return new TextEmbeddingBitResults(embeddingList);
+ return new DenseEmbeddingBitResults(embeddingList);
}
- private static TextEmbeddingByteResults.Embedding parseBitEmbeddingObject(XContentParser parser) throws IOException {
+ private static DenseEmbeddingByteResults.Embedding parseBitEmbeddingObject(XContentParser parser) throws IOException {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);
@@ -164,7 +164,7 @@ private static TextEmbeddingByteResults.Embedding parseBitEmbeddingObject(XConte
// parse and discard the rest of the object
consumeUntilObjectEnd(parser);
- return TextEmbeddingByteResults.Embedding.of(embeddingList);
+ return DenseEmbeddingByteResults.Embedding.of(embeddingList);
}
private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntity.java
index b8130545a711d..3298ca310c01d 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntity.java
@@ -12,7 +12,7 @@
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
@@ -65,9 +65,9 @@ public class OpenAiEmbeddingsResponseEntity {
*
*
*/
- public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
+ public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) {
- return EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults();
+ return EmbeddingFloatResult.PARSER.apply(p, null).toDenseEmbeddingFloatResults();
}
}
@@ -83,9 +83,9 @@ public record EmbeddingFloatResult(List embeddingResu
PARSER.declareObjectArray(constructorArg(), EmbeddingFloatResultEntry.PARSER::apply, new ParseField("data"));
}
- public TextEmbeddingFloatResults toTextEmbeddingFloatResults() {
- return new TextEmbeddingFloatResults(
- embeddingResults.stream().map(entry -> TextEmbeddingFloatResults.Embedding.of(entry.embedding)).toList()
+ public DenseEmbeddingFloatResults toDenseEmbeddingFloatResults() {
+ return new DenseEmbeddingFloatResults(
+ embeddingResults.stream().map(entry -> DenseEmbeddingFloatResults.Embedding.of(entry.embedding)).toList()
);
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredServiceSchema.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredServiceSchema.java
index 70ed23e215dd6..94d96b3db6d26 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredServiceSchema.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredServiceSchema.java
@@ -12,6 +12,7 @@
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.validation.DenseEmbeddingModelValidator;
/**
* Contains any model-specific settings that are stored in SageMakerServiceSettings.
@@ -72,7 +73,7 @@ default boolean isFragment() {
/**
* If this Schema supports Text Embeddings, then we need to implement this.
- * {@link org.elasticsearch.xpack.inference.services.validation.TextEmbeddingModelValidator} will set the dimensions if the user
+ * {@link DenseEmbeddingModelValidator} will set the dimensions if the user
* does not do it, so we need to store the dimensions and flip the {@link #dimensionsSetByUser()} boolean.
*/
default SageMakerStoredServiceSchema updateModelWithEmbeddingDetails(Integer dimensions) {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java
index a5fd194f12109..dbbf82f5703b9 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java
@@ -24,10 +24,10 @@
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParserConfiguration;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingResults;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema;
@@ -92,7 +92,7 @@ public Stream namedWriteables() {
}
@Override
- public TextEmbeddingResults> responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception {
+ public DenseEmbeddingResults> responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception {
try (var p = jsonXContent.createParser(XContentParserConfiguration.EMPTY, response.body().asInputStream())) {
return switch (model.apiServiceSettings().elementType()) {
case BIT -> TextEmbeddingBinary.PARSER.apply(p, null);
@@ -103,7 +103,7 @@ public TextEmbeddingResults> responseBody(SageMakerModel model, InvokeEndpoint
}
/**
- * Reads binary format (it says bytes, but the lengths are different)
+ * Reads binary format
* {
* "text_embedding_bits": [
* {
@@ -120,12 +120,12 @@ public TextEmbeddingResults> responseBody(SageMakerModel model, InvokeEndpoint
* }
*/
private static class TextEmbeddingBinary {
- private static final ParseField TEXT_EMBEDDING_BITS = new ParseField(TextEmbeddingBitResults.TEXT_EMBEDDING_BITS);
+ private static final ParseField TEXT_EMBEDDING_BITS = new ParseField(DenseEmbeddingBitResults.TEXT_EMBEDDING_BITS);
@SuppressWarnings("unchecked")
- private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
- TextEmbeddingBitResults.class.getSimpleName(),
+ private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
+ DenseEmbeddingBitResults.class.getSimpleName(),
IGNORE_UNKNOWN_FIELDS,
- args -> new TextEmbeddingBitResults((List) args[0])
+ args -> new DenseEmbeddingBitResults((List) args[0])
);
static {
@@ -151,20 +151,20 @@ private static class TextEmbeddingBinary {
* }
*/
private static class TextEmbeddingBytes {
- private static final ParseField TEXT_EMBEDDING_BYTES = new ParseField("text_embedding_bytes");
+ private static final ParseField TEXT_EMBEDDING_BYTES = new ParseField(DenseEmbeddingByteResults.TEXT_EMBEDDING_BYTES);
@SuppressWarnings("unchecked")
- private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
- TextEmbeddingByteResults.class.getSimpleName(),
+ private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
+ DenseEmbeddingByteResults.class.getSimpleName(),
IGNORE_UNKNOWN_FIELDS,
- args -> new TextEmbeddingByteResults((List) args[0])
+ args -> new DenseEmbeddingByteResults((List) args[0])
);
@SuppressWarnings("unchecked")
- private static final ConstructingObjectParser BYTE_PARSER =
+ private static final ConstructingObjectParser BYTE_PARSER =
new ConstructingObjectParser<>(
- TextEmbeddingByteResults.Embedding.class.getSimpleName(),
+ DenseEmbeddingByteResults.Embedding.class.getSimpleName(),
IGNORE_UNKNOWN_FIELDS,
- args -> TextEmbeddingByteResults.Embedding.of((List) args[0])
+ args -> DenseEmbeddingByteResults.Embedding.of((List) args[0])
);
static {
@@ -197,20 +197,20 @@ private static class TextEmbeddingBytes {
* }
*/
private static class TextEmbeddingFloat {
- private static final ParseField TEXT_EMBEDDING_FLOAT = new ParseField("text_embedding");
+ private static final ParseField TEXT_EMBEDDING_FLOAT = new ParseField(DenseEmbeddingFloatResults.TEXT_EMBEDDING);
@SuppressWarnings("unchecked")
- private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
- TextEmbeddingByteResults.class.getSimpleName(),
+ private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
+ DenseEmbeddingFloatResults.class.getSimpleName(),
IGNORE_UNKNOWN_FIELDS,
- args -> new TextEmbeddingFloatResults((List) args[0])
+ args -> new DenseEmbeddingFloatResults((List) args[0])
);
@SuppressWarnings("unchecked")
- private static final ConstructingObjectParser FLOAT_PARSER =
+ private static final ConstructingObjectParser FLOAT_PARSER =
new ConstructingObjectParser<>(
- TextEmbeddingFloatResults.Embedding.class.getSimpleName(),
+ DenseEmbeddingFloatResults.Embedding.class.getSimpleName(),
IGNORE_UNKNOWN_FIELDS,
- args -> TextEmbeddingFloatResults.Embedding.of((List) args[0])
+ args -> DenseEmbeddingFloatResults.Embedding.of((List) args[0])
);
static {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java
index f060c7a75797b..b4c0b62f7ca12 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java
@@ -25,7 +25,7 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.json.JsonXContent;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
@@ -116,9 +116,9 @@ public SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest req
}
@Override
- public TextEmbeddingFloatResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception {
+ public DenseEmbeddingFloatResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception {
try (var p = jsonXContent.createParser(XContentParserConfiguration.EMPTY, response.body().asInputStream())) {
- return OpenAiEmbeddingsResponseEntity.EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults();
+ return OpenAiEmbeddingsResponseEntity.EmbeddingFloatResult.PARSER.apply(p, null).toDenseEmbeddingFloatResults();
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/DenseEmbeddingModelValidator.java
similarity index 85%
rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java
rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/DenseEmbeddingModelValidator.java
index ce9df7376ebcb..fb7b415758ca5 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/DenseEmbeddingModelValidator.java
@@ -16,14 +16,14 @@
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.validation.ServiceIntegrationValidator;
import org.elasticsearch.rest.RestStatus;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingResults;
-public class TextEmbeddingModelValidator implements ModelValidator {
+public class DenseEmbeddingModelValidator implements ModelValidator {
private final ServiceIntegrationValidator serviceIntegrationValidator;
- public TextEmbeddingModelValidator(ServiceIntegrationValidator serviceIntegrationValidator) {
+ public DenseEmbeddingModelValidator(ServiceIntegrationValidator serviceIntegrationValidator) {
this.serviceIntegrationValidator = serviceIntegrationValidator;
}
@@ -35,7 +35,7 @@ public void validate(InferenceService service, Model model, TimeValue timeout, A
}
private Model postValidate(InferenceService service, Model model, InferenceServiceResults results) {
- if (results instanceof TextEmbeddingResults> embeddingResults) {
+ if (results instanceof DenseEmbeddingResults> embeddingResults) {
var serviceSettings = model.getServiceSettings();
var dimensions = serviceSettings.dimensions();
int embeddingSize = getEmbeddingSize(embeddingResults);
@@ -60,7 +60,7 @@ private Model postValidate(InferenceService service, Model model, InferenceServi
throw new ElasticsearchStatusException(
"Validation call did not return expected results type."
+ "Expected a result of type ["
- + TextEmbeddingFloatResults.NAME
+ + DenseEmbeddingFloatResults.NAME
+ "] got ["
+ (results == null ? "null" : results.getWriteableName())
+ "]",
@@ -69,7 +69,7 @@ private Model postValidate(InferenceService service, Model model, InferenceServi
}
}
- private int getEmbeddingSize(TextEmbeddingResults> embeddingResults) {
+ private int getEmbeddingSize(DenseEmbeddingResults> embeddingResults) {
int embeddingSize;
try {
embeddingSize = embeddingResults.getFirstEmbeddingSize();
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java
index fa0e1b3e590a4..6078531eec3f9 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java
@@ -17,8 +17,8 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.validation.ServiceIntegrationValidator;
import org.elasticsearch.rest.RestStatus;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingResults;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandEmbeddingModel;
public class ElasticsearchInternalServiceModelValidator implements ModelValidator {
@@ -54,7 +54,7 @@ public void validate(InferenceService service, Model model, TimeValue timeout, A
}
private Model postValidate(InferenceService service, Model model, InferenceServiceResults results) {
- if (results instanceof TextEmbeddingResults> embeddingResults) {
+ if (results instanceof DenseEmbeddingResults> embeddingResults) {
var serviceSettings = model.getServiceSettings();
var dimensions = serviceSettings.dimensions();
int embeddingSize = getEmbeddingSize(embeddingResults);
@@ -79,7 +79,7 @@ private Model postValidate(InferenceService service, Model model, InferenceServi
throw new ElasticsearchStatusException(
"Validation call did not return expected results type."
+ "Expected a result of type ["
- + TextEmbeddingFloatResults.NAME
+ + DenseEmbeddingFloatResults.NAME
+ "] got ["
+ (results == null ? "null" : results.getWriteableName())
+ "]",
@@ -88,7 +88,7 @@ private Model postValidate(InferenceService service, Model model, InferenceServi
}
}
- private int getEmbeddingSize(TextEmbeddingResults> embeddingResults) {
+ private int getEmbeddingSize(DenseEmbeddingResults> embeddingResults) {
int embeddingSize;
try {
embeddingSize = embeddingResults.getFirstEmbeddingSize();
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java
index fac9ee5e9c1c1..59fcdac82c2a4 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java
@@ -36,7 +36,7 @@ private static ModelValidator buildModelValidatorForTaskType(TaskType taskType,
switch (taskType) {
case TEXT_EMBEDDING -> {
- return new TextEmbeddingModelValidator(
+ return new DenseEmbeddingModelValidator(
Objects.requireNonNullElse(validatorFromService, new SimpleServiceIntegrationValidator())
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntity.java
index f9ba5fd58d21a..61436d509e45a 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntity.java
@@ -15,9 +15,9 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType;
@@ -75,9 +75,9 @@ private static void checkByteBounds(Integer value) {
}
}
- public TextEmbeddingByteResults.Embedding toInferenceByteEmbedding() {
+ public DenseEmbeddingByteResults.Embedding toInferenceByteEmbedding() {
embedding.forEach(EmbeddingInt8ResultEntry::checkByteBounds);
- return TextEmbeddingByteResults.Embedding.of(embedding.stream().map(Integer::byteValue).toList());
+ return DenseEmbeddingByteResults.Embedding.of(embedding.stream().map(Integer::byteValue).toList());
}
}
@@ -108,8 +108,8 @@ record EmbeddingFloatResultEntry(Integer index, List embedding) {
PARSER.declareFloatArray(constructorArg(), new ParseField("embedding"));
}
- public TextEmbeddingFloatResults.Embedding toInferenceFloatEmbedding() {
- return TextEmbeddingFloatResults.Embedding.of(embedding);
+ public DenseEmbeddingFloatResults.Embedding toInferenceFloatEmbedding() {
+ return DenseEmbeddingFloatResults.Embedding.of(embedding);
}
}
@@ -166,22 +166,22 @@ public static InferenceServiceResults fromResponse(Request request, HttpResult r
if (embeddingType == null || embeddingType == VoyageAIEmbeddingType.FLOAT) {
var embeddingResult = EmbeddingFloatResult.PARSER.apply(jsonParser, null);
- List embeddingList = embeddingResult.entries.stream()
+ List embeddingList = embeddingResult.entries.stream()
.map(EmbeddingFloatResultEntry::toInferenceFloatEmbedding)
.toList();
- return new TextEmbeddingFloatResults(embeddingList);
+ return new DenseEmbeddingFloatResults(embeddingList);
} else if (embeddingType == VoyageAIEmbeddingType.INT8) {
var embeddingResult = EmbeddingInt8Result.PARSER.apply(jsonParser, null);
- List embeddingList = embeddingResult.entries.stream()
+ List embeddingList = embeddingResult.entries.stream()
.map(EmbeddingInt8ResultEntry::toInferenceByteEmbedding)
.toList();
- return new TextEmbeddingByteResults(embeddingList);
+ return new DenseEmbeddingByteResults(embeddingList);
} else if (embeddingType == VoyageAIEmbeddingType.BIT || embeddingType == VoyageAIEmbeddingType.BINARY) {
var embeddingResult = EmbeddingInt8Result.PARSER.apply(jsonParser, null);
- List embeddingList = embeddingResult.entries.stream()
+ List embeddingList = embeddingResult.entries.stream()
.map(EmbeddingInt8ResultEntry::toInferenceByteEmbedding)
.toList();
- return new TextEmbeddingBitResults(embeddingList);
+ return new DenseEmbeddingBitResults(embeddingList);
} else {
throw new IllegalArgumentException(
"Illegal embedding_type value: " + embeddingType + ". Supported types are: " + VALID_EMBEDDING_TYPES_STRING
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java
index c7ff03b05e975..429cf65991b13 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java
@@ -11,8 +11,8 @@
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider;
@@ -39,7 +39,7 @@ protected Writeable.Reader instanceReader() {
@Override
protected InferenceAction.Response createTestInstance() {
var result = randomBoolean()
- ? TextEmbeddingFloatResultsTests.createRandomResults()
+ ? DenseEmbeddingFloatResultsTests.createRandomResults()
: SparseEmbeddingResultsTests.createRandomResults();
return new InferenceAction.Response(result);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java
index 411d992adfa3d..9d450912d9c4e 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java
@@ -14,10 +14,10 @@
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.hamcrest.Matchers;
import java.util.ArrayList;
@@ -393,12 +393,12 @@ public void testVeryLongInput_Float() {
// Produce inference results for each request, with increasing weights.
float weight = 0f;
for (var batch : batches) {
- var embeddings = new ArrayList();
+ var embeddings = new ArrayList();
for (int i = 0; i < batch.batch().requests().size(); i++) {
weight += 1 / 16384f;
- embeddings.add(new TextEmbeddingFloatResults.Embedding(new float[] { weight }));
+ embeddings.add(new DenseEmbeddingFloatResults.Embedding(new float[] { weight }));
}
- batch.listener().onResponse(new TextEmbeddingFloatResults(embeddings));
+ batch.listener().onResponse(new DenseEmbeddingFloatResults(embeddings));
}
assertNotNull(finalListener.results);
@@ -410,8 +410,10 @@ public void testVeryLongInput_Float() {
ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
assertThat(chunkedEmbedding.chunks(), hasSize(1));
assertThat(getMatchedText(inputs.get(0).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small"));
- assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
- TextEmbeddingFloatResults.Embedding embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
+ assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
+ DenseEmbeddingFloatResults.Embedding embedding = (DenseEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks()
+ .get(0)
+ .embedding();
assertThat(embedding.values(), equalTo(new float[] { 1 / 16384f }));
// The very long passage "word0 word1 ... word199999" is split into 10000 chunks for
@@ -427,8 +429,8 @@ public void testVeryLongInput_Float() {
// is the average of the weights 2/16384 ... 21/16384.
assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 "));
assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399"));
- assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
- embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
+ assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
+ embedding = (DenseEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
assertThat(embedding.values(), equalTo(new float[] { (2 + 21) / (2 * 16384f) }));
// The last merged chunk consists of 19 small chunks (so 380 words) and the weight
@@ -438,8 +440,8 @@ public void testVeryLongInput_Float() {
startsWith(" word199620 word199621 ")
);
assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999"));
- assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
- embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(511).embedding();
+ assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
+ embedding = (DenseEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(511).embedding();
assertThat(embedding.values(), equalTo(new float[] { (9983 + 10001) / (2 * 16384f) }));
// The last input has the token with weight 10002/16384.
@@ -448,8 +450,8 @@ public void testVeryLongInput_Float() {
chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
assertThat(chunkedEmbedding.chunks(), hasSize(1));
assertThat(getMatchedText(inputs.get(2).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small"));
- assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
- embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
+ assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
+ embedding = (DenseEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
assertThat(embedding.values(), equalTo(new float[] { 10002 / 16384f }));
}
@@ -484,12 +486,12 @@ public void testVeryLongInput_Byte() {
// Produce inference results for each request, with increasing weights.
byte weight = 0;
for (var batch : batches) {
- var embeddings = new ArrayList();
+ var embeddings = new ArrayList();
for (int i = 0; i < batch.batch().requests().size(); i++) {
weight += 1;
- embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { weight }));
+ embeddings.add(new DenseEmbeddingByteResults.Embedding(new byte[] { weight }));
}
- batch.listener().onResponse(new TextEmbeddingByteResults(embeddings));
+ batch.listener().onResponse(new DenseEmbeddingByteResults(embeddings));
}
assertNotNull(finalListener.results);
@@ -501,8 +503,8 @@ public void testVeryLongInput_Byte() {
ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
assertThat(chunkedEmbedding.chunks(), hasSize(1));
assertThat(getMatchedText(inputs.get(0).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small"));
- assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class));
- TextEmbeddingByteResults.Embedding embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
+ assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class));
+ DenseEmbeddingByteResults.Embedding embedding = (DenseEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
assertThat(embedding.values(), equalTo(new byte[] { 1 }));
// The very long passage "word0 word1 ... word199999" is split into 10000 chunks for
@@ -518,8 +520,8 @@ public void testVeryLongInput_Byte() {
// is the average of the weights 2 ... 21, so 11.5, which is rounded to 12.
assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 "));
assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399"));
- assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class));
- embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
+ assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class));
+ embedding = (DenseEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
assertThat(embedding.values(), equalTo(new byte[] { 12 }));
// The last merged chunk consists of 19 small chunks (so 380 words) and the weight
@@ -530,8 +532,8 @@ public void testVeryLongInput_Byte() {
startsWith(" word199620 word199621 ")
);
assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999"));
- assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class));
- embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(511).embedding();
+ assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class));
+ embedding = (DenseEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(511).embedding();
assertThat(embedding.values(), equalTo(new byte[] { 8 }));
// The last input has the token with weight 10002 % 256 = 18
@@ -540,8 +542,8 @@ public void testVeryLongInput_Byte() {
chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
assertThat(chunkedEmbedding.chunks(), hasSize(1));
assertThat(getMatchedText(inputs.get(2).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small"));
- assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class));
- embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
+ assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class));
+ embedding = (DenseEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
assertThat(embedding.values(), equalTo(new byte[] { 18 }));
}
@@ -570,18 +572,18 @@ public void testMergingListener_Float() {
// 4 inputs in 2 batches
{
- var embeddings = new ArrayList();
+ var embeddings = new ArrayList();
for (int i = 0; i < batchSize; i++) {
- embeddings.add(new TextEmbeddingFloatResults.Embedding(new float[] { randomFloat() }));
+ embeddings.add(new DenseEmbeddingFloatResults.Embedding(new float[] { randomFloat() }));
}
- batches.get(0).listener().onResponse(new TextEmbeddingFloatResults(embeddings));
+ batches.get(0).listener().onResponse(new DenseEmbeddingFloatResults(embeddings));
}
{
- var embeddings = new ArrayList();
+ var embeddings = new ArrayList();
for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch
- embeddings.add(new TextEmbeddingFloatResults.Embedding(new float[] { randomFloat() }));
+ embeddings.add(new DenseEmbeddingFloatResults.Embedding(new float[] { randomFloat() }));
}
- batches.get(1).listener().onResponse(new TextEmbeddingFloatResults(embeddings));
+ batches.get(1).listener().onResponse(new DenseEmbeddingFloatResults(embeddings));
}
assertNotNull(finalListener.results);
@@ -650,18 +652,18 @@ public void testMergingListener_Byte() {
// 4 inputs in 2 batches
{
- var embeddings = new ArrayList();
+ var embeddings = new ArrayList();
for (int i = 0; i < batchSize; i++) {
- embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { randomByte() }));
+ embeddings.add(new DenseEmbeddingByteResults.Embedding(new byte[] { randomByte() }));
}
- batches.get(0).listener().onResponse(new TextEmbeddingByteResults(embeddings));
+ batches.get(0).listener().onResponse(new DenseEmbeddingByteResults(embeddings));
}
{
- var embeddings = new ArrayList();
+ var embeddings = new ArrayList();
for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch
- embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { randomByte() }));
+ embeddings.add(new DenseEmbeddingByteResults.Embedding(new byte[] { randomByte() }));
}
- batches.get(1).listener().onResponse(new TextEmbeddingByteResults(embeddings));
+ batches.get(1).listener().onResponse(new DenseEmbeddingByteResults(embeddings));
}
assertNotNull(finalListener.results);
@@ -727,18 +729,18 @@ public void testMergingListener_Bit() {
// 4 inputs in 2 batches
{
- var embeddings = new ArrayList();
+ var embeddings = new ArrayList();
for (int i = 0; i < batchSize; i++) {
- embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { randomByte() }));
+ embeddings.add(new DenseEmbeddingByteResults.Embedding(new byte[] { randomByte() }));
}
- batches.get(0).listener().onResponse(new TextEmbeddingBitResults(embeddings));
+ batches.get(0).listener().onResponse(new DenseEmbeddingBitResults(embeddings));
}
{
- var embeddings = new ArrayList();
+ var embeddings = new ArrayList();
for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch
- embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { randomByte() }));
+ embeddings.add(new DenseEmbeddingByteResults.Embedding(new byte[] { randomByte() }));
}
- batches.get(1).listener().onResponse(new TextEmbeddingBitResults(embeddings));
+ batches.get(1).listener().onResponse(new DenseEmbeddingBitResults(embeddings));
}
assertNotNull(finalListener.results);
@@ -892,10 +894,10 @@ public void onFailure(Exception e) {
var batches = new EmbeddingRequestChunker<>(inputs, 10, 100, 0).batchRequestsWithListeners(listener);
assertThat(batches, hasSize(1));
- var embeddings = new ArrayList();
- embeddings.add(new TextEmbeddingFloatResults.Embedding(new float[] { randomFloat() }));
- embeddings.add(new TextEmbeddingFloatResults.Embedding(new float[] { randomFloat() }));
- batches.get(0).listener().onResponse(new TextEmbeddingFloatResults(embeddings));
+ var embeddings = new ArrayList();
+ embeddings.add(new DenseEmbeddingFloatResults.Embedding(new float[] { randomFloat() }));
+ embeddings.add(new DenseEmbeddingFloatResults.Embedding(new float[] { randomFloat() }));
+ batches.get(0).listener().onResponse(new DenseEmbeddingFloatResults(embeddings));
assertEquals("Error the number of embedding responses [2] does not equal the number of requests [3]", failureMessage.get());
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java
index 027a19aca6d1f..f104e7b87136e 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java
@@ -48,7 +48,7 @@
import java.util.concurrent.atomic.AtomicReference;
import static org.elasticsearch.core.Strings.format;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java
index d1499f4009d0a..d596677f4125f 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java
@@ -25,10 +25,10 @@
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.utils.FloatConversionUtils;
import org.elasticsearch.xpack.inference.chunking.NoneChunkingSettings;
import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings;
@@ -211,7 +211,7 @@ public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingByte(Mode
}
chunks.add(
new EmbeddingResults.Chunk(
- new TextEmbeddingByteResults.Embedding(values),
+ new DenseEmbeddingByteResults.Embedding(values),
new ChunkedInference.TextOffset(0, input.length())
)
);
@@ -233,7 +233,7 @@ public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingFloat(Mod
}
chunks.add(
new EmbeddingResults.Chunk(
- new TextEmbeddingFloatResults.Embedding(values),
+ new DenseEmbeddingFloatResults.Embedding(values),
new ChunkedInference.TextOffset(0, input.length())
)
);
@@ -415,8 +415,8 @@ public static ChunkedInference toChunkedResult(
ChunkedInference.TextOffset offset = createOffset(useLegacyFormat, entryChunk, matchedText);
double[] values = parseDenseVector(entryChunk.rawEmbeddings(), embeddingLength, field.contentType());
EmbeddingResults.Embedding> embedding = switch (elementType) {
- case FLOAT -> new TextEmbeddingFloatResults.Embedding(FloatConversionUtils.floatArrayOf(values));
- case BYTE, BIT -> new TextEmbeddingByteResults.Embedding(byteArrayOf(values));
+ case FLOAT -> new DenseEmbeddingFloatResults.Embedding(FloatConversionUtils.floatArrayOf(values));
+ case BYTE, BIT -> new DenseEmbeddingByteResults.Embedding(byteArrayOf(values));
};
chunks.add(new EmbeddingResults.Chunk(embedding, offset));
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java
index 20354aa9b2dc7..de8583132319b 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java
@@ -22,7 +22,7 @@
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.search.vectors.VectorData;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
@@ -175,8 +175,8 @@ public void testInterceptAndRewrite() throws Exception {
new FullyQualifiedInferenceId(LOCAL_CLUSTER_GROUP_KEY, DENSE_INFERENCE_ID)
);
assertThat(inferenceResults, notNullValue());
- assertThat(inferenceResults, instanceOf(MlTextEmbeddingResults.class));
- VectorData queryVector = new VectorData(((MlTextEmbeddingResults) inferenceResults).getInferenceAsFloat());
+ assertThat(inferenceResults, instanceOf(MlDenseEmbeddingResults.class));
+ VectorData queryVector = new VectorData(((MlDenseEmbeddingResults) inferenceResults).getInferenceAsFloat());
// Perform data node rewrite on test index 1
final QueryRewriteContext indexMetadataContextTestIndex1 = createIndexMetadataContext(
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/MockInferenceClient.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/MockInferenceClient.java
index 17438d9786ba3..b22191735f1d9 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/MockInferenceClient.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/MockInferenceClient.java
@@ -20,11 +20,11 @@
import org.elasticsearch.test.client.NoOpClient;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import java.util.Arrays;
@@ -58,8 +58,8 @@ protected void
InferenceResults inferenceResults = generateInferenceResults(inferenceId, input);
if (inferenceResults instanceof TextExpansionResults textExpansionResults) {
inferenceServiceResults = SparseEmbeddingResults.of(List.of(textExpansionResults));
- } else if (inferenceResults instanceof MlTextEmbeddingResults mlTextEmbeddingResults) {
- inferenceServiceResults = TextEmbeddingFloatResults.of(List.of(mlTextEmbeddingResults));
+ } else if (inferenceResults instanceof MlDenseEmbeddingResults mlDenseEmbeddingResults) {
+ inferenceServiceResults = DenseEmbeddingFloatResults.of(List.of(mlDenseEmbeddingResults));
} else {
throw new IllegalStateException("Unexpected inference results type [" + inferenceResults.getWriteableName() + "]");
}
@@ -126,6 +126,6 @@ private static InferenceResults generateTextEmbeddingResults(MinimalServiceSetti
double[] embedding = new double[embeddingSize];
Arrays.fill(embedding, Byte.MIN_VALUE); // Always use a byte value so that the embedding is valid regardless of the element type
- return new MlTextEmbeddingResults(DEFAULT_RESULTS_FIELD, embedding, false);
+ return new MlDenseEmbeddingResults(DEFAULT_RESULTS_FIELD, embedding, false);
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java
index b2d7218720a57..160261eb1b1cc 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java
@@ -67,10 +67,10 @@
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.XPackClientPlugin;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
@@ -348,9 +348,9 @@ private InferenceAction.Response generateTextEmbeddingInferenceResponse() {
int inferenceLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(denseVectorElementType, TEXT_EMBEDDING_DIMENSION_COUNT);
double[] inference = new double[inferenceLength];
Arrays.fill(inference, 1.0);
- MlTextEmbeddingResults textEmbeddingResults = new MlTextEmbeddingResults(DEFAULT_RESULTS_FIELD, inference, false);
+ MlDenseEmbeddingResults textEmbeddingResults = new MlDenseEmbeddingResults(DEFAULT_RESULTS_FIELD, inference, false);
- return new InferenceAction.Response(TextEmbeddingFloatResults.of(List.of(textEmbeddingResults)));
+ return new InferenceAction.Response(DenseEmbeddingFloatResults.of(List.of(textEmbeddingResults)));
}
@Override
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java
index 4e60b09530684..ece25247bb5b5 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java
@@ -23,7 +23,7 @@
import org.elasticsearch.xpack.core.inference.InferenceContext;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.junit.Before;
@@ -234,7 +234,7 @@ public void testExtractProductUseCase_EmptyWhenHeaderValueEmpty() {
static InferenceAction.Response createResponse() {
return new InferenceAction.Response(
- new TextEmbeddingByteResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -1 })))
+ new DenseEmbeddingByteResults(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -1 })))
);
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java
index c9f5ea738b33b..94f65acb67cff 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java
@@ -33,9 +33,9 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.InputTypeTests;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
@@ -516,7 +516,7 @@ private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettin
var firstResult = results.getFirst();
assertThat(firstResult, instanceOf(ChunkedInferenceEmbedding.class));
Class> expectedClass = switch (taskType) {
- case TEXT_EMBEDDING -> TextEmbeddingFloatResults.Chunk.class;
+ case TEXT_EMBEDDING -> DenseEmbeddingFloatResults.Chunk.class;
case SPARSE_EMBEDDING -> SparseEmbeddingResults.Chunk.class;
default -> null;
};
@@ -650,10 +650,10 @@ private AlibabaCloudSearchModel createEmbeddingsModel(
) {
public ExecutableAction accept(AlibabaCloudSearchActionVisitor visitor, Map taskSettings) {
return (inferenceInputs, timeout, listener) -> {
- TextEmbeddingFloatResults results = new TextEmbeddingFloatResults(
+ DenseEmbeddingFloatResults results = new DenseEmbeddingFloatResults(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123f, -0.0123f }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.0456f, -0.0456f })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123f, -0.0123f }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0456f, -0.0456f })
)
);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java
index b09fbf43a8ca4..dec6539732156 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java
@@ -21,11 +21,11 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
@@ -50,9 +50,9 @@
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests.buildExpectationRerank;
import static org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests.buildExpectationSparseEmbeddings;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.hamcrest.Matchers.is;
@@ -88,7 +88,7 @@ public void testExecute_withTextEmbeddingsAction_Success() {
float[] values = { 0.1111111f, 0.2222222f, 0.3333333f };
doAnswer(invocation -> {
ActionListener listener = invocation.getArgument(3);
- listener.onResponse(new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(values))));
+ listener.onResponse(new DenseEmbeddingFloatResults(List.of(new DenseEmbeddingFloatResults.Embedding(values))));
return Void.TYPE;
}).when(sender).send(any(), any(), any(), any());
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntityTests.java
index ed8a1185bd846..28e5c9422a970 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntityTests.java
@@ -9,7 +9,7 @@
import org.apache.http.HttpResponse;
import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.request.AlibabaCloudSearchRequest;
@@ -50,14 +50,14 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException,
URI uri = new URI("mock_uri");
when(request.getURI()).thenReturn(uri);
- TextEmbeddingFloatResults parsedResults = AlibabaCloudSearchEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = AlibabaCloudSearchEmbeddingsResponseEntity.fromResponse(
request,
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
assertThat(
parsedResults.embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { -0.02868066355586052f, 0.022033605724573135f })))
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { -0.02868066355586052f, 0.022033605724573135f })))
);
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java
index 04c8bca2287e5..33d2d120e1b8d 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java
@@ -38,7 +38,7 @@
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.Utils;
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -70,7 +70,7 @@
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
@@ -1026,7 +1026,7 @@ public void testInfer_SendsRequest_ForTitanEmbeddingsModel() throws IOException
);
var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()
) {
- var results = new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })));
+ var results = new DenseEmbeddingFloatResults(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })));
requestSender.enqueue(results);
PlainActionFuture listener = new PlainActionFuture<>();
service.infer(
@@ -1067,8 +1067,8 @@ public void testInfer_SendsRequest_ForCohereEmbeddingsModel() throws IOException
)
) {
try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) {
- var results = new TextEmbeddingFloatResults(
- List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F }))
+ var results = new DenseEmbeddingFloatResults(
+ List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F }))
);
requestSender.enqueue(results);
@@ -1343,14 +1343,14 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep
) {
try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) {
{
- var mockResults1 = new TextEmbeddingFloatResults(
- List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F }))
+ var mockResults1 = new DenseEmbeddingFloatResults(
+ List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F }))
);
requestSender.enqueue(mockResults1);
}
{
- var mockResults2 = new TextEmbeddingFloatResults(
- List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.223F, 0.278F }))
+ var mockResults2 = new DenseEmbeddingFloatResults(
+ List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.223F, 0.278F }))
);
requestSender.enqueue(mockResults2);
}
@@ -1373,10 +1373,10 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 0.123F, 0.678F },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
@@ -1385,10 +1385,10 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 0.223F, 0.278F },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java
index 5dd42dc66485f..8f1ba6b277c37 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java
@@ -15,7 +15,7 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.InputTypeTests;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
@@ -33,7 +33,7 @@
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.hamcrest.Matchers.is;
@@ -53,8 +53,8 @@ public void shutdown() throws IOException {
public void testEmbeddingsRequestAction_Titan() throws IOException {
var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool);
- var mockedFloatResults = List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }));
- var mockedResult = new TextEmbeddingFloatResults(mockedFloatResults);
+ var mockedFloatResults = List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }));
+ var mockedResult = new DenseEmbeddingFloatResults(mockedFloatResults);
try (var sender = new AmazonBedrockMockRequestSender()) {
sender.enqueue(mockedResult);
var creator = new AmazonBedrockActionCreator(sender, serviceComponents, TIMEOUT);
@@ -91,8 +91,8 @@ public void testEmbeddingsRequestAction_Titan() throws IOException {
public void testEmbeddingsRequestAction_Cohere() throws IOException {
var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool);
- var mockedFloatResults = List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }));
- var mockedResult = new TextEmbeddingFloatResults(mockedFloatResults);
+ var mockedFloatResults = List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }));
+ var mockedResult = new DenseEmbeddingFloatResults(mockedFloatResults);
try (var sender = new AmazonBedrockMockRequestSender()) {
sender.enqueue(mockedResult);
var creator = new AmazonBedrockActionCreator(sender, serviceComponents, TIMEOUT);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockExecutorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockExecutorTests.java
index 19fe2ca784b1c..d267f80080885 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockExecutorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockExecutorTests.java
@@ -34,7 +34,7 @@
import java.util.List;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.common.TruncatorTests.createTruncator;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java
index d608b6ec9fb8e..38a9ca17fec7b 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java
@@ -32,7 +32,7 @@
import java.util.concurrent.atomic.AtomicReference;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.services.amazonbedrock.client.AmazonBedrockExecutorTests.TEST_AMAZON_TITAN_EMBEDDINGS_RESULT;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java
index 06668b91e1965..734d66a20bf56 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java
@@ -39,8 +39,8 @@
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@@ -1347,10 +1347,10 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 0.0123f, -0.0123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
@@ -1359,10 +1359,10 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 1.0123f, -1.0123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java
index 975661d180795..ef82cc18946b1 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java
@@ -46,7 +46,7 @@
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioEmbeddingsResponseEntityTests.java
index 3da3598a4637a..e2ebd72b17da5 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioEmbeddingsResponseEntityTests.java
@@ -9,7 +9,7 @@
import org.apache.http.HttpResponse;
import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
@@ -50,11 +50,11 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
var entity = new AzureAiStudioEmbeddingsResponseEntity();
- var parsedResults = (TextEmbeddingFloatResults) entity.apply(
+ var parsedResults = (DenseEmbeddingFloatResults) entity.apply(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingFloatResults.Embedding.of(List.of(0.014539449F, -0.015288644F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(DenseEmbeddingFloatResults.Embedding.of(List.of(0.014539449F, -0.015288644F)))));
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java
index b5d8fb887c62c..6c0b4f272650a 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java
@@ -37,7 +37,7 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@@ -63,7 +63,7 @@
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
@@ -1007,10 +1007,10 @@ private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOExcepti
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 0.123f, -0.123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
@@ -1019,10 +1019,10 @@ private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOExcepti
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 1.123f, -1.123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java
index 1724747f9861a..23c773ab0d61b 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java
@@ -42,7 +42,7 @@
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.core.Strings.format;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java
index 20ac31398de2f..c6db0f441d348 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java
@@ -44,7 +44,7 @@
import java.util.List;
import java.util.concurrent.TimeUnit;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
index 0f897500698ee..a340236e31eaa 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
@@ -39,8 +39,8 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@@ -69,7 +69,7 @@
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
@@ -1351,7 +1351,7 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException {
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
assertArrayEquals(
new float[] { 0.123f, -0.123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
@@ -1362,7 +1362,7 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException {
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
assertArrayEquals(
new float[] { 0.223f, -0.223f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
@@ -1448,10 +1448,10 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException {
var byteResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(byteResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), byteResult.chunks().get(0).offset());
- assertThat(byteResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class));
+ assertThat(byteResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class));
assertArrayEquals(
new byte[] { 23, -23 },
- ((TextEmbeddingByteResults.Embedding) byteResult.chunks().get(0).embedding()).values()
+ ((DenseEmbeddingByteResults.Embedding) byteResult.chunks().get(0).embedding()).values()
);
}
{
@@ -1459,10 +1459,10 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException {
var byteResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(byteResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), byteResult.chunks().get(0).offset());
- assertThat(byteResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class));
+ assertThat(byteResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class));
assertArrayEquals(
new byte[] { 24, -24 },
- ((TextEmbeddingByteResults.Embedding) byteResult.chunks().get(0).embedding()).values()
+ ((DenseEmbeddingByteResults.Embedding) byteResult.chunks().get(0).embedding()).values()
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java
index f3d816d9a118d..a6f3e6a07d66c 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java
@@ -40,7 +40,7 @@
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java
index f5191ecac5ce6..008bce852b8d6 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java
@@ -44,8 +44,8 @@
import java.util.Map;
import java.util.concurrent.TimeUnit;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationByte;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationByte;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntityTests.java
index 6df356bfe0a80..78bf38f204bf7 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntityTests.java
@@ -10,9 +10,9 @@
import org.apache.http.HttpResponse;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.hamcrest.MatcherAssert;
@@ -56,10 +56,10 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- MatcherAssert.assertThat(parsedResults, instanceOf(TextEmbeddingFloatResults.class));
+ MatcherAssert.assertThat(parsedResults, instanceOf(DenseEmbeddingFloatResults.class));
MatcherAssert.assertThat(
- ((TextEmbeddingFloatResults) parsedResults).embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F })))
+ ((DenseEmbeddingFloatResults) parsedResults).embeddings(),
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F })))
);
}
@@ -90,14 +90,14 @@ public void testFromResponse_CreatesResultsForASingleItem_ObjectFormat() throws
}
""";
- TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
MatcherAssert.assertThat(
parsedResults.embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F })))
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F })))
);
}
@@ -134,14 +134,14 @@ public void testFromResponse_UsesTheFirstValidEmbeddingsEntry() throws IOExcepti
}
""";
- TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
MatcherAssert.assertThat(
parsedResults.embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F })))
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F })))
);
}
@@ -178,14 +178,14 @@ public void testFromResponse_UsesTheFirstValidEmbeddingsEntryInt8_WithInvalidFir
}
""";
- TextEmbeddingByteResults parsedResults = (TextEmbeddingByteResults) CohereEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingByteResults parsedResults = (DenseEmbeddingByteResults) CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
MatcherAssert.assertThat(
parsedResults.embeddings(),
- is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -1, (byte) 0 })))
+ is(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -1, (byte) 0 })))
);
}
@@ -216,14 +216,14 @@ public void testFromResponse_ParsesBytes() throws IOException {
}
""";
- TextEmbeddingByteResults parsedResults = (TextEmbeddingByteResults) CohereEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingByteResults parsedResults = (DenseEmbeddingByteResults) CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
MatcherAssert.assertThat(
parsedResults.embeddings(),
- is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -1, (byte) 0 })))
+ is(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -1, (byte) 0 })))
);
}
@@ -257,14 +257,14 @@ public void testFromResponse_ParsesBytes_FromBinaryEmbeddingsEntry() throws IOEx
}
""";
- TextEmbeddingBitResults parsedResults = (TextEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingBitResults parsedResults = (DenseEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
MatcherAssert.assertThat(
parsedResults.embeddings(),
- is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 })))
+ is(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 })))
);
}
@@ -297,7 +297,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
}
""";
- TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -306,8 +306,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
parsedResults.embeddings(),
is(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { -0.123F, 0.123F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { -0.123F, 0.123F })
)
)
);
@@ -344,7 +344,7 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw
}
""";
- TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -353,8 +353,8 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw
parsedResults.embeddings(),
is(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { -0.123F, 0.123F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { -0.123F, 0.123F })
)
)
);
@@ -397,7 +397,7 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat_Binary(
}
""";
- TextEmbeddingBitResults parsedResults = (TextEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingBitResults parsedResults = (DenseEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -406,8 +406,8 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat_Binary(
parsedResults.embeddings(),
is(
List.of(
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67 }),
- new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 34, (byte) -64, (byte) 97, (byte) 65, (byte) -42 })
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67 }),
+ new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 34, (byte) -64, (byte) 97, (byte) 65, (byte) -42 })
)
)
);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java
index 1c5c13e2086c2..fb9802c6938d5 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java
@@ -14,7 +14,7 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
-import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
+import org.elasticsearch.xpack.inference.services.custom.response.DenseEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.hamcrest.MatcherAssert;
@@ -94,7 +94,7 @@ public static CustomModel createModel(
public static CustomModel getTestModel() {
return getTestModel(
TaskType.TEXT_EMBEDDING,
- new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT)
+ new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT)
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java
index 65d9de30576ff..74c3e0f57e73c 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java
@@ -26,10 +26,10 @@
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser;
+import org.elasticsearch.xpack.inference.services.custom.response.DenseEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser;
-import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.hamcrest.MatcherAssert;
@@ -60,7 +60,7 @@ public static CustomServiceSettings createRandom(String inputUrl) {
var requestContentString = randomAlphaOfLength(10);
var responseJsonParser = switch (taskType) {
- case TEXT_EMBEDDING -> new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT);
+ case TEXT_EMBEDDING -> new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT);
case SPARSE_EMBEDDING -> new SparseEmbeddingResponseParser(
"$.result.sparse_embeddings[*].embedding[*].token_id",
"$.result.sparse_embeddings[*].embedding[*].weights"
@@ -100,7 +100,7 @@ public void testFromMap() {
var queryParameters = List.of(List.of("key", "value"));
String requestContentString = "request body";
- var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT);
+ var responseParser = new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT);
var settings = CustomServiceSettings.fromMap(
new HashMap<>(
@@ -124,7 +124,7 @@ public void testFromMap() {
Map.of(
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
- Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
+ Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
)
)
),
@@ -158,7 +158,7 @@ public void testFromMap_EmbeddingType_Bit() {
String url = "http://www.abc.com";
String requestContentString = "request body";
- var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.BIT);
+ var responseParser = new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.BIT);
var settings = CustomServiceSettings.fromMap(
new HashMap<>(
@@ -173,9 +173,9 @@ public void testFromMap_EmbeddingType_Bit() {
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
Map.of(
- TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS,
+ DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS,
"$.result.embeddings[*].embedding",
- TextEmbeddingResponseParser.EMBEDDING_TYPE,
+ DenseEmbeddingResponseParser.EMBEDDING_TYPE,
CustomServiceEmbeddingType.BIT.toString()
)
)
@@ -207,7 +207,7 @@ public void testFromMap_EmbeddingType_Binary() {
String url = "http://www.abc.com";
String requestContentString = "request body";
- var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.BINARY);
+ var responseParser = new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.BINARY);
var settings = CustomServiceSettings.fromMap(
new HashMap<>(
@@ -222,9 +222,9 @@ public void testFromMap_EmbeddingType_Binary() {
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
Map.of(
- TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS,
+ DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS,
"$.result.embeddings[*].embedding",
- TextEmbeddingResponseParser.EMBEDDING_TYPE,
+ DenseEmbeddingResponseParser.EMBEDDING_TYPE,
CustomServiceEmbeddingType.BINARY.toString()
)
)
@@ -256,7 +256,7 @@ public void testFromMap_EmbeddingType_Byte() {
String url = "http://www.abc.com";
String requestContentString = "request body";
- var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.BYTE);
+ var responseParser = new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.BYTE);
var settings = CustomServiceSettings.fromMap(
new HashMap<>(
@@ -271,9 +271,9 @@ public void testFromMap_EmbeddingType_Byte() {
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
Map.of(
- TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS,
+ DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS,
"$.result.embeddings[*].embedding",
- TextEmbeddingResponseParser.EMBEDDING_TYPE,
+ DenseEmbeddingResponseParser.EMBEDDING_TYPE,
CustomServiceEmbeddingType.BYTE.toString()
)
)
@@ -367,7 +367,7 @@ public void testFromMap_Completion_ThrowsWhenEmbeddingIsIncludedInMap() {
Map.of(
CompletionResponseParser.COMPLETION_PARSER_RESULT,
"$.result.text",
- TextEmbeddingResponseParser.EMBEDDING_TYPE,
+ DenseEmbeddingResponseParser.EMBEDDING_TYPE,
"byte"
)
)
@@ -393,7 +393,7 @@ public void testFromMap_WithOptionalsNotSpecified() {
String url = "http://www.abc.com";
String requestContentString = "request body";
- var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT);
+ var responseParser = new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT);
var settings = CustomServiceSettings.fromMap(
new HashMap<>(
@@ -407,7 +407,7 @@ public void testFromMap_WithOptionalsNotSpecified() {
Map.of(
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
- Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
+ Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
)
)
)
@@ -445,7 +445,7 @@ public void testFromMap_RemovesNullValues_FromMaps() {
String requestContentString = "request body";
- var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT);
+ var responseParser = new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT);
var settings = CustomServiceSettings.fromMap(
new HashMap<>(
@@ -467,7 +467,7 @@ public void testFromMap_RemovesNullValues_FromMaps() {
Map.of(
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
- Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
+ Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
)
)
)
@@ -519,7 +519,7 @@ public void testFromMap_ReturnsError_IfHeadersContainsNonStringValues() {
Map.of(
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
- Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
+ Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
)
)
)
@@ -566,7 +566,7 @@ public void testFromMap_ReturnsError_IfQueryParamsContainsNonStringValues() {
Map.of(
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
- Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
+ Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
)
)
)
@@ -604,7 +604,7 @@ public void testFromMap_ReturnsError_IfRequestMapIsMissing() {
Map.of(
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
- Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
+ Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
)
)
)
@@ -636,7 +636,7 @@ public void testFromMap_ReturnsError_IfResponseMapIsMissing() {
Map.of(
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
- Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
+ Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
)
)
)
@@ -675,7 +675,7 @@ public void testFromMap_ReturnsError_IfJsonParserMapIsNotEmptyAfterParsing() {
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
Map.of(
- TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS,
+ DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS,
"$.result.embeddings[*].embedding",
"key",
"value"
@@ -717,7 +717,7 @@ public void testFromMap_ReturnsError_IfResponseMapIsNotEmptyAfterParsing() {
Map.of(
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
- Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
+ Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
),
"key",
"value"
@@ -757,7 +757,7 @@ public void testFromMap_ReturnsError_IfTaskTypeIsInvalid() {
Map.of(
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
- Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
+ Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
)
)
)
@@ -779,7 +779,7 @@ public void testXContent() throws IOException {
Map.of("key", "value"),
null,
"string",
- new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT),
+ new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT),
null
);
@@ -868,7 +868,7 @@ public void testXContent_WithInputTypeTranslationValues() throws IOException {
Map.of("key", "value"),
null,
"string",
- new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT),
+ new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT),
null,
null,
new InputTypeTranslator(Map.of(InputType.SEARCH, "do_search", InputType.INGEST, "do_ingest"), "a_default")
@@ -915,7 +915,7 @@ public void testXContent_BatchSize11() throws IOException {
Map.of("key", "value"),
null,
"string",
- new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT),
+ new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT),
null,
11,
InputTypeTranslator.EMPTY_TRANSLATOR
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java
index 44bf17c3ac96d..1d43330381fa4 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java
@@ -28,9 +28,9 @@
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@@ -39,9 +39,9 @@
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
+import org.elasticsearch.xpack.inference.services.custom.response.DenseEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser;
-import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.hamcrest.CoreMatchers;
import org.hamcrest.Matchers;
@@ -138,7 +138,7 @@ private static void assertTextEmbeddingModel(Model model, boolean modelIncludesS
var customModel = assertCommonModelFields(model, modelIncludesSecrets);
assertThat(customModel.getTaskType(), is(TaskType.TEXT_EMBEDDING));
- assertThat(customModel.getServiceSettings().getResponseJsonParser(), instanceOf(TextEmbeddingResponseParser.class));
+ assertThat(customModel.getServiceSettings().getResponseJsonParser(), instanceOf(DenseEmbeddingResponseParser.class));
}
private static CustomModel assertCommonModelFields(Model model, boolean modelIncludesSecrets) {
@@ -204,7 +204,7 @@ private static Map createServiceSettingsMap(TaskType taskType) {
private static Map createResponseParserMap(TaskType taskType) {
return switch (taskType) {
case TEXT_EMBEDDING -> new HashMap<>(
- Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
+ Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
);
case COMPLETION -> new HashMap<>(Map.of(CompletionResponseParser.COMPLETION_PARSER_RESULT, "$.result.text"));
case SPARSE_EMBEDDING -> new HashMap<>(
@@ -240,18 +240,18 @@ private static Map createSecretSettingsMap() {
private static CustomModel createInternalEmbeddingModel(SimilarityMeasure similarityMeasure) {
return createInternalEmbeddingModel(
similarityMeasure,
- new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT),
+ new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT),
"http://www.abc.com"
);
}
- private static CustomModel createInternalEmbeddingModel(TextEmbeddingResponseParser parser, String url) {
+ private static CustomModel createInternalEmbeddingModel(DenseEmbeddingResponseParser parser, String url) {
return createInternalEmbeddingModel(SimilarityMeasure.DOT_PRODUCT, parser, url);
}
private static CustomModel createInternalEmbeddingModel(
@Nullable SimilarityMeasure similarityMeasure,
- TextEmbeddingResponseParser parser,
+ DenseEmbeddingResponseParser parser,
String url
) {
var inferenceId = "inference_id";
@@ -276,7 +276,7 @@ private static CustomModel createInternalEmbeddingModel(
private static CustomModel createInternalEmbeddingModel(
@Nullable SimilarityMeasure similarityMeasure,
- TextEmbeddingResponseParser parser,
+ DenseEmbeddingResponseParser parser,
String url,
@Nullable ChunkingSettings chunkingSettings,
@Nullable Integer batchSize
@@ -336,7 +336,7 @@ public void testInfer_ReturnsAnError_WithoutParsingTheResponseBody() throws IOEx
webServer.enqueue(new MockResponse().setResponseCode(400).setBody(responseJson));
var model = createInternalEmbeddingModel(
- new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT),
+ new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT),
getUrl(webServer)
);
PlainActionFuture listener = new PlainActionFuture<>();
@@ -394,7 +394,7 @@ public void testInfer_HandlesTextEmbeddingRequest_OpenAI_Format() throws IOExcep
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
var model = createInternalEmbeddingModel(
- new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT),
+ new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT),
getUrl(webServer)
);
PlainActionFuture listener = new PlainActionFuture<>();
@@ -412,12 +412,12 @@ public void testInfer_HandlesTextEmbeddingRequest_OpenAI_Format() throws IOExcep
);
InferenceServiceResults results = listener.actionGet(TIMEOUT);
- assertThat(results, instanceOf(TextEmbeddingFloatResults.class));
+ assertThat(results, instanceOf(DenseEmbeddingFloatResults.class));
- var embeddingResults = (TextEmbeddingFloatResults) results;
+ var embeddingResults = (DenseEmbeddingFloatResults) results;
assertThat(
embeddingResults.embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })))
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })))
);
}
}
@@ -676,7 +676,7 @@ public void testParseRequestConfig_ThrowsAValidationError_WhenReplacementDoesNot
public void testChunkedInfer_ChunkingSettingsSet() throws IOException {
var model = createInternalEmbeddingModel(
SimilarityMeasure.DOT_PRODUCT,
- new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT),
+ new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT),
getUrl(webServer),
ChunkingSettingsTests.createRandomChunkingSettings(),
2
@@ -732,10 +732,10 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException {
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 0.123f, -0.123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
@@ -744,10 +744,10 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException {
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 0.223f, -0.223f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
@@ -762,7 +762,7 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException {
public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
var model = createInternalEmbeddingModel(
- new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT),
+ new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT),
getUrl(webServer)
);
String responseJson = """
@@ -807,10 +807,10 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 0.123f, -0.123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java
index 8b35979c3daf5..8ad0b028d3b34 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java
@@ -27,8 +27,8 @@
import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings;
import org.elasticsearch.xpack.inference.services.custom.InputTypeTranslator;
import org.elasticsearch.xpack.inference.services.custom.QueryParameters;
+import org.elasticsearch.xpack.inference.services.custom.response.DenseEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
-import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import java.io.IOException;
@@ -62,7 +62,7 @@ public void testCreateRequest() throws IOException {
headers,
new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"), new QueryParameters.Parameter("key", "value2"))),
requestContentString,
- new TextEmbeddingResponseParser("$.result.embeddings", CustomServiceEmbeddingType.FLOAT),
+ new DenseEmbeddingResponseParser("$.result.embeddings", CustomServiceEmbeddingType.FLOAT),
new RateLimitSettings(10_000),
null,
new InputTypeTranslator(Map.of(InputType.INGEST, "value"), "default")
@@ -122,7 +122,7 @@ public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() throws IOEx
)
),
requestContentString,
- new TextEmbeddingResponseParser("$.result.embeddings", CustomServiceEmbeddingType.FLOAT),
+ new DenseEmbeddingResponseParser("$.result.embeddings", CustomServiceEmbeddingType.FLOAT),
new RateLimitSettings(10_000),
null,
new InputTypeTranslator(Map.of(InputType.INGEST, "value"), "default")
@@ -180,7 +180,7 @@ public void testCreateRequest_SecretsInTheJsonBody_AreEncodedCorrectly() throws
headers,
new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"), new QueryParameters.Parameter("key", "value2"))),
requestContentString,
- new TextEmbeddingResponseParser("$.result.embeddings", CustomServiceEmbeddingType.FLOAT),
+ new DenseEmbeddingResponseParser("$.result.embeddings", CustomServiceEmbeddingType.FLOAT),
new RateLimitSettings(10_000)
);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java
index e53add6733aca..b475a770c7cfe 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java
@@ -13,9 +13,9 @@
import org.elasticsearch.inference.WeightedToken;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
@@ -62,7 +62,7 @@ public void testFromTextEmbeddingResponse() throws IOException {
var model = CustomModelTests.getTestModel(
TaskType.TEXT_EMBEDDING,
- new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT)
+ new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT)
);
var request = new CustomRequest(
EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null), model.getServiceSettings().getInputTypeTranslator()),
@@ -73,10 +73,10 @@ public void testFromTextEmbeddingResponse() throws IOException {
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(results, instanceOf(TextEmbeddingFloatResults.class));
+ assertThat(results, instanceOf(DenseEmbeddingFloatResults.class));
assertThat(
- ((TextEmbeddingFloatResults) results).embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { -0.02868066355586052f, 0.022033605724573135f })))
+ ((DenseEmbeddingFloatResults) results).embeddings(),
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { -0.02868066355586052f, 0.022033605724573135f })))
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/DenseEmbeddingResponseParserTests.java
similarity index 71%
rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java
rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/DenseEmbeddingResponseParserTests.java
index 7796b5e1e7f6b..cffb2a5db6af5 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/DenseEmbeddingResponseParserTests.java
@@ -16,9 +16,9 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.services.custom.CustomServiceEmbeddingType;
@@ -29,24 +29,24 @@
import java.util.List;
import java.util.Map;
-import static org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser.EMBEDDING_TYPE;
-import static org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS;
+import static org.elasticsearch.xpack.inference.services.custom.response.DenseEmbeddingResponseParser.EMBEDDING_TYPE;
+import static org.elasticsearch.xpack.inference.services.custom.response.DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
-public class TextEmbeddingResponseParserTests extends AbstractBWCWireSerializationTestCase {
+public class DenseEmbeddingResponseParserTests extends AbstractBWCWireSerializationTestCase {
private static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE = TransportVersion.fromName(
"ml_inference_custom_service_embedding_type"
);
- public static TextEmbeddingResponseParser createRandom() {
- return new TextEmbeddingResponseParser("$." + randomAlphaOfLength(5), randomFrom(CustomServiceEmbeddingType.values()));
+ public static DenseEmbeddingResponseParser createRandom() {
+ return new DenseEmbeddingResponseParser("$." + randomAlphaOfLength(5), randomFrom(CustomServiceEmbeddingType.values()));
}
public void testFromMap() {
var validation = new ValidationException();
- var parser = TextEmbeddingResponseParser.fromMap(
+ var parser = DenseEmbeddingResponseParser.fromMap(
new HashMap<>(
Map.of(
TEXT_EMBEDDING_PARSER_EMBEDDINGS,
@@ -59,14 +59,14 @@ public void testFromMap() {
validation
);
- assertThat(parser, is(new TextEmbeddingResponseParser("$.result[*].embeddings", CustomServiceEmbeddingType.BIT)));
+ assertThat(parser, is(new DenseEmbeddingResponseParser("$.result[*].embeddings", CustomServiceEmbeddingType.BIT)));
}
public void testFromMap_ThrowsException_WhenRequiredFieldIsNotPresent() {
var validation = new ValidationException();
var exception = expectThrows(
ValidationException.class,
- () -> TextEmbeddingResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.result[*].embeddings")), "scope", validation)
+ () -> DenseEmbeddingResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.result[*].embeddings")), "scope", validation)
);
assertThat(
@@ -76,7 +76,7 @@ public void testFromMap_ThrowsException_WhenRequiredFieldIsNotPresent() {
}
public void testToXContent() throws IOException {
- var entity = new TextEmbeddingResponseParser("$.result.path", CustomServiceEmbeddingType.BINARY);
+ var entity = new DenseEmbeddingResponseParser("$.result.path", CustomServiceEmbeddingType.BINARY);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
{
@@ -120,14 +120,18 @@ public void testParse() throws IOException {
}
""";
- var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT);
- TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) parser.parse(
+ var parser = new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT);
+ DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) parser.parse(
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
assertThat(
parsedResults,
- is(new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))))
+ is(
+ new DenseEmbeddingFloatResults(
+ List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))
+ )
+ )
);
}
@@ -153,12 +157,15 @@ public void testParseByte() throws IOException {
}
""";
- var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.BYTE);
- TextEmbeddingByteResults parsedResults = (TextEmbeddingByteResults) parser.parse(
+ var parser = new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.BYTE);
+ DenseEmbeddingByteResults parsedResults = (DenseEmbeddingByteResults) parser.parse(
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults, is(new TextEmbeddingByteResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { 1, -2 })))));
+ assertThat(
+ parsedResults,
+ is(new DenseEmbeddingByteResults(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { 1, -2 }))))
+ );
}
public void testParseBit() throws IOException {
@@ -183,12 +190,12 @@ public void testParseBit() throws IOException {
}
""";
- var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.BIT);
- TextEmbeddingBitResults parsedResults = (TextEmbeddingBitResults) parser.parse(
+ var parser = new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.BIT);
+ DenseEmbeddingBitResults parsedResults = (DenseEmbeddingBitResults) parser.parse(
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults, is(new TextEmbeddingBitResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { 1, -2 })))));
+ assertThat(parsedResults, is(new DenseEmbeddingBitResults(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { 1, -2 })))));
}
public void testParse_MultipleEmbeddings() throws IOException {
@@ -221,18 +228,18 @@ public void testParse_MultipleEmbeddings() throws IOException {
}
""";
- var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT);
- TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) parser.parse(
+ var parser = new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT);
+ DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) parser.parse(
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
assertThat(
parsedResults,
is(
- new TextEmbeddingFloatResults(
+ new DenseEmbeddingFloatResults(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 1F, -2F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 1F, -2F })
)
)
)
@@ -269,7 +276,7 @@ public void testParse_ThrowsException_WhenExtractedField_IsNotAListOfFloats() {
}
""";
- var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT);
+ var parser = new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT);
var exception = expectThrows(
IllegalArgumentException.class,
() -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)))
@@ -303,7 +310,7 @@ public void testParse_ThrowsException_WhenExtractedField_IsNotAList() {
}
""";
- var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT);
+ var parser = new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT);
var exception = expectThrows(
IllegalArgumentException.class,
() -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)))
@@ -319,25 +326,25 @@ public void testParse_ThrowsException_WhenExtractedField_IsNotAList() {
}
@Override
- protected TextEmbeddingResponseParser mutateInstanceForVersion(TextEmbeddingResponseParser instance, TransportVersion version) {
+ protected DenseEmbeddingResponseParser mutateInstanceForVersion(DenseEmbeddingResponseParser instance, TransportVersion version) {
if (version.supports(ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE) == false) {
- return new TextEmbeddingResponseParser(instance.getTextEmbeddingsPath(), CustomServiceEmbeddingType.FLOAT);
+ return new DenseEmbeddingResponseParser(instance.getTextEmbeddingsPath(), CustomServiceEmbeddingType.FLOAT);
}
return instance;
}
@Override
- protected Writeable.Reader instanceReader() {
- return TextEmbeddingResponseParser::new;
+ protected Writeable.Reader instanceReader() {
+ return DenseEmbeddingResponseParser::new;
}
@Override
- protected TextEmbeddingResponseParser createTestInstance() {
+ protected DenseEmbeddingResponseParser createTestInstance() {
return createRandom();
}
@Override
- protected TextEmbeddingResponseParser mutateInstance(TextEmbeddingResponseParser instance) throws IOException {
- return randomValueOtherThan(instance, TextEmbeddingResponseParserTests::createRandom);
+ protected DenseEmbeddingResponseParser mutateInstance(DenseEmbeddingResponseParser instance) throws IOException {
+ return randomValueOtherThan(instance, DenseEmbeddingResponseParserTests::createRandom);
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java
index b8a82d6a7a29c..551152aafa533 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java
@@ -40,8 +40,8 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
@@ -889,9 +889,9 @@ public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOExceptio
var denseResult = (ChunkedInferenceEmbedding) results.getFirst();
assertThat(denseResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, "hello world".length()), denseResult.chunks().getFirst().offset());
- assertThat(denseResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(denseResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
- var embedding = (TextEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding();
+ var embedding = (DenseEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding();
assertArrayEquals(new float[] { 0.123f, -0.456f, 0.789f }, embedding.values(), 0.0f);
}
@@ -901,9 +901,9 @@ public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOExceptio
var denseResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(denseResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, "dense embedding".length()), denseResult.chunks().getFirst().offset());
- assertThat(denseResult.chunks().getFirst().embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(denseResult.chunks().getFirst().embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
- var embedding = (TextEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding();
+ var embedding = (DenseEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding();
assertArrayEquals(new float[] { 0.987f, -0.654f, 0.321f }, embedding.values(), 0.0f);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java
index 1d91713eef931..ab715b7f73e96 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java
@@ -20,9 +20,9 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@@ -298,8 +298,8 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction()
var result = listener.actionGet(TIMEOUT);
- assertThat(result, instanceOf(TextEmbeddingFloatResults.class));
- var textEmbeddingResults = (TextEmbeddingFloatResults) result;
+ assertThat(result, instanceOf(DenseEmbeddingFloatResults.class));
+ var textEmbeddingResults = (DenseEmbeddingFloatResults) result;
assertThat(textEmbeddingResults.embeddings(), hasSize(2));
var firstEmbedding = textEmbeddingResults.embeddings().get(0);
@@ -354,8 +354,8 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_W
var result = listener.actionGet(TIMEOUT);
- assertThat(result, instanceOf(TextEmbeddingFloatResults.class));
- var textEmbeddingResults = (TextEmbeddingFloatResults) result;
+ assertThat(result, instanceOf(DenseEmbeddingFloatResults.class));
+ var textEmbeddingResults = (DenseEmbeddingFloatResults) result;
assertThat(textEmbeddingResults.embeddings(), hasSize(1));
var embedding = textEmbeddingResults.embeddings().get(0);
@@ -447,8 +447,8 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_E
var result = listener.actionGet(TIMEOUT);
- assertThat(result, instanceOf(TextEmbeddingFloatResults.class));
- var textEmbeddingResults = (TextEmbeddingFloatResults) result;
+ assertThat(result, instanceOf(DenseEmbeddingFloatResults.class));
+ var textEmbeddingResults = (DenseEmbeddingFloatResults) result;
assertThat(textEmbeddingResults.embeddings(), hasSize(0));
assertThat(webServer.requests(), hasSize(1));
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests.java
index 2883a1ab73c21..79721b95af067 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests.java
@@ -9,7 +9,7 @@
import org.apache.http.HttpResponse;
import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceDenseTextEmbeddingsResponseEntity;
@@ -35,7 +35,7 @@ public void testDenseTextEmbeddingsResponse_SingleEmbeddingInData_NoMeta() throw
}
""";
- TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -64,7 +64,7 @@ public void testDenseTextEmbeddingsResponse_MultipleEmbeddingsInData_NoMeta() th
}
""";
- TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -85,7 +85,7 @@ public void testDenseTextEmbeddingsResponse_EmptyData() throws Exception {
}
""";
- TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -111,7 +111,7 @@ public void testDenseTextEmbeddingsResponse_SingleEmbeddingInData_IgnoresMeta()
}
""";
- TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java
index 2cc0323e6d913..0db705c1770ba 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java
@@ -53,9 +53,9 @@
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
@@ -71,8 +71,8 @@
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentTests;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResultsTests;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResultsTests;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResultsTests;
import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResultsTests;
@@ -1109,8 +1109,8 @@ public void testChunkInfer_E5ChunkingSettingsSet() throws InterruptedException {
@SuppressWarnings("unchecked")
private void testChunkInfer_e5(ChunkingSettings chunkingSettings) throws InterruptedException {
var mlTrainedModelResults = new ArrayList();
- mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
- mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
+ mlTrainedModelResults.add(MlDenseEmbeddingResultsTests.createRandomResults());
+ mlTrainedModelResults.add(MlDenseEmbeddingResultsTests.createRandomResults());
var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true);
Client client = mock(Client.class);
@@ -1136,20 +1136,20 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) throws Interru
assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbedding.class));
var result1 = (ChunkedInferenceEmbedding) chunkedResponse.get(0);
assertThat(result1.chunks(), hasSize(1));
- assertThat(result1.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(result1.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
- ((MlTextEmbeddingResults) mlTrainedModelResults.get(0)).getInferenceAsFloat(),
- ((TextEmbeddingFloatResults.Embedding) result1.chunks().get(0).embedding()).values(),
+ ((MlDenseEmbeddingResults) mlTrainedModelResults.get(0)).getInferenceAsFloat(),
+ ((DenseEmbeddingFloatResults.Embedding) result1.chunks().get(0).embedding()).values(),
0.0001f
);
assertEquals(new ChunkedInference.TextOffset(0, 1), result1.chunks().get(0).offset());
assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbedding.class));
var result2 = (ChunkedInferenceEmbedding) chunkedResponse.get(1);
assertThat(result2.chunks(), hasSize(1));
- assertThat(result2.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(result2.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
- ((MlTextEmbeddingResults) mlTrainedModelResults.get(1)).getInferenceAsFloat(),
- ((TextEmbeddingFloatResults.Embedding) result2.chunks().get(0).embedding()).values(),
+ ((MlDenseEmbeddingResults) mlTrainedModelResults.get(1)).getInferenceAsFloat(),
+ ((DenseEmbeddingFloatResults.Embedding) result2.chunks().get(0).embedding()).values(),
0.0001f
);
assertEquals(new ChunkedInference.TextOffset(0, 2), result2.chunks().get(0).offset());
@@ -1377,8 +1377,8 @@ public void testChunkInferSetsTokenization() {
@SuppressWarnings("unchecked")
public void testChunkInfer_FailsBatch() throws InterruptedException {
var mlTrainedModelResults = new ArrayList();
- mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
- mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
+ mlTrainedModelResults.add(MlDenseEmbeddingResultsTests.createRandomResults());
+ mlTrainedModelResults.add(MlDenseEmbeddingResultsTests.createRandomResults());
mlTrainedModelResults.add(new ErrorInferenceResults(new RuntimeException("boom")));
var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true);
@@ -1454,7 +1454,7 @@ public void testChunkingLargeDocument() throws InterruptedException {
var listener = (ActionListener) invocationOnMock.getArguments()[2];
var mlTrainedModelResults = new ArrayList();
for (int i = 0; i < request.numberOfDocuments(); i++) {
- mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
+ mlTrainedModelResults.add(MlDenseEmbeddingResultsTests.createRandomResults());
}
var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true);
listener.onResponse(response);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java
index be087f73f8d5b..7a5fef5ff0995 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java
@@ -39,7 +39,7 @@
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@@ -66,7 +66,7 @@
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
@@ -941,11 +941,11 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, input.get(0).input().length()), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertTrue(
Arrays.equals(
new float[] { 0.0123f, -0.0123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
)
);
}
@@ -956,11 +956,11 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, input.get(1).input().length()), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertTrue(
Arrays.equals(
new float[] { 0.0456f, -0.0456f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
)
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java
index 38e33c546b291..9e6b5547a6b1d 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java
@@ -40,7 +40,7 @@
import java.util.Map;
import java.util.concurrent.TimeUnit;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntityTests.java
index eca4a369c29c8..6bfb33769602f 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntityTests.java
@@ -9,7 +9,7 @@
import org.apache.http.HttpResponse;
import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
@@ -36,12 +36,12 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
}
""";
- TextEmbeddingFloatResults parsedResults = GoogleAiStudioEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = GoogleAiStudioEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(DenseEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F)))));
}
public void testFromResponse_CreatesResultsForMultipleItems() throws IOException {
@@ -64,7 +64,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
}
""";
- TextEmbeddingFloatResults parsedResults = GoogleAiStudioEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = GoogleAiStudioEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -73,8 +73,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
parsedResults.embeddings(),
is(
List.of(
- TextEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F)),
- TextEmbeddingFloatResults.Embedding.of(List.of(0.030681048F, 0.01714732F))
+ DenseEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F)),
+ DenseEmbeddingFloatResults.Embedding.of(List.of(0.030681048F, 0.01714732F))
)
)
);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntityTests.java
index 8f19edb3031d7..c3a8a8b74078b 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntityTests.java
@@ -9,7 +9,7 @@
import org.apache.http.HttpResponse;
import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
@@ -42,12 +42,12 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
}
""";
- TextEmbeddingFloatResults parsedResults = GoogleVertexAiEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = GoogleVertexAiEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingFloatResults.Embedding.of(List.of(-0.123F, 0.123F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(DenseEmbeddingFloatResults.Embedding.of(List.of(-0.123F, 0.123F)))));
}
public void testFromResponse_CreatesResultsForMultipleItems() throws IOException {
@@ -82,7 +82,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
}
""";
- TextEmbeddingFloatResults parsedResults = GoogleVertexAiEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = GoogleVertexAiEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -91,8 +91,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
parsedResults.embeddings(),
is(
List.of(
- TextEmbeddingFloatResults.Embedding.of(List.of(-0.123F, 0.123F)),
- TextEmbeddingFloatResults.Embedding.of(List.of(-0.456F, 0.456F))
+ DenseEmbeddingFloatResults.Embedding.of(List.of(-0.123F, 0.123F)),
+ DenseEmbeddingFloatResults.Embedding.of(List.of(-0.456F, 0.456F))
)
)
);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java
index 67534d83a2e15..5d60a87200f3e 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java
@@ -42,8 +42,8 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -76,7 +76,7 @@
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
@@ -1212,10 +1212,10 @@ public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() th
var embeddingResult = (ChunkedInferenceEmbedding) result;
assertThat(embeddingResult.chunks(), hasSize(1));
assertThat(embeddingResult.chunks().get(0).offset(), is(new ChunkedInference.TextOffset(0, "abc".length())));
- assertThat(embeddingResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(embeddingResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { -0.0123f, 0.0123f },
- ((TextEmbeddingFloatResults.Embedding) embeddingResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) embeddingResult.chunks().get(0).embedding()).values(),
0.001f
);
assertThat(webServer.requests(), hasSize(1));
@@ -1266,10 +1266,10 @@ public void testChunkedInfer() throws IOException {
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 3), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 0.123f, -0.123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java
index 1194f3a5fa95c..1225e40a0fc0c 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java
@@ -19,8 +19,8 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests;
import org.elasticsearch.xpack.inference.InputTypeTests;
import org.elasticsearch.xpack.inference.common.TruncatorTests;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
@@ -231,7 +231,10 @@ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws I
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F }))));
+ assertThat(
+ result.asMap(),
+ is(DenseEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))
+ );
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
@@ -416,7 +419,10 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F }))));
+ assertThat(
+ result.asMap(),
+ is(DenseEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))
+ );
assertThat(webServer.requests(), hasSize(2));
{
@@ -478,7 +484,10 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException {
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F }))));
+ assertThat(
+ result.asMap(),
+ is(DenseEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))
+ );
assertThat(webServer.requests(), hasSize(1));
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntityTests.java
index 61e035326d163..e157038f2f244 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntityTests.java
@@ -10,7 +10,7 @@
import org.apache.http.HttpResponse;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
@@ -32,14 +32,14 @@ public void testFromResponse_CreatesResultsForASingleItem_ArrayFormat() throws I
]
""";
- TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
assertThat(
parsedResults.embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
);
}
@@ -55,14 +55,14 @@ public void testFromResponse_CreatesResultsForASingleItem_ObjectFormat() throws
}
""";
- TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
assertThat(
parsedResults.embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
);
}
@@ -80,7 +80,7 @@ public void testFromResponse_CreatesResultsForMultipleItems_ArrayFormat() throws
]
""";
- TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -89,8 +89,8 @@ public void testFromResponse_CreatesResultsForMultipleItems_ArrayFormat() throws
parsedResults.embeddings(),
is(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })
)
)
);
@@ -112,7 +112,7 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw
}
""";
- TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -121,8 +121,8 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw
parsedResults.embeddings(),
is(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })
)
)
);
@@ -255,12 +255,12 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ArrayFormat() throw
]
""";
- TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 1.0F }))));
+ assertThat(parsedResults.embeddings(), is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 1.0F }))));
}
public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ObjectFormat() throws IOException {
@@ -274,12 +274,12 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ObjectFormat() thro
}
""";
- TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 1.0F }))));
+ assertThat(parsedResults.embeddings(), is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 1.0F }))));
}
public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ArrayFormat() throws IOException {
@@ -291,12 +291,12 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ArrayFormat() thro
]
""";
- TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F }))));
+ assertThat(parsedResults.embeddings(), is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F }))));
}
public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ObjectFormat() throws IOException {
@@ -310,12 +310,12 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ObjectFormat() thr
}
""";
- TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F }))));
+ assertThat(parsedResults.embeddings(), is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F }))));
}
public void testFromResponse_FailsWhenEmbeddingValueIsAnObject_ObjectFormat() {
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java
index 6e24981e3f3b3..229bbb73eb14d 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java
@@ -38,7 +38,7 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -70,7 +70,7 @@
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
@@ -789,11 +789,11 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, input.get(0).input().length()), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertTrue(
Arrays.equals(
new float[] { 0.0123f, -0.0123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
)
);
}
@@ -804,11 +804,11 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, input.get(1).input().length()), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertTrue(
Arrays.equals(
new float[] { 0.0456f, -0.0456f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
)
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java
index a0980a8151036..793bb535c0cc5 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java
@@ -44,7 +44,7 @@
import java.util.Map;
import java.util.concurrent.TimeUnit;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntityTests.java
index db42e9c49e1e5..c3bdb8c686985 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntityTests.java
@@ -9,7 +9,7 @@
import org.apache.http.HttpResponse;
import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
@@ -36,12 +36,12 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
}
""";
- TextEmbeddingFloatResults parsedResults = IbmWatsonxEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = IbmWatsonxEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F)))));
+ assertThat(parsedResults.embeddings(), is(List.of(DenseEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F)))));
}
public void testFromResponse_CreatesResultsForMultipleItems() throws IOException {
@@ -66,7 +66,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
}
""";
- TextEmbeddingFloatResults parsedResults = IbmWatsonxEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = IbmWatsonxEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -75,8 +75,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
parsedResults.embeddings(),
is(
List.of(
- TextEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F)),
- TextEmbeddingFloatResults.Embedding.of(List.of(0.030681048F, 0.01714732F))
+ DenseEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F)),
+ DenseEmbeddingFloatResults.Embedding.of(List.of(0.030681048F, 0.01714732F))
)
)
);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java
index d408e269219cb..a6a04692d33bf 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java
@@ -38,7 +38,7 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@@ -65,7 +65,7 @@
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
@@ -1688,10 +1688,10 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel mode
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 0.123f, -0.123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
@@ -1700,10 +1700,10 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel mode
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 0.223f, -0.223f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntityTests.java
index c1b19cb450789..4df61cdc439af 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntityTests.java
@@ -11,9 +11,9 @@
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.InputTypeTests;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType;
@@ -69,10 +69,10 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults, instanceOf(TextEmbeddingFloatResults.class));
+ assertThat(parsedResults, instanceOf(DenseEmbeddingFloatResults.class));
assertThat(
- ((TextEmbeddingFloatResults) parsedResults).embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
+ ((DenseEmbeddingFloatResults) parsedResults).embeddings(),
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
);
}
@@ -123,13 +123,13 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults, instanceOf(TextEmbeddingFloatResults.class));
+ assertThat(parsedResults, instanceOf(DenseEmbeddingFloatResults.class));
assertThat(
- ((TextEmbeddingFloatResults) parsedResults).embeddings(),
+ ((DenseEmbeddingFloatResults) parsedResults).embeddings(),
is(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })
)
)
);
@@ -363,8 +363,8 @@ public void testFromResponse_SucceedsWhenEmbeddingType_IsBinary() throws IOExcep
);
assertThat(
- ((TextEmbeddingBitResults) parsedResults).embeddings(),
- is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 })))
+ ((DenseEmbeddingBitResults) parsedResults).embeddings(),
+ is(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 })))
);
}
@@ -411,8 +411,8 @@ public void testFromResponse_SucceedsWhenEmbeddingType_IsBit() throws IOExceptio
);
assertThat(
- ((TextEmbeddingBitResults) parsedResults).embeddings(),
- is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 })))
+ ((DenseEmbeddingBitResults) parsedResults).embeddings(),
+ is(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 })))
);
}
@@ -504,7 +504,7 @@ public void testFieldsInDifferentOrderServer() throws IOException {
}
}""";
- TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) JinaAIEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) JinaAIEmbeddingsResponseEntity.fromResponse(
JinaAIEmbeddingsRequestTests.createRequest(
List.of("abc"),
InputTypeTests.randomWithNull(),
@@ -525,9 +525,9 @@ public void testFieldsInDifferentOrderServer() throws IOException {
parsedResults.embeddings(),
is(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F })
)
)
);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java
index f6a0232db529c..818ae94192ce1 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java
@@ -42,7 +42,7 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
@@ -695,11 +695,11 @@ public void testChunkedInfer(LlamaEmbeddingsModel model) throws IOException {
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertTrue(
Arrays.equals(
new float[] { 0.010060793f, -0.0017529363f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
)
);
}
@@ -707,11 +707,11 @@ public void testChunkedInfer(LlamaEmbeddingsModel model) throws IOException {
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertTrue(
Arrays.equals(
new float[] { 0.110060793f, -0.1017529363f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
)
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java
index 5bf65870e1dfc..df2ce97614edc 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java
@@ -19,7 +19,7 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests;
import org.elasticsearch.xpack.inference.InputTypeTests;
import org.elasticsearch.xpack.inference.common.TruncatorTests;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
@@ -98,7 +98,10 @@ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws I
var result = listener.actionGet(TIMEOUT);
- assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F }))));
+ assertThat(
+ result.asMap(),
+ is(DenseEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))
+ );
assertEmbeddingsRequest();
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java
index db94ec5c9c2f5..3ff7bf1a78374 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java
@@ -41,7 +41,7 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.ModelConfigurationsTests;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
@@ -1135,11 +1135,11 @@ public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException {
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertTrue(
Arrays.equals(
new float[] { 0.123f, -0.123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
)
);
}
@@ -1147,11 +1147,11 @@ public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException {
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertTrue(
Arrays.equals(
new float[] { 0.223f, -0.223f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
)
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
index 0a03c7a231e8c..85c9c73d002b2 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
@@ -43,7 +43,7 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
@@ -84,7 +84,7 @@
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
@@ -1049,11 +1049,11 @@ private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException {
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertTrue(
Arrays.equals(
new float[] { 0.123f, -0.123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
)
);
}
@@ -1062,11 +1062,11 @@ private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException {
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
- assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertTrue(
Arrays.equals(
new float[] { 0.223f, -0.223f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
)
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java
index 5cdcf402835b4..1291a1ad4a36c 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java
@@ -36,7 +36,7 @@
import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java
index c34ebbde4ac8f..1851dc4cad2b3 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java
@@ -40,7 +40,7 @@
import java.util.List;
import java.util.concurrent.TimeUnit;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntityTests.java
index f2a430eefc801..77e1f384509b4 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntityTests.java
@@ -10,7 +10,7 @@
import org.apache.http.HttpResponse;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.XContentParseException;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
@@ -45,14 +45,14 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
}
""";
- TextEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
assertThat(
parsedResults.embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
);
}
@@ -86,7 +86,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
}
""";
- TextEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
@@ -95,8 +95,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
parsedResults.embeddings(),
is(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })
)
)
);
@@ -254,12 +254,12 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOExceptio
}
""";
- TextEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 1.0F }))));
+ assertThat(parsedResults.embeddings(), is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 1.0F }))));
}
public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOException {
@@ -283,12 +283,12 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOExcepti
}
""";
- TextEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F }))));
+ assertThat(parsedResults.embeddings(), is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F }))));
}
public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() {
@@ -365,7 +365,7 @@ public void testFieldsInDifferentOrderServer() throws IOException {
}
}""";
- TextEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse(
+ DenseEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), response.getBytes(StandardCharsets.UTF_8))
);
@@ -374,9 +374,9 @@ public void testFieldsInDifferentOrderServer() throws IOException {
parsedResults.embeddings(),
is(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F })
)
)
);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java
index 5d6bec1bcfbff..b391ce0cc7949 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java
@@ -29,7 +29,7 @@
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
@@ -449,7 +449,7 @@ public void testChunkedInfer() throws Exception {
var model = mockModelForChunking();
SageMakerSchema schema = mock();
- when(schema.response(any(), any(), any())).thenReturn(TextEmbeddingFloatResultsTests.createRandomResults());
+ when(schema.response(any(), any(), any())).thenReturn(DenseEmbeddingFloatResultsTests.createRandomResults());
when(schemas.schemaFor(model)).thenReturn(schema);
mockInvoke();
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayloadTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayloadTests.java
index ed0ee43266ab5..3ac0f20d4242c 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayloadTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayloadTests.java
@@ -9,8 +9,8 @@
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema;
@@ -66,8 +66,8 @@ public void testBitResponse() throws Exception {
assertThat(bitResults.embeddings().size(), is(1));
var embedding = bitResults.embeddings().get(0);
- assertThat(embedding, isA(TextEmbeddingByteResults.Embedding.class));
- assertThat(((TextEmbeddingByteResults.Embedding) embedding).values(), is(new byte[] { 23 }));
+ assertThat(embedding, isA(DenseEmbeddingByteResults.Embedding.class));
+ assertThat(((DenseEmbeddingByteResults.Embedding) embedding).values(), is(new byte[] { 23 }));
}
public void testByteResponse() throws Exception {
@@ -87,8 +87,8 @@ public void testByteResponse() throws Exception {
assertThat(byteResults.embeddings().size(), is(1));
var embedding = byteResults.embeddings().get(0);
- assertThat(embedding, isA(TextEmbeddingByteResults.Embedding.class));
- assertThat(((TextEmbeddingByteResults.Embedding) embedding).values(), is(new byte[] { 23 }));
+ assertThat(embedding, isA(DenseEmbeddingByteResults.Embedding.class));
+ assertThat(((DenseEmbeddingByteResults.Embedding) embedding).values(), is(new byte[] { 23 }));
}
public void testFloatResponse() throws Exception {
@@ -108,7 +108,7 @@ public void testFloatResponse() throws Exception {
assertThat(byteResults.embeddings().size(), is(1));
var embedding = byteResults.embeddings().get(0);
- assertThat(embedding, isA(TextEmbeddingFloatResults.Embedding.class));
- assertThat(((TextEmbeddingFloatResults.Embedding) embedding).values(), is(new float[] { 0.1F }));
+ assertThat(embedding, isA(DenseEmbeddingFloatResults.Embedding.class));
+ assertThat(((DenseEmbeddingFloatResults.Embedding) embedding).values(), is(new float[] { 0.1F }));
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayloadTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayloadTests.java
index 35b78b004618c..2435514e9f7e1 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayloadTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayloadTests.java
@@ -12,7 +12,7 @@
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemaPayloadTestCase;
@@ -145,11 +145,11 @@ public void testResponse() throws Exception {
.body(SdkBytes.fromString(responseJson, StandardCharsets.UTF_8))
.build();
- var textEmbeddingFloatResults = payload.responseBody(mock(), invokeEndpointResponse);
+ var denseEmbeddingFloatResults = payload.responseBody(mock(), invokeEndpointResponse);
assertThat(
- textEmbeddingFloatResults.embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
+ denseEmbeddingFloatResults.embeddings(),
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
);
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/DenseEmbeddingModelValidatorTests.java
similarity index 90%
rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java
rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/DenseEmbeddingModelValidatorTests.java
index 45726f0789667..9ad4e81b20cb7 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/DenseEmbeddingModelValidatorTests.java
@@ -16,10 +16,10 @@
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.validation.ServiceIntegrationValidator;
import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResultsTests;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResultsTests;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.EmptyTaskSettingsTests;
import org.elasticsearch.xpack.inference.ModelConfigurationsTests;
import org.junit.Before;
@@ -38,7 +38,7 @@
import static org.mockito.Mockito.when;
import static org.mockito.MockitoAnnotations.openMocks;
-public class TextEmbeddingModelValidatorTests extends ESTestCase {
+public class DenseEmbeddingModelValidatorTests extends ESTestCase {
private static final TimeValue TIMEOUT = TimeValue.ONE_MINUTE;
@@ -53,13 +53,13 @@ public class TextEmbeddingModelValidatorTests extends ESTestCase {
@Mock
private ServiceSettings mockServiceSettings;
- private TextEmbeddingModelValidator underTest;
+ private DenseEmbeddingModelValidator underTest;
@Before
public void setup() {
openMocks(this);
- underTest = new TextEmbeddingModelValidator(mockServiceIntegrationValidator);
+ underTest = new DenseEmbeddingModelValidator(mockServiceIntegrationValidator);
when(mockInferenceService.updateModelWithEmbeddingDetails(eq(mockModel), anyInt())).thenReturn(mockModel);
when(mockActionListener.delegateFailureAndWrap(any())).thenCallRealMethod();
@@ -94,7 +94,7 @@ public void testValidate_ServiceReturnsNonTextEmbeddingResults() {
}
public void testValidate_RetrievingEmbeddingSizeThrowsIllegalStateException() {
- TextEmbeddingFloatResults results = new TextEmbeddingFloatResults(List.of());
+ DenseEmbeddingFloatResults results = new DenseEmbeddingFloatResults(List.of());
when(mockServiceSettings.dimensionsSetByUser()).thenReturn(true);
when(mockServiceSettings.dimensions()).thenReturn(randomNonNegativeInt());
@@ -107,7 +107,7 @@ public void testValidate_RetrievingEmbeddingSizeThrowsIllegalStateException() {
}
public void testValidate_DimensionsSetByUserDoNotEqualEmbeddingSize() {
- TextEmbeddingByteResults results = TextEmbeddingByteResultsTests.createRandomResults();
+ DenseEmbeddingByteResults results = DenseEmbeddingByteResultsTests.createRandomResults();
var dimensions = randomValueOtherThan(results.getFirstEmbeddingSize(), ESTestCase::randomNonNegativeInt);
when(mockServiceSettings.dimensionsSetByUser()).thenReturn(true);
@@ -131,7 +131,7 @@ public void testValidate_DimensionsNotSetByUser() {
}
private void mockSuccessfulValidation(Boolean dimensionsSetByUser) {
- TextEmbeddingByteResults results = TextEmbeddingByteResultsTests.createRandomResults();
+ DenseEmbeddingByteResults results = DenseEmbeddingByteResultsTests.createRandomResults();
when(mockModel.getConfigurations()).thenReturn(ModelConfigurationsTests.createRandomInstance());
when(mockModel.getTaskSettings()).thenReturn(EmptyTaskSettingsTests.createRandom());
when(mockServiceSettings.dimensionsSetByUser()).thenReturn(dimensionsSetByUser);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidatorTests.java
index 30e7a33757c16..c9db9907792fd 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidatorTests.java
@@ -16,7 +16,7 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingResults;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandEmbeddingModel;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings;
@@ -118,7 +118,7 @@ public void testValidate_ElandTextEmbeddingModelValidationFails() {
public void testValidate_ElandTextEmbeddingModelValidationSucceedsAndDimensionsSetByUserValid() {
var dimensions = randomIntBetween(1, 10);
- var mockInferenceServiceResults = mock(TextEmbeddingResults.class);
+ var mockInferenceServiceResults = mock(DenseEmbeddingResults.class);
var mockUpdatedModel = mock(CustomElandEmbeddingModel.class);
when(mockInferenceServiceResults.getFirstEmbeddingSize()).thenReturn(dimensions);
CustomElandEmbeddingModel customElandEmbeddingModel = createCustomElandEmbeddingModel(true, dimensions);
@@ -151,7 +151,7 @@ public void testValidate_ElandTextEmbeddingModelValidationSucceedsAndDimensionsS
public void testValidate_ElandTextEmbeddingModelValidationSucceedsAndDimensionsSetByUserInvalid() {
var dimensions = randomIntBetween(1, 10);
- var mockInferenceServiceResults = mock(TextEmbeddingResults.class);
+ var mockInferenceServiceResults = mock(DenseEmbeddingResults.class);
when(mockInferenceServiceResults.getFirstEmbeddingSize()).thenReturn(
randomValueOtherThan(dimensions, () -> randomIntBetween(1, 10))
);
@@ -207,7 +207,7 @@ public void testValidate_ElandTextEmbeddingAndValidationReturnsInvalidResultsTyp
public void testValidate_ElandTextEmbeddingModelDimensionsNotSetByUser() {
var dimensions = randomIntBetween(1, 10);
- var mockInferenceServiceResults = mock(TextEmbeddingResults.class);
+ var mockInferenceServiceResults = mock(DenseEmbeddingResults.class);
when(mockInferenceServiceResults.getFirstEmbeddingSize()).thenReturn(dimensions);
CustomElandEmbeddingModel customElandEmbeddingModel = createCustomElandEmbeddingModel(false, null);
@@ -239,7 +239,7 @@ public void testValidate_ElandTextEmbeddingModelDimensionsNotSetByUser() {
}
public void testValidate_ElandTextEmbeddingModelAndEmbeddingSizeRetrievalThrowsException() {
- var mockInferenceServiceResults = mock(TextEmbeddingResults.class);
+ var mockInferenceServiceResults = mock(DenseEmbeddingResults.class);
when(mockInferenceServiceResults.getFirstEmbeddingSize()).thenThrow(ElasticsearchStatusException.class);
CustomElandEmbeddingModel customElandEmbeddingModel = createCustomElandEmbeddingModel(false, null);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java
index b4d563b565eee..804984c866542 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java
@@ -95,7 +95,7 @@ public void testBuildModelValidator_ValidTaskType() {
private Map> taskTypeToModelValidatorClassMap() {
return Map.of(
TaskType.TEXT_EMBEDDING,
- TextEmbeddingModelValidator.class,
+ DenseEmbeddingModelValidator.class,
TaskType.SPARSE_EMBEDDING,
SimpleModelValidator.class,
TaskType.RERANK,
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java
index 9bbfcb4d58667..0e9059132f23d 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java
@@ -37,7 +37,7 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@@ -63,7 +63,7 @@
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
@@ -1654,10 +1654,10 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel mo
var floatResult = (ChunkedInferenceEmbedding) results.getFirst();
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().getFirst().offset());
- assertThat(floatResult.chunks().get(0).embedding(), CoreMatchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(floatResult.chunks().get(0).embedding(), CoreMatchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
assertArrayEquals(
new float[] { 0.123f, -0.123f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
@@ -1666,10 +1666,13 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel mo
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().getFirst().offset());
- assertThat(floatResult.chunks().getFirst().embedding(), CoreMatchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
+ assertThat(
+ floatResult.chunks().getFirst().embedding(),
+ CoreMatchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)
+ );
assertArrayEquals(
new float[] { 0.223f, -0.223f },
- ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
+ ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f
);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java
index afbb5532d5fdd..86d6fab29842b 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java
@@ -37,7 +37,7 @@
import java.util.Map;
import java.util.concurrent.TimeUnit;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java
index 7050c2d131e17..b326664c527c1 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java
@@ -46,9 +46,9 @@
import java.util.Map;
import java.util.concurrent.TimeUnit;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationBinary;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationByte;
-import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationBinary;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationByte;
+import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntityTests.java
index 80a00737ddf52..7fa6b0c7bece3 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntityTests.java
@@ -11,7 +11,7 @@
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.XContentParseException;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.InputTypeTests;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.services.voyageai.request.VoyageAIEmbeddingsRequest;
@@ -60,8 +60,8 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
);
assertThat(
- ((TextEmbeddingFloatResults) parsedResults).embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
+ ((DenseEmbeddingFloatResults) parsedResults).embeddings(),
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
);
}
@@ -106,11 +106,11 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
);
assertThat(
- ((TextEmbeddingFloatResults) parsedResults).embeddings(),
+ ((DenseEmbeddingFloatResults) parsedResults).embeddings(),
is(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })
)
)
);
@@ -299,8 +299,8 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOExceptio
);
assertThat(
- ((TextEmbeddingFloatResults) parsedResults).embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 1.0F })))
+ ((DenseEmbeddingFloatResults) parsedResults).embeddings(),
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 1.0F })))
);
}
@@ -336,8 +336,8 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOExcepti
);
assertThat(
- ((TextEmbeddingFloatResults) parsedResults).embeddings(),
- is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F })))
+ ((DenseEmbeddingFloatResults) parsedResults).embeddings(),
+ is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F })))
);
}
@@ -427,15 +427,15 @@ public void testFieldsInDifferentOrderServer() throws IOException {
new HttpResult(mock(HttpResponse.class), response.getBytes(StandardCharsets.UTF_8))
);
- assertThat(parsedResults, instanceOf(TextEmbeddingFloatResults.class));
+ assertThat(parsedResults, instanceOf(DenseEmbeddingFloatResults.class));
assertThat(
- ((TextEmbeddingFloatResults) parsedResults).embeddings(),
+ ((DenseEmbeddingFloatResults) parsedResults).embeddings(),
is(
List.of(
- new TextEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }),
- new TextEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F })
+ new DenseEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }),
+ new DenseEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F })
)
)
);
diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java
index c28fc8f44c3fa..49598aa620464 100644
--- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java
+++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java
@@ -9,7 +9,7 @@
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
@@ -85,7 +85,7 @@ static InferenceResults processResult(
tokenization.anyTruncated()
);
} else {
- return new MlTextEmbeddingResults(
+ return new MlDenseEmbeddingResults(
Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD),
pyTorchResult.getInferenceResult()[0][0],
tokenization.anyTruncated()
diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessorTests.java
index 8369412580b88..6f66a05187fee 100644
--- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessorTests.java
+++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessorTests.java
@@ -10,7 +10,7 @@
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizationResult;
@@ -40,9 +40,9 @@ public void testSingleResult() {
var tokenization = tokenizer.tokenize(input, Tokenization.Truncate.NONE, 0, 0, null);
var tokenizationResult = new BertTokenizationResult(TextExpansionProcessorTests.TEST_CASED_VOCAB, tokenization, 0);
var inferenceResult = TextEmbeddingProcessor.processResult(tokenizationResult, pytorchResult, "foo", false);
- assertThat(inferenceResult, instanceOf(MlTextEmbeddingResults.class));
+ assertThat(inferenceResult, instanceOf(MlDenseEmbeddingResults.class));
- var result = (MlTextEmbeddingResults) inferenceResult;
+ var result = (MlDenseEmbeddingResults) inferenceResult;
assertThat(result.getInference().length, greaterThan(0));
}
}
diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java
index 86b4e8c588e51..fe269b6bcb0f5 100644
--- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java
+++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java
@@ -17,7 +17,7 @@
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
-import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder;
import org.elasticsearch.xpack.ml.MachineLearningTests;
@@ -53,7 +53,7 @@ public ActionResponse createResponse(float[] array, TextEmbeddingQueryVectorBuil
embedding[i] = array[i];
}
return new InferModelAction.Response(
- List.of(new MlTextEmbeddingResults("foo", embedding, randomBoolean())),
+ List.of(new MlDenseEmbeddingResults("foo", embedding, randomBoolean())),
builder.getModelId(),
true
);