diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java index 3746960ad8f78..2d1b932271f25 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java @@ -24,7 +24,7 @@ public interface InferenceServiceResults extends NamedWriteable, ChunkedToXConte /** *

Transform the result to match the format required for the TransportCoordinatedInferenceAction. - * TransportCoordinatedInferenceAction expects an ml plugin TextEmbeddingResults or SparseEmbeddingResults.

+ * TransportCoordinatedInferenceAction expects an ml plugin DenseEmbeddingResults or SparseEmbeddingResults.

*/ default List transformToCoordinationFormat() { throw new UnsupportedOperationException("transformToCoordinationFormat() is not implemented"); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingBitResults.java similarity index 75% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResults.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingBitResults.java index 37fca12f1697a..0792bf90dbe12 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingBitResults.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; import java.io.IOException; import java.util.Iterator; @@ -24,9 +24,10 @@ import java.util.Objects; /** - * Writes a text embedding result in the follow json format + * Writes a dense embedding result in the follow json format. + *
  * {
- *     "text_embedding_bytes": [
+ *     "text_embedding_bits": [
  *         {
  *             "embedding": [
  *                 23
@@ -39,17 +40,19 @@
  *         }
  *     ]
  * }
+ * 
*/ -// Note: inheriting from TextEmbeddingByteResults gives a bad implementation of the +// Note: inheriting from DenseEmbeddingByteResults gives a bad implementation of the // Embedding.merge method for bits. TODO: implement a proper merge method -public record TextEmbeddingBitResults(List embeddings) +public record DenseEmbeddingBitResults(List embeddings) implements - TextEmbeddingResults { + DenseEmbeddingResults { + // This name is a holdover from before this class was renamed public static final String NAME = "text_embedding_service_bit_results"; public static final String TEXT_EMBEDDING_BITS = "text_embedding_bits"; - public TextEmbeddingBitResults(StreamInput in) throws IOException { - this(in.readCollectionAsList(TextEmbeddingByteResults.Embedding::new)); + public DenseEmbeddingBitResults(StreamInput in) throws IOException { + this(in.readCollectionAsList(DenseEmbeddingByteResults.Embedding::new)); } @Override @@ -79,7 +82,7 @@ public String getWriteableName() { @Override public List transformToCoordinationFormat() { return embeddings.stream() - .map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING_BITS, embedding.toDoubleArray(), false)) + .map(embedding -> new MlDenseEmbeddingResults(TEXT_EMBEDDING_BITS, embedding.toDoubleArray(), false)) .toList(); } @@ -94,7 +97,7 @@ public Map asMap() { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - TextEmbeddingBitResults that = (TextEmbeddingBitResults) o; + DenseEmbeddingBitResults that = (DenseEmbeddingBitResults) o; return Objects.equals(embeddings, that.embeddings); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingByteResults.java similarity index 90% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingByteResults.java index 54f858cb20ae0..9e72dc9a7b2b4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingByteResults.java @@ -20,7 +20,7 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContent; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; import java.io.IOException; import java.util.Arrays; @@ -31,7 +31,8 @@ import java.util.Objects; /** - * Writes a text embedding result in the follow json format + * Writes a dense embedding result in the follow json format + *
  * {
  *     "text_embedding_bytes": [
  *         {
@@ -46,13 +47,15 @@
  *         }
  *     ]
  * }
+ * 
*/ -public record TextEmbeddingByteResults(List embeddings) implements TextEmbeddingResults { +public record DenseEmbeddingByteResults(List embeddings) implements DenseEmbeddingResults { + // This name is a holdover from before this class was renamed public static final String NAME = "text_embedding_service_byte_results"; public static final String TEXT_EMBEDDING_BYTES = "text_embedding_bytes"; - public TextEmbeddingByteResults(StreamInput in) throws IOException { - this(in.readCollectionAsList(TextEmbeddingByteResults.Embedding::new)); + public DenseEmbeddingByteResults(StreamInput in) throws IOException { + this(in.readCollectionAsList(DenseEmbeddingByteResults.Embedding::new)); } @Override @@ -81,7 +84,7 @@ public String getWriteableName() { @Override public List transformToCoordinationFormat() { return embeddings.stream() - .map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING_BYTES, embedding.toDoubleArray(), false)) + .map(embedding -> new MlDenseEmbeddingResults(TEXT_EMBEDDING_BYTES, embedding.toDoubleArray(), false)) .toList(); } @@ -96,7 +99,7 @@ public Map asMap() { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - TextEmbeddingByteResults that = (TextEmbeddingByteResults) o; + DenseEmbeddingByteResults that = (DenseEmbeddingByteResults) o; return Objects.equals(embeddings, that.embeddings); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingFloatResults.java similarity index 84% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResults.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingFloatResults.java index 9dbdccd26e5d2..d14f38259208c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingFloatResults.java @@ -23,7 +23,7 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContent; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; import java.io.IOException; import java.util.ArrayList; @@ -35,7 +35,8 @@ import java.util.Objects; /** - * Writes a text embedding result in the follow json format + * Writes a dense embedding result in the follow json format + *
  * {
  *     "text_embedding": [
  *         {
@@ -50,20 +51,24 @@
  *         }
  *     ]
  * }
+ * 
*/ -public record TextEmbeddingFloatResults(List embeddings) implements TextEmbeddingResults { +public record DenseEmbeddingFloatResults(List embeddings) + implements + DenseEmbeddingResults { + // This name is a holdover from before this class was renamed public static final String NAME = "text_embedding_service_results"; public static final String TEXT_EMBEDDING = TaskType.TEXT_EMBEDDING.toString(); - public TextEmbeddingFloatResults(StreamInput in) throws IOException { - this(in.readCollectionAsList(TextEmbeddingFloatResults.Embedding::new)); + public DenseEmbeddingFloatResults(StreamInput in) throws IOException { + this(in.readCollectionAsList(DenseEmbeddingFloatResults.Embedding::new)); } - public static TextEmbeddingFloatResults of(List results) { + public static DenseEmbeddingFloatResults of(List results) { List embeddings = new ArrayList<>(results.size()); for (InferenceResults result : results) { - if (result instanceof MlTextEmbeddingResults embeddingResult) { - embeddings.add(TextEmbeddingFloatResults.Embedding.of(embeddingResult)); + if (result instanceof MlDenseEmbeddingResults embeddingResult) { + embeddings.add(DenseEmbeddingFloatResults.Embedding.of(embeddingResult)); } else if (result instanceof org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults errorResult) { if (errorResult.getException() instanceof ElasticsearchStatusException statusException) { throw statusException; @@ -76,11 +81,15 @@ public static TextEmbeddingFloatResults of(List resu } } else { throw new IllegalArgumentException( - "Received invalid inference result, of type " + result.getClass().getName() + " but expected TextEmbeddingResults." + "Received invalid inference result, of type " + + result.getClass().getName() + + " but expected " + + MlDenseEmbeddingResults.class.getName() + + "." ); } } - return new TextEmbeddingFloatResults(embeddings); + return new DenseEmbeddingFloatResults(embeddings); } @Override @@ -108,7 +117,7 @@ public String getWriteableName() { @Override public List transformToCoordinationFormat() { - return embeddings.stream().map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING, embedding.asDoubleArray(), false)).toList(); + return embeddings.stream().map(embedding -> new MlDenseEmbeddingResults(TEXT_EMBEDDING, embedding.asDoubleArray(), false)).toList(); } public Map asMap() { @@ -122,7 +131,7 @@ public Map asMap() { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - TextEmbeddingFloatResults that = (TextEmbeddingFloatResults) o; + DenseEmbeddingFloatResults that = (DenseEmbeddingFloatResults) o; return Objects.equals(embeddings, that.embeddings); } @@ -148,7 +157,7 @@ public Embedding(StreamInput in) throws IOException { this(in.readFloatArray()); } - public static Embedding of(MlTextEmbeddingResults embeddingResult) { + public static Embedding of(MlDenseEmbeddingResults embeddingResult) { float[] embeddingAsArray = embeddingResult.getInferenceAsFloat(); return new Embedding(embeddingAsArray); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingResults.java similarity index 66% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingResults.java index ea4e45ec67407..af6abb357bd98 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingResults.java @@ -7,11 +7,11 @@ package org.elasticsearch.xpack.core.inference.results; -public interface TextEmbeddingResults> extends EmbeddingResults { +public interface DenseEmbeddingResults> extends EmbeddingResults { /** - * Returns the first text embedding entry in the result list's array size. - * @return the size of the text embedding + * Returns the first embedding entry in the result list's array size. + * @return the size of the embedding * @throws IllegalStateException if the list of embeddings is empty */ int getFirstEmbeddingSize() throws IllegalStateException; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index 667d7bf63efc9..f2952fee36cda 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -25,7 +25,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults; import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.NerResults; import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults; @@ -669,7 +669,7 @@ public List getNamedWriteables() { ); namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, TextExpansionResults.NAME, TextExpansionResults::new)); namedWriteables.add( - new NamedWriteableRegistry.Entry(InferenceResults.class, MlTextEmbeddingResults.NAME, MlTextEmbeddingResults::new) + new NamedWriteableRegistry.Entry(InferenceResults.class, MlDenseEmbeddingResults.NAME, MlDenseEmbeddingResults::new) ); namedWriteables.add( new NamedWriteableRegistry.Entry( diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/MlTextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/MlDenseEmbeddingResults.java similarity index 87% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/MlTextEmbeddingResults.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/MlDenseEmbeddingResults.java index 0c0fa6f3f690e..b839ade030e1e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/MlTextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/MlDenseEmbeddingResults.java @@ -16,20 +16,21 @@ import java.util.Map; import java.util.Objects; -public class MlTextEmbeddingResults extends NlpInferenceResults { +public class MlDenseEmbeddingResults extends NlpInferenceResults { + // This name is a holdover from before this class was renamed public static final String NAME = "text_embedding_result"; private final String resultsField; private final double[] inference; - public MlTextEmbeddingResults(String resultsField, double[] inference, boolean isTruncated) { + public MlDenseEmbeddingResults(String resultsField, double[] inference, boolean isTruncated) { super(isTruncated); this.inference = inference; this.resultsField = resultsField; } - public MlTextEmbeddingResults(StreamInput in) throws IOException { + public MlDenseEmbeddingResults(StreamInput in) throws IOException { super(in); inference = in.readDoubleArray(); resultsField = in.readString(); @@ -89,7 +90,7 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; if (super.equals(o) == false) return false; - MlTextEmbeddingResults that = (MlTextEmbeddingResults) o; + MlDenseEmbeddingResults that = (MlDenseEmbeddingResults) o; return Objects.equals(resultsField, that.resultsField) && Arrays.equals(inference, that.inference); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/vectors/TextEmbeddingQueryVectorBuilder.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/vectors/TextEmbeddingQueryVectorBuilder.java index 1f4967d95c9b7..a5b085e877a73 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/vectors/TextEmbeddingQueryVectorBuilder.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/vectors/TextEmbeddingQueryVectorBuilder.java @@ -20,7 +20,7 @@ import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; -import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate; @@ -117,14 +117,14 @@ public void buildVector(Client client, ActionListener listener) { return; } - if (response.getInferenceResults().get(0) instanceof MlTextEmbeddingResults textEmbeddingResults) { + if (response.getInferenceResults().get(0) instanceof MlDenseEmbeddingResults textEmbeddingResults) { listener.onResponse(textEmbeddingResults.getInferenceAsFloat()); } else if (response.getInferenceResults().get(0) instanceof WarningInferenceResults warning) { listener.onFailure(new IllegalStateException(warning.getWarning())); } else { throw new IllegalArgumentException( "expected a result of type [" - + MlTextEmbeddingResults.NAME + + MlDenseEmbeddingResults.NAME + "] received [" + response.getInferenceResults().get(0).getWriteableName() + "]. Is [" diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingBitResultsTests.java similarity index 55% rename from x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResultsTests.java rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingBitResultsTests.java index 61b49075702a2..9e5814825d659 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingBitResultsTests.java @@ -10,7 +10,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; import java.io.IOException; import java.util.ArrayList; @@ -19,19 +19,19 @@ import static org.hamcrest.Matchers.is; -public class TextEmbeddingBitResultsTests extends AbstractWireSerializingTestCase { - public static TextEmbeddingBitResults createRandomResults() { +public class DenseEmbeddingBitResultsTests extends AbstractWireSerializingTestCase { + public static DenseEmbeddingBitResults createRandomResults() { int embeddings = randomIntBetween(1, 10); - List embeddingResults = new ArrayList<>(embeddings); + List embeddingResults = new ArrayList<>(embeddings); for (int i = 0; i < embeddings; i++) { embeddingResults.add(createRandomEmbedding()); } - return new TextEmbeddingBitResults(embeddingResults); + return new DenseEmbeddingBitResults(embeddingResults); } - private static TextEmbeddingByteResults.Embedding createRandomEmbedding() { + private static DenseEmbeddingByteResults.Embedding createRandomEmbedding() { int columns = randomIntBetween(1, 10); byte[] bytes = new byte[columns]; @@ -39,11 +39,11 @@ private static TextEmbeddingByteResults.Embedding createRandomEmbedding() { bytes[i] = randomByte(); } - return new TextEmbeddingByteResults.Embedding(bytes); + return new DenseEmbeddingByteResults.Embedding(bytes); } public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException { - var entity = new TextEmbeddingBitResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }))); + var entity = new DenseEmbeddingBitResults(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }))); String xContentResult = Strings.toString(entity, true, true); assertThat(xContentResult, is(""" @@ -59,10 +59,10 @@ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOE } public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException { - var entity = new TextEmbeddingBitResults( + var entity = new DenseEmbeddingBitResults( List.of( - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }), - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 24 }) + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }), + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 24 }) ) ); @@ -85,10 +85,10 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I } public void testTransformToCoordinationFormat() { - var results = new TextEmbeddingBitResults( + var results = new DenseEmbeddingBitResults( List.of( - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }), - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 }) + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }), + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 }) ) ).transformToCoordinationFormat(); @@ -96,18 +96,18 @@ public void testTransformToCoordinationFormat() { results, is( List.of( - new MlTextEmbeddingResults(TextEmbeddingBitResults.TEXT_EMBEDDING_BITS, new double[] { 23F, 24F }, false), - new MlTextEmbeddingResults(TextEmbeddingBitResults.TEXT_EMBEDDING_BITS, new double[] { 25F, 26F }, false) + new MlDenseEmbeddingResults(DenseEmbeddingBitResults.TEXT_EMBEDDING_BITS, new double[] { 23F, 24F }, false), + new MlDenseEmbeddingResults(DenseEmbeddingBitResults.TEXT_EMBEDDING_BITS, new double[] { 25F, 26F }, false) ) ) ); } public void testGetFirstEmbeddingSize() { - var firstEmbeddingSize = new TextEmbeddingBitResults( + var firstEmbeddingSize = new DenseEmbeddingBitResults( List.of( - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }), - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 }) + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }), + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 }) ) ).getFirstEmbeddingSize(); @@ -115,33 +115,33 @@ public void testGetFirstEmbeddingSize() { } @Override - protected Writeable.Reader instanceReader() { - return TextEmbeddingBitResults::new; + protected Writeable.Reader instanceReader() { + return DenseEmbeddingBitResults::new; } @Override - protected TextEmbeddingBitResults createTestInstance() { + protected DenseEmbeddingBitResults createTestInstance() { return createRandomResults(); } @Override - protected TextEmbeddingBitResults mutateInstance(TextEmbeddingBitResults instance) throws IOException { + protected DenseEmbeddingBitResults mutateInstance(DenseEmbeddingBitResults instance) throws IOException { // if true we reduce the embeddings list by a random amount, if false we add an embedding to the list if (randomBoolean()) { // -1 to remove at least one item from the list int end = randomInt(instance.embeddings().size() - 1); - return new TextEmbeddingBitResults(instance.embeddings().subList(0, end)); + return new DenseEmbeddingBitResults(instance.embeddings().subList(0, end)); } else { - List embeddings = new ArrayList<>(instance.embeddings()); + List embeddings = new ArrayList<>(instance.embeddings()); embeddings.add(createRandomEmbedding()); - return new TextEmbeddingBitResults(embeddings); + return new DenseEmbeddingBitResults(embeddings); } } public static Map buildExpectationByte(List> embeddings) { return Map.of( - TextEmbeddingBitResults.TEXT_EMBEDDING_BITS, - embeddings.stream().map(embedding -> Map.of(TextEmbeddingByteResults.Embedding.EMBEDDING, embedding)).toList() + DenseEmbeddingBitResults.TEXT_EMBEDDING_BITS, + embeddings.stream().map(embedding -> Map.of(DenseEmbeddingByteResults.Embedding.EMBEDDING, embedding)).toList() ); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingByteResultsTests.java similarity index 50% rename from x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResultsTests.java rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingByteResultsTests.java index 60f45399cfb32..53cd8690b7dc0 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingByteResultsTests.java @@ -10,7 +10,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; import java.io.IOException; import java.util.ArrayList; @@ -20,19 +20,19 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; -public class TextEmbeddingByteResultsTests extends AbstractWireSerializingTestCase { - public static TextEmbeddingByteResults createRandomResults() { +public class DenseEmbeddingByteResultsTests extends AbstractWireSerializingTestCase { + public static DenseEmbeddingByteResults createRandomResults() { int embeddings = randomIntBetween(1, 10); - List embeddingResults = new ArrayList<>(embeddings); + List embeddingResults = new ArrayList<>(embeddings); for (int i = 0; i < embeddings; i++) { embeddingResults.add(createRandomEmbedding()); } - return new TextEmbeddingByteResults(embeddingResults); + return new DenseEmbeddingByteResults(embeddingResults); } - private static TextEmbeddingByteResults.Embedding createRandomEmbedding() { + private static DenseEmbeddingByteResults.Embedding createRandomEmbedding() { int columns = randomIntBetween(1, 10); byte[] bytes = new byte[columns]; @@ -40,11 +40,11 @@ private static TextEmbeddingByteResults.Embedding createRandomEmbedding() { bytes[i] = randomByte(); } - return new TextEmbeddingByteResults.Embedding(bytes); + return new DenseEmbeddingByteResults.Embedding(bytes); } public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException { - var entity = new TextEmbeddingByteResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }))); + var entity = new DenseEmbeddingByteResults(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }))); String xContentResult = Strings.toString(entity, true, true); assertThat(xContentResult, is(""" @@ -60,10 +60,10 @@ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOE } public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException { - var entity = new TextEmbeddingByteResults( + var entity = new DenseEmbeddingByteResults( List.of( - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }), - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 24 }) + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }), + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 24 }) ) ); @@ -86,10 +86,10 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I } public void testTransformToCoordinationFormat() { - var results = new TextEmbeddingByteResults( + var results = new DenseEmbeddingByteResults( List.of( - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }), - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 }) + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }), + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 }) ) ).transformToCoordinationFormat(); @@ -97,18 +97,18 @@ public void testTransformToCoordinationFormat() { results, is( List.of( - new MlTextEmbeddingResults(TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, new double[] { 23F, 24F }, false), - new MlTextEmbeddingResults(TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, new double[] { 25F, 26F }, false) + new MlDenseEmbeddingResults(DenseEmbeddingByteResults.TEXT_EMBEDDING_BYTES, new double[] { 23F, 24F }, false), + new MlDenseEmbeddingResults(DenseEmbeddingByteResults.TEXT_EMBEDDING_BYTES, new double[] { 25F, 26F }, false) ) ) ); } public void testGetFirstEmbeddingSize() { - var firstEmbeddingSize = new TextEmbeddingByteResults( + var firstEmbeddingSize = new DenseEmbeddingByteResults( List.of( - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }), - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 }) + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }), + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 }) ) ).getFirstEmbeddingSize(); @@ -116,43 +116,43 @@ public void testGetFirstEmbeddingSize() { } public void testEmbeddingMerge() { - TextEmbeddingByteResults.Embedding embedding1 = new TextEmbeddingByteResults.Embedding(new byte[] { 1, 1, -128 }); - TextEmbeddingByteResults.Embedding embedding2 = new TextEmbeddingByteResults.Embedding(new byte[] { 1, 0, 127 }); - TextEmbeddingByteResults.Embedding embedding3 = new TextEmbeddingByteResults.Embedding(new byte[] { 0, 0, 100 }); - TextEmbeddingByteResults.Embedding mergedEmbedding = embedding1.merge(embedding2); - assertThat(mergedEmbedding, equalTo(new TextEmbeddingByteResults.Embedding(new byte[] { 1, 1, 0 }))); + DenseEmbeddingByteResults.Embedding embedding1 = new DenseEmbeddingByteResults.Embedding(new byte[] { 1, 1, -128 }); + DenseEmbeddingByteResults.Embedding embedding2 = new DenseEmbeddingByteResults.Embedding(new byte[] { 1, 0, 127 }); + DenseEmbeddingByteResults.Embedding embedding3 = new DenseEmbeddingByteResults.Embedding(new byte[] { 0, 0, 100 }); + DenseEmbeddingByteResults.Embedding mergedEmbedding = embedding1.merge(embedding2); + assertThat(mergedEmbedding, equalTo(new DenseEmbeddingByteResults.Embedding(new byte[] { 1, 1, 0 }))); mergedEmbedding = mergedEmbedding.merge(embedding3); - assertThat(mergedEmbedding, equalTo(new TextEmbeddingByteResults.Embedding(new byte[] { 1, 0, 33 }))); + assertThat(mergedEmbedding, equalTo(new DenseEmbeddingByteResults.Embedding(new byte[] { 1, 0, 33 }))); } @Override - protected Writeable.Reader instanceReader() { - return TextEmbeddingByteResults::new; + protected Writeable.Reader instanceReader() { + return DenseEmbeddingByteResults::new; } @Override - protected TextEmbeddingByteResults createTestInstance() { + protected DenseEmbeddingByteResults createTestInstance() { return createRandomResults(); } @Override - protected TextEmbeddingByteResults mutateInstance(TextEmbeddingByteResults instance) throws IOException { + protected DenseEmbeddingByteResults mutateInstance(DenseEmbeddingByteResults instance) throws IOException { // if true we reduce the embeddings list by a random amount, if false we add an embedding to the list if (randomBoolean()) { // -1 to remove at least one item from the list int end = randomInt(instance.embeddings().size() - 1); - return new TextEmbeddingByteResults(instance.embeddings().subList(0, end)); + return new DenseEmbeddingByteResults(instance.embeddings().subList(0, end)); } else { - List embeddings = new ArrayList<>(instance.embeddings()); + List embeddings = new ArrayList<>(instance.embeddings()); embeddings.add(createRandomEmbedding()); - return new TextEmbeddingByteResults(embeddings); + return new DenseEmbeddingByteResults(embeddings); } } public static Map buildExpectationByte(List> embeddings) { return Map.of( - TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, - embeddings.stream().map(embedding -> Map.of(TextEmbeddingByteResults.Embedding.EMBEDDING, embedding)).toList() + DenseEmbeddingByteResults.TEXT_EMBEDDING_BYTES, + embeddings.stream().map(embedding -> Map.of(DenseEmbeddingByteResults.Embedding.EMBEDDING, embedding)).toList() ); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingFloatResultsTests.java similarity index 50% rename from x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResultsTests.java rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingFloatResultsTests.java index 8cdd98bcdebc6..4481692127979 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/DenseEmbeddingFloatResultsTests.java @@ -10,7 +10,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; import java.io.IOException; import java.util.ArrayList; @@ -20,30 +20,30 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; -public class TextEmbeddingFloatResultsTests extends AbstractWireSerializingTestCase { - public static TextEmbeddingFloatResults createRandomResults() { +public class DenseEmbeddingFloatResultsTests extends AbstractWireSerializingTestCase { + public static DenseEmbeddingFloatResults createRandomResults() { int embeddings = randomIntBetween(1, 10); - List embeddingResults = new ArrayList<>(embeddings); + List embeddingResults = new ArrayList<>(embeddings); for (int i = 0; i < embeddings; i++) { embeddingResults.add(createRandomEmbedding()); } - return new TextEmbeddingFloatResults(embeddingResults); + return new DenseEmbeddingFloatResults(embeddingResults); } - private static TextEmbeddingFloatResults.Embedding createRandomEmbedding() { + private static DenseEmbeddingFloatResults.Embedding createRandomEmbedding() { int columns = randomIntBetween(1, 10); float[] floats = new float[columns]; for (int i = 0; i < columns; i++) { floats[i] = randomFloat(); } - return new TextEmbeddingFloatResults.Embedding(floats); + return new DenseEmbeddingFloatResults.Embedding(floats); } public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException { - var entity = new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F }))); + var entity = new DenseEmbeddingFloatResults(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F }))); String xContentResult = Strings.toString(entity, true, true); assertThat(xContentResult, is(""" @@ -59,10 +59,10 @@ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOE } public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException { - var entity = new TextEmbeddingFloatResults( + var entity = new DenseEmbeddingFloatResults( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.2F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.2F }) ) ); @@ -86,10 +86,10 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I } public void testTransformToCoordinationFormat() { - var results = new TextEmbeddingFloatResults( + var results = new DenseEmbeddingFloatResults( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.2F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.3F, 0.4F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.2F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.3F, 0.4F }) ) ).transformToCoordinationFormat(); @@ -97,18 +97,18 @@ public void testTransformToCoordinationFormat() { results, is( List.of( - new MlTextEmbeddingResults(TextEmbeddingFloatResults.TEXT_EMBEDDING, new double[] { 0.1F, 0.2F }, false), - new MlTextEmbeddingResults(TextEmbeddingFloatResults.TEXT_EMBEDDING, new double[] { 0.3F, 0.4F }, false) + new MlDenseEmbeddingResults(DenseEmbeddingFloatResults.TEXT_EMBEDDING, new double[] { 0.1F, 0.2F }, false), + new MlDenseEmbeddingResults(DenseEmbeddingFloatResults.TEXT_EMBEDDING, new double[] { 0.3F, 0.4F }, false) ) ) ); } public void testGetFirstEmbeddingSize() { - var firstEmbeddingSize = new TextEmbeddingFloatResults( + var firstEmbeddingSize = new DenseEmbeddingFloatResults( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.2F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.3F, 0.4F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.2F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.3F, 0.4F }) ) ).getFirstEmbeddingSize(); @@ -116,51 +116,54 @@ public void testGetFirstEmbeddingSize() { } public void testEmbeddingMerge() { - TextEmbeddingFloatResults.Embedding embedding1 = new TextEmbeddingFloatResults.Embedding(new float[] { 0.1f, 0.2f, 0.3f, 0.4f }); - TextEmbeddingFloatResults.Embedding embedding2 = new TextEmbeddingFloatResults.Embedding(new float[] { 0.0f, 0.4f, 0.1f, 1.0f }); - TextEmbeddingFloatResults.Embedding embedding3 = new TextEmbeddingFloatResults.Embedding(new float[] { 0.2f, 0.9f, 0.8f, 0.1f }); - TextEmbeddingFloatResults.Embedding mergedEmbedding = embedding1.merge(embedding2); - assertThat(mergedEmbedding, equalTo(new TextEmbeddingFloatResults.Embedding(new float[] { 0.05f, 0.3f, 0.2f, 0.7f }))); + DenseEmbeddingFloatResults.Embedding embedding1 = new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1f, 0.2f, 0.3f, 0.4f }); + DenseEmbeddingFloatResults.Embedding embedding2 = new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0f, 0.4f, 0.1f, 1.0f }); + DenseEmbeddingFloatResults.Embedding embedding3 = new DenseEmbeddingFloatResults.Embedding(new float[] { 0.2f, 0.9f, 0.8f, 0.1f }); + DenseEmbeddingFloatResults.Embedding mergedEmbedding = embedding1.merge(embedding2); + assertThat(mergedEmbedding, equalTo(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.05f, 0.3f, 0.2f, 0.7f }))); mergedEmbedding = mergedEmbedding.merge(embedding3); - assertThat(mergedEmbedding, equalTo(new TextEmbeddingFloatResults.Embedding(new float[] { 0.1f, 0.5f, 0.4f, 0.5f }))); + assertThat(mergedEmbedding, equalTo(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1f, 0.5f, 0.4f, 0.5f }))); } @Override - protected Writeable.Reader instanceReader() { - return TextEmbeddingFloatResults::new; + protected Writeable.Reader instanceReader() { + return DenseEmbeddingFloatResults::new; } @Override - protected TextEmbeddingFloatResults createTestInstance() { + protected DenseEmbeddingFloatResults createTestInstance() { return createRandomResults(); } @Override - protected TextEmbeddingFloatResults mutateInstance(TextEmbeddingFloatResults instance) throws IOException { + protected DenseEmbeddingFloatResults mutateInstance(DenseEmbeddingFloatResults instance) throws IOException { // if true we reduce the embeddings list by a random amount, if false we add an embedding to the list if (randomBoolean()) { // -1 to remove at least one item from the list int end = randomInt(instance.embeddings().size() - 1); - return new TextEmbeddingFloatResults(instance.embeddings().subList(0, end)); + return new DenseEmbeddingFloatResults(instance.embeddings().subList(0, end)); } else { - List embeddings = new ArrayList<>(instance.embeddings()); + List embeddings = new ArrayList<>(instance.embeddings()); embeddings.add(createRandomEmbedding()); - return new TextEmbeddingFloatResults(embeddings); + return new DenseEmbeddingFloatResults(embeddings); } } public static Map buildExpectationFloat(List embeddings) { - return Map.of(TextEmbeddingFloatResults.TEXT_EMBEDDING, embeddings.stream().map(TextEmbeddingFloatResults.Embedding::new).toList()); + return Map.of( + DenseEmbeddingFloatResults.TEXT_EMBEDDING, + embeddings.stream().map(DenseEmbeddingFloatResults.Embedding::new).toList() + ); } public static Map buildExpectationByte(List embeddings) { return Map.of( - TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, - embeddings.stream().map(TextEmbeddingByteResults.Embedding::new).toList() + DenseEmbeddingByteResults.TEXT_EMBEDDING_BYTES, + embeddings.stream().map(DenseEmbeddingByteResults.Embedding::new).toList() ); } public static Map buildExpectationBinary(List embeddings) { - return Map.of("text_embedding_bits", embeddings.stream().map(TextEmbeddingByteResults.Embedding::new).toList()); + return Map.of("text_embedding_bits", embeddings.stream().map(DenseEmbeddingByteResults.Embedding::new).toList()); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java index 87049d6bde90c..b897dc8aef6dd 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java @@ -17,8 +17,8 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults; import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResultsTests; -import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResultsTests; +import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.NerResults; import org.elasticsearch.xpack.core.ml.inference.results.NerResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults; @@ -50,7 +50,7 @@ public class InferModelActionResponseTests extends AbstractWireSerializingTestCa PyTorchPassThroughResults.NAME, QuestionAnsweringInferenceResults.NAME, RegressionInferenceResults.NAME, - MlTextEmbeddingResults.NAME, + MlDenseEmbeddingResults.NAME, TextExpansionResults.NAME, TextSimilarityInferenceResults.NAME, WarningInferenceResults.NAME @@ -87,7 +87,7 @@ private static InferenceResults randomInferenceResult(String resultType) { case PyTorchPassThroughResults.NAME -> PyTorchPassThroughResultsTests.createRandomResults(); case QuestionAnsweringInferenceResults.NAME -> QuestionAnsweringInferenceResultsTests.createRandomResults(); case RegressionInferenceResults.NAME -> RegressionInferenceResultsTests.createRandomResults(); - case MlTextEmbeddingResults.NAME -> MlTextEmbeddingResultsTests.createRandomResults(); + case MlDenseEmbeddingResults.NAME -> MlDenseEmbeddingResultsTests.createRandomResults(); case TextExpansionResults.NAME -> TextExpansionResultsTests.createRandomResults(); case TextSimilarityInferenceResults.NAME -> TextSimilarityInferenceResultsTests.createRandomResults(); case WarningInferenceResults.NAME -> WarningInferenceResultsTests.createRandomResults(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentResponseTests.java index 62d26ead2b641..80d905fe1d83f 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentResponseTests.java @@ -13,7 +13,7 @@ import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; -import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResultsTests; +import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResultsTests; import org.junit.Before; import java.util.List; @@ -49,10 +49,10 @@ protected Writeable.Reader instanceR protected InferTrainedModelDeploymentAction.Response createTestInstance() { return new InferTrainedModelDeploymentAction.Response( List.of( - MlTextEmbeddingResultsTests.createRandomResults(), - MlTextEmbeddingResultsTests.createRandomResults(), - MlTextEmbeddingResultsTests.createRandomResults(), - MlTextEmbeddingResultsTests.createRandomResults() + MlDenseEmbeddingResultsTests.createRandomResults(), + MlDenseEmbeddingResultsTests.createRandomResults(), + MlDenseEmbeddingResultsTests.createRandomResults(), + MlDenseEmbeddingResultsTests.createRandomResults() ) ); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/MlTextEmbeddingResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/MlDenseEmbeddingResultsTests.java similarity index 68% rename from x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/MlTextEmbeddingResultsTests.java rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/MlDenseEmbeddingResultsTests.java index 3338609eebdc3..b5aee244771ea 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/MlTextEmbeddingResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/MlDenseEmbeddingResultsTests.java @@ -16,35 +16,35 @@ import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; -public class MlTextEmbeddingResultsTests extends InferenceResultsTestCase { +public class MlDenseEmbeddingResultsTests extends InferenceResultsTestCase { - public static MlTextEmbeddingResults createRandomResults() { + public static MlDenseEmbeddingResults createRandomResults() { int columns = randomIntBetween(1, 10); double[] arr = new double[columns]; for (int i = 0; i < columns; i++) { arr[i] = randomDouble(); } - return new MlTextEmbeddingResults(DEFAULT_RESULTS_FIELD, arr, randomBoolean()); + return new MlDenseEmbeddingResults(DEFAULT_RESULTS_FIELD, arr, randomBoolean()); } @Override - protected Writeable.Reader instanceReader() { - return MlTextEmbeddingResults::new; + protected Writeable.Reader instanceReader() { + return MlDenseEmbeddingResults::new; } @Override - protected MlTextEmbeddingResults createTestInstance() { + protected MlDenseEmbeddingResults createTestInstance() { return createRandomResults(); } @Override - protected MlTextEmbeddingResults mutateInstance(MlTextEmbeddingResults instance) { + protected MlDenseEmbeddingResults mutateInstance(MlDenseEmbeddingResults instance) { return null;// TODO implement https://github.com/elastic/elasticsearch/issues/25929 } public void testAsMap() { - MlTextEmbeddingResults testInstance = createTestInstance(); + MlDenseEmbeddingResults testInstance = createTestInstance(); Map asMap = testInstance.asMap(); int size = testInstance.isTruncated ? 2 : 1; assertThat(asMap.keySet(), hasSize(size)); @@ -55,7 +55,7 @@ public void testAsMap() { } @Override - void assertFieldValues(MlTextEmbeddingResults createdInstance, IngestDocument document, String parentField, String resultsField) { + void assertFieldValues(MlDenseEmbeddingResults createdInstance, IngestDocument document, String parentField, String resultsField) { assertArrayEquals(document.getFieldValue(parentField + resultsField, double[].class), createdInstance.getInference(), 1e-10); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java index a2b0a32e77b05..6cc27f931cfc8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java @@ -12,14 +12,14 @@ import org.elasticsearch.compute.data.Page; import org.elasticsearch.core.Releasables; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingResults; import org.elasticsearch.xpack.esql.inference.InferenceOperator; /** * {@link TextEmbeddingOperatorOutputBuilder} builds the output page for text embedding by converting - * {@link TextEmbeddingResults} into a {@link FloatBlock} containing dense vector embeddings. + * {@link DenseEmbeddingResults} into a {@link FloatBlock} containing dense vector embeddings. */ class TextEmbeddingOperatorOutputBuilder implements InferenceOperator.OutputBuilder { private final Page inputPage; @@ -39,7 +39,7 @@ public void close() { * Adds an inference response to the output builder. * *

- * If the response is null or not of type {@link TextEmbeddingResults} an {@link IllegalStateException} is thrown. + * If the response is null or not of type {@link DenseEmbeddingResults} an {@link IllegalStateException} is thrown. * Else, the embedding vector is added to the output block as a multi-value position. *

* @@ -55,7 +55,7 @@ public void addInferenceResponse(InferenceAction.Response inferenceResponse) { return; } - TextEmbeddingResults embeddingResults = inferenceResults(inferenceResponse); + DenseEmbeddingResults embeddingResults = inferenceResults(inferenceResponse); var embeddings = embeddingResults.embeddings(); if (embeddings.isEmpty()) { @@ -82,21 +82,25 @@ public Page buildOutput() { return inputPage.appendBlock(outputBlock); } - private TextEmbeddingResults inferenceResults(InferenceAction.Response inferenceResponse) { - return InferenceOperator.OutputBuilder.inferenceResults(inferenceResponse, TextEmbeddingResults.class); + private DenseEmbeddingResults inferenceResults(InferenceAction.Response inferenceResponse) { + return InferenceOperator.OutputBuilder.inferenceResults(inferenceResponse, DenseEmbeddingResults.class); } /** * Extracts the embedding as a float array from the embedding result. */ - private static float[] getEmbeddingAsFloatArray(TextEmbeddingResults embedding) { + private static float[] getEmbeddingAsFloatArray(DenseEmbeddingResults embedding) { return switch (embedding.embeddings().get(0)) { - case TextEmbeddingFloatResults.Embedding floatEmbedding -> floatEmbedding.values(); - case TextEmbeddingByteResults.Embedding byteEmbedding -> toFloatArray(byteEmbedding.values()); + case DenseEmbeddingFloatResults.Embedding floatEmbedding -> floatEmbedding.values(); + case DenseEmbeddingByteResults.Embedding byteEmbedding -> toFloatArray(byteEmbedding.values()); default -> throw new IllegalArgumentException( "Unsupported embedding type: " + embedding.embeddings().get(0).getClass().getName() - + ". Expected TextEmbeddingFloatResults.Embedding or TextEmbeddingByteResults.Embedding." + + ". Expected " + + DenseEmbeddingFloatResults.Embedding.class.getName() + + " or " + + DenseEmbeddingByteResults.Embedding.class.getName() + + "." ); }; } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java index 875c136ab861c..ac847f50ce353 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java @@ -16,7 +16,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FoldContext; @@ -250,7 +250,7 @@ private float[] randomEmbedding(int length) { } private InferenceAction.Response inferenceResponse(float[] embedding) { - TextEmbeddingFloatResults.Embedding embeddingResult = new TextEmbeddingFloatResults.Embedding(embedding); - return new InferenceAction.Response(new TextEmbeddingFloatResults(List.of(embeddingResult))); + DenseEmbeddingFloatResults.Embedding embeddingResult = new DenseEmbeddingFloatResults.Embedding(embedding); + return new InferenceAction.Response(new DenseEmbeddingFloatResults(List.of(embeddingResult))); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilderTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilderTests.java index ea77c6bed3c38..ac1fb5daa36a5 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilderTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilderTests.java @@ -15,8 +15,8 @@ import org.elasticsearch.compute.test.RandomBlock; import org.elasticsearch.core.Releasables; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import java.util.List; @@ -182,7 +182,7 @@ private void assertByteEmbeddingContent(FloatBlock block, byte[][] expectedByteE int firstValueIndex = block.getFirstValueIndex(currentPos); for (int i = 0; i < expectedByteEmbeddings[currentPos].length; i++) { float actualValue = block.getFloat(firstValueIndex + i); - // Convert byte to float the same way as TextEmbeddingByteResults.Embedding.toFloatArray() + // Convert byte to float the same way as DenseEmbeddingByteResults.Embedding.toFloatArray() float expectedValue = expectedByteEmbeddings[currentPos][i]; assertThat(actualValue, equalTo(expectedValue)); } @@ -206,20 +206,20 @@ private byte[] randomByteEmbedding(int dimension) { } private static InferenceAction.Response createFloatEmbeddingResponse(float[] embedding) { - var embeddingResult = new TextEmbeddingFloatResults.Embedding(embedding); - var textEmbeddingResults = new TextEmbeddingFloatResults(List.of(embeddingResult)); - return new InferenceAction.Response(textEmbeddingResults); + var embeddingResult = new DenseEmbeddingFloatResults.Embedding(embedding); + var denseEmbeddingResults = new DenseEmbeddingFloatResults(List.of(embeddingResult)); + return new InferenceAction.Response(denseEmbeddingResults); } private static InferenceAction.Response createByteEmbeddingResponse(byte[] embedding) { - var embeddingResult = new TextEmbeddingByteResults.Embedding(embedding); - var textEmbeddingResults = new TextEmbeddingByteResults(List.of(embeddingResult)); - return new InferenceAction.Response(textEmbeddingResults); + var embeddingResult = new DenseEmbeddingByteResults.Embedding(embedding); + var denseEmbeddingResults = new DenseEmbeddingByteResults(List.of(embeddingResult)); + return new InferenceAction.Response(denseEmbeddingResults); } private static InferenceAction.Response createEmptyFloatEmbeddingResponse() { - var textEmbeddingResults = new TextEmbeddingFloatResults(List.of()); - return new InferenceAction.Response(textEmbeddingResults); + var denseEmbeddingResults = new DenseEmbeddingFloatResults(List.of()); + return new InferenceAction.Response(denseEmbeddingResults); } private Page randomInputPage(int positionCount, int columnCount) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorTests.java index 6ff9a90b70b16..06441be9e7148 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorTests.java @@ -13,7 +13,7 @@ import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.esql.inference.InferenceOperatorTestCase; import org.hamcrest.Matcher; import org.junit.Before; @@ -23,7 +23,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; -public class TextEmbeddingOperatorTests extends InferenceOperatorTestCase { +public class TextEmbeddingOperatorTests extends InferenceOperatorTestCase { private static final String SIMPLE_INFERENCE_ID = "test_text_embedding"; private static final int EMBEDDING_DIMENSION = 384; // Common embedding dimension @@ -89,15 +89,15 @@ private void assertTextEmbeddingResults(Page inputPage, Page resultPage) { } @Override - protected TextEmbeddingFloatResults mockInferenceResult(InferenceAction.Request request) { + protected DenseEmbeddingFloatResults mockInferenceResult(InferenceAction.Request request) { // For text embedding, we expect one input text per request String inputText = request.getInput().get(0); // Generate a deterministic mock embedding based on the input text float[] mockEmbedding = generateMockEmbedding(inputText, EMBEDDING_DIMENSION); - var embeddingResult = new TextEmbeddingFloatResults.Embedding(mockEmbedding); - return new TextEmbeddingFloatResults(List.of(embeddingResult)); + var embeddingResult = new DenseEmbeddingFloatResults.Embedding(mockEmbedding); + return new DenseEmbeddingFloatResults(List.of(embeddingResult)); } @Override diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 051b6dbf3e8fa..45ecb3dedf3f1 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -36,7 +36,7 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -167,22 +167,22 @@ public void chunkedInfer( } } - private TextEmbeddingFloatResults makeResults(List input, ServiceSettings serviceSettings) { - List embeddings = new ArrayList<>(); + private DenseEmbeddingFloatResults makeResults(List input, ServiceSettings serviceSettings) { + List embeddings = new ArrayList<>(); for (String inputString : input) { List floatEmbeddings = generateEmbedding(inputString, serviceSettings.dimensions(), serviceSettings.elementType()); - embeddings.add(TextEmbeddingFloatResults.Embedding.of(floatEmbeddings)); + embeddings.add(DenseEmbeddingFloatResults.Embedding.of(floatEmbeddings)); } - return new TextEmbeddingFloatResults(embeddings); + return new DenseEmbeddingFloatResults(embeddings); } private List makeChunkedResults(List inputs, ServiceSettings serviceSettings) { var results = new ArrayList(); for (ChunkInferenceInput input : inputs) { List chunkedInput = chunkInputs(input); - List chunks = chunkedInput.stream() + List chunks = chunkedInput.stream() .map( - c -> new TextEmbeddingFloatResults.Chunk( + c -> new DenseEmbeddingFloatResults.Chunk( makeResults(List.of(c.input()), serviceSettings).embeddings().get(0), new ChunkedInference.TextOffset(c.startOffset(), c.endOffset()) ) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index 28a191a1bbfac..9ea2301abfa0c 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -35,9 +35,9 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.DequeUtils; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import java.io.IOException; import java.util.ArrayList; @@ -138,9 +138,9 @@ public void infer( ) ); } else { - // Return text embedding results when creating a sparse_embedding inference endpoint to allow creation validation to - // pass. This is required to test that streaming fails for a sparse_embedding endpoint. - listener.onResponse(makeTextEmbeddingResults(input)); + // Return dense embedding results when creating a sparse_embedding inference endpoint to allow creation validation + // to pass. This is required to test that streaming fails for a sparse_embedding endpoint. + listener.onResponse(makeDenseEmbeddingResults(input)); } } default -> listener.onFailure( @@ -189,16 +189,16 @@ public void cancel() {} }); } - private TextEmbeddingFloatResults makeTextEmbeddingResults(List input) { - var embeddings = new ArrayList(); + private DenseEmbeddingFloatResults makeDenseEmbeddingResults(List input) { + var embeddings = new ArrayList(); for (int i = 0; i < input.size(); i++) { var values = new float[5]; for (int j = 0; j < 5; j++) { values[j] = random.nextFloat(); } - embeddings.add(new TextEmbeddingFloatResults.Embedding(values)); + embeddings.add(new DenseEmbeddingFloatResults.Embedding(values)); } - return new TextEmbeddingFloatResults(embeddings); + return new DenseEmbeddingFloatResults(embeddings); } private InferenceServiceResults.Result completionChunk(String delta) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index b1887cff763fa..f227140e431bd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -20,12 +20,12 @@ import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.chunking.NoneChunkingSettings; import org.elasticsearch.xpack.inference.chunking.RecursiveChunkingSettings; @@ -71,10 +71,10 @@ import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings; import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.DenseEmbeddingResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser; -import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings; @@ -198,7 +198,11 @@ private static void addCustomNamedWriteables(List namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, CustomSecretSettings.NAME, CustomSecretSettings::new)); namedWriteables.add( - new NamedWriteableRegistry.Entry(CustomResponseParser.class, TextEmbeddingResponseParser.NAME, TextEmbeddingResponseParser::new) + new NamedWriteableRegistry.Entry( + CustomResponseParser.class, + DenseEmbeddingResponseParser.NAME, + DenseEmbeddingResponseParser::new + ) ); namedWriteables.add( @@ -649,10 +653,14 @@ private static void addInferenceResultsNamedWriteables(List * */ - public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { + public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) { - return EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults(); + return EmbeddingFloatResult.PARSER.apply(p, null).toDenseEmbeddingFloatResults(); } } @@ -81,9 +81,9 @@ public record EmbeddingFloatResult(List embeddingResu }, new ParseField("data"), org.elasticsearch.xcontent.ObjectParser.ValueType.OBJECT_ARRAY); } - public TextEmbeddingFloatResults toTextEmbeddingFloatResults() { - return new TextEmbeddingFloatResults( - embeddingResults.stream().map(entry -> TextEmbeddingFloatResults.Embedding.of(entry.embedding)).toList() + public DenseEmbeddingFloatResults toDenseEmbeddingFloatResults() { + return new DenseEmbeddingFloatResults( + embeddingResults.stream().map(entry -> DenseEmbeddingFloatResults.Embedding.of(entry.embedding)).toList() ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 1a8b162eb1b46..1e9a66c776b53 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -80,7 +80,7 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter; @@ -967,13 +967,13 @@ yield new SparseVectorQueryBuilder( ); } case TEXT_EMBEDDING -> { - if (inferenceResults instanceof MlTextEmbeddingResults == false) { + if (inferenceResults instanceof MlDenseEmbeddingResults == false) { throw new IllegalArgumentException( - generateQueryInferenceResultsTypeMismatchMessage(inferenceResults, MlTextEmbeddingResults.NAME) + generateQueryInferenceResultsTypeMismatchMessage(inferenceResults, MlDenseEmbeddingResults.NAME) ); } - MlTextEmbeddingResults textEmbeddingResults = (MlTextEmbeddingResults) inferenceResults; + MlDenseEmbeddingResults textEmbeddingResults = (MlDenseEmbeddingResults) inferenceResults; float[] inference = textEmbeddingResults.getInferenceAsFloat(); int dimensions = modelSettings.elementType() == DenseVectorFieldMapper.ElementType.BIT ? inference.length * Byte.SIZE // Bit vectors encode 8 dimensions into each byte value diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java index 210ab2b67f9c9..cd590b8410234 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java @@ -25,7 +25,7 @@ import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; import org.elasticsearch.search.vectors.QueryVectorBuilder; import org.elasticsearch.search.vectors.VectorData; -import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; @@ -195,7 +195,7 @@ private QueryBuilder querySemanticTextField(String clusterAlias, SemanticTextFie fullyQualifiedInferenceId = new FullyQualifiedInferenceId(clusterAlias, semanticTextFieldType.getSearchInferenceId()); } - MlTextEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(fullyQualifiedInferenceId); + MlDenseEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(fullyQualifiedInferenceId); queryVector = new VectorData(textEmbeddingResults.getInferenceAsFloat()); } @@ -226,7 +226,7 @@ private QueryBuilder queryNonSemanticTextField() { throw new IllegalStateException("No query vector or query vector builder model ID specified"); } - MlTextEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(fullyQualifiedInferenceId); + MlDenseEmbeddingResults textEmbeddingResults = getTextEmbeddingResults(fullyQualifiedInferenceId); queryVector = new VectorData(textEmbeddingResults.getInferenceAsFloat()); } @@ -244,20 +244,20 @@ private QueryBuilder queryNonSemanticTextField() { return knnQuery; } - private MlTextEmbeddingResults getTextEmbeddingResults(FullyQualifiedInferenceId fullyQualifiedInferenceId) { + private MlDenseEmbeddingResults getTextEmbeddingResults(FullyQualifiedInferenceId fullyQualifiedInferenceId) { InferenceResults inferenceResults = inferenceResultsMap.get(fullyQualifiedInferenceId); if (inferenceResults == null) { throw new IllegalStateException("Could not find inference results from inference endpoint [" + fullyQualifiedInferenceId + "]"); - } else if (inferenceResults instanceof MlTextEmbeddingResults == false) { + } else if (inferenceResults instanceof MlDenseEmbeddingResults == false) { throw new IllegalArgumentException( "Expected query inference results to be of type [" - + MlTextEmbeddingResults.NAME + + MlDenseEmbeddingResults.NAME + "], got [" + inferenceResults.getWriteableName() + "]. Are you specifying a compatible inference endpoint? Has the inference endpoint configuration changed?" ); } - return (MlTextEmbeddingResults) inferenceResults; + return (MlDenseEmbeddingResults) inferenceResults; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index fddcf75ce3ba3..52abe808ee4f2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -33,7 +33,7 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.inference.InferenceException; @@ -499,7 +499,7 @@ private static InferenceResults validateAndConvertInferenceResults( InferenceResults inferenceResults = inferenceResultsList.getFirst(); if (inferenceResults instanceof TextExpansionResults == false - && inferenceResults instanceof MlTextEmbeddingResults == false + && inferenceResults instanceof MlDenseEmbeddingResults == false && inferenceResults instanceof ErrorInferenceResults == false && inferenceResults instanceof WarningInferenceResults == false) { return new ErrorInferenceResults( @@ -507,7 +507,7 @@ private static InferenceResults validateAndConvertInferenceResults( "Expected query inference results to be of type [" + TextExpansionResults.NAME + "] or [" - + MlTextEmbeddingResults.NAME + + MlDenseEmbeddingResults.NAME + "], got [" + inferenceResults.getWriteableName() + "]. Has the inference endpoint [" diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntity.java index 4e73f03e2898b..0a314202c922c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntity.java @@ -9,7 +9,7 @@ import org.elasticsearch.common.xcontent.XContentParserUtils; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -70,20 +70,20 @@ public class AlibabaCloudSearchEmbeddingsResponseEntity extends AlibabaCloudSear * * */ - public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { + public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { return fromResponse(request, response, parser -> { positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE); - List embeddingList = XContentParserUtils.parseList( + List embeddingList = XContentParserUtils.parseList( parser, AlibabaCloudSearchEmbeddingsResponseEntity::parseEmbeddingObject ); - return new TextEmbeddingFloatResults(embeddingList); + return new DenseEmbeddingFloatResults(embeddingList); }); } - private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException { + private static DenseEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException { XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); @@ -95,7 +95,7 @@ private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContent // if there are additional fields within this object, lets skip them, so we can begin parsing the next embedding array parser.skipChildren(); - return TextEmbeddingFloatResults.Embedding.of(embeddingValues); + return DenseEmbeddingFloatResults.Embedding.of(embeddingValues); } private static float parseEmbeddingList(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java index 831bf9938c211..61ec5f0c39790 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java @@ -16,7 +16,7 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.response.XContentUtils; import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; import org.elasticsearch.xpack.inference.services.amazonbedrock.request.AmazonBedrockRequest; @@ -48,7 +48,7 @@ public InferenceServiceResults accept(AmazonBedrockRequest request) { throw new ElasticsearchException("unexpected request type [" + request.getClass() + "]"); } - public static TextEmbeddingFloatResults fromResponse(InvokeModelResponse response, AmazonBedrockProvider provider) { + public static DenseEmbeddingFloatResults fromResponse(InvokeModelResponse response, AmazonBedrockProvider provider) { var charset = StandardCharsets.UTF_8; var bodyText = String.valueOf(charset.decode(response.body().asByteBuffer())); @@ -63,13 +63,13 @@ public static TextEmbeddingFloatResults fromResponse(InvokeModelResponse respons var embeddingList = parseEmbeddings(jsonParser, provider); - return new TextEmbeddingFloatResults(embeddingList); + return new DenseEmbeddingFloatResults(embeddingList); } catch (IOException e) { throw new ElasticsearchException(e); } } - private static List parseEmbeddings(XContentParser jsonParser, AmazonBedrockProvider provider) + private static List parseEmbeddings(XContentParser jsonParser, AmazonBedrockProvider provider) throws IOException { switch (provider) { case AMAZONTITAN -> { @@ -82,7 +82,7 @@ private static List parseEmbeddings(XConten } } - private static List parseTitanEmbeddings(XContentParser parser) throws IOException { + private static List parseTitanEmbeddings(XContentParser parser) throws IOException { /* Titan response: { @@ -92,11 +92,11 @@ private static List parseTitanEmbeddings(XC */ positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); List embeddingValuesList = parseList(parser, XContentUtils::parseFloat); - var embeddingValues = TextEmbeddingFloatResults.Embedding.of(embeddingValuesList); + var embeddingValues = DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList); return List.of(embeddingValues); } - private static List parseCohereEmbeddings(XContentParser parser) throws IOException { + private static List parseCohereEmbeddings(XContentParser parser) throws IOException { /* Cohere response: { @@ -111,7 +111,7 @@ private static List parseCohereEmbeddings(X */ positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE); - List embeddingList = parseList( + List embeddingList = parseList( parser, AmazonBedrockEmbeddingsResponse::parseCohereEmbeddingsListItem ); @@ -119,9 +119,9 @@ private static List parseCohereEmbeddings(X return embeddingList; } - private static TextEmbeddingFloatResults.Embedding parseCohereEmbeddingsListItem(XContentParser parser) throws IOException { + private static DenseEmbeddingFloatResults.Embedding parseCohereEmbeddingsListItem(XContentParser parser) throws IOException { List embeddingValuesList = parseList(parser, XContentUtils::parseFloat); - return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList); + return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntity.java index b4a2e142b3792..a5a174e80b289 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntity.java @@ -15,9 +15,9 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.response.XContentUtils; @@ -189,20 +189,20 @@ private static InferenceServiceResults parseBitEmbeddingsArray(XContentParser pa // Cohere returns array of binary embeddings encoded as bytes with int8 precision so we can reuse the byte parser var embeddingList = parseList(parser, CohereEmbeddingsResponseEntity::parseByteArrayEntry); - return new TextEmbeddingBitResults(embeddingList); + return new DenseEmbeddingBitResults(embeddingList); } private static InferenceServiceResults parseByteEmbeddingsArray(XContentParser parser) throws IOException { var embeddingList = parseList(parser, CohereEmbeddingsResponseEntity::parseByteArrayEntry); - return new TextEmbeddingByteResults(embeddingList); + return new DenseEmbeddingByteResults(embeddingList); } - private static TextEmbeddingByteResults.Embedding parseByteArrayEntry(XContentParser parser) throws IOException { + private static DenseEmbeddingByteResults.Embedding parseByteArrayEntry(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); List embeddingValuesList = parseList(parser, CohereEmbeddingsResponseEntity::parseEmbeddingInt8Entry); - return TextEmbeddingByteResults.Embedding.of(embeddingValuesList); + return DenseEmbeddingByteResults.Embedding.of(embeddingValuesList); } private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException { @@ -223,13 +223,13 @@ private static void checkByteBounds(short value) { private static InferenceServiceResults parseFloatEmbeddingsArray(XContentParser parser) throws IOException { var embeddingList = parseList(parser, CohereEmbeddingsResponseEntity::parseFloatArrayEntry); - return new TextEmbeddingFloatResults(embeddingList); + return new DenseEmbeddingFloatResults(embeddingList); } - private static TextEmbeddingFloatResults.Embedding parseFloatArrayEntry(XContentParser parser) throws IOException { + private static DenseEmbeddingFloatResults.Embedding parseFloatArrayEntry(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); List embeddingValuesList = parseList(parser, XContentUtils::parseFloat); - return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList); + return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList); } private CohereEmbeddingsResponseEntity() {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java index 5986b2104ffb1..44901b49a7e47 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java @@ -25,10 +25,10 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.DenseEmbeddingResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser; -import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -501,7 +501,7 @@ private static CustomResponseParser extractResponseParser( } return switch (taskType) { - case TEXT_EMBEDDING -> TextEmbeddingResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException); + case TEXT_EMBEDDING -> DenseEmbeddingResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException); case SPARSE_EMBEDDING -> SparseEmbeddingResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException); case RERANK -> RerankResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException); case COMPLETION -> CompletionResponseParser.fromMap(responseParserMap, RESPONSE_SCOPE, validationException); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/DenseEmbeddingResponseParser.java similarity index 80% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/DenseEmbeddingResponseParser.java index f665c0be81511..18aaf186b4a3f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/DenseEmbeddingResponseParser.java @@ -14,9 +14,9 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.common.MapPathExtractor; import org.elasticsearch.xpack.inference.services.custom.CustomServiceEmbeddingType; @@ -31,8 +31,8 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER; -public class TextEmbeddingResponseParser extends BaseCustomResponseParser { - +public class DenseEmbeddingResponseParser extends BaseCustomResponseParser { + // This name is a holdover from before this class was renamed public static final String NAME = "text_embedding_response_parser"; public static final String TEXT_EMBEDDING_PARSER_EMBEDDINGS = "text_embeddings"; public static final String EMBEDDING_TYPE = "embedding_type"; @@ -41,7 +41,7 @@ public class TextEmbeddingResponseParser extends BaseCustomResponseParser { "ml_inference_custom_service_embedding_type" ); - public static TextEmbeddingResponseParser fromMap( + public static DenseEmbeddingResponseParser fromMap( Map responseParserMap, String scope, ValidationException validationException @@ -70,18 +70,18 @@ public static TextEmbeddingResponseParser fromMap( throw validationException; } - return new TextEmbeddingResponseParser(path, embeddingType); + return new DenseEmbeddingResponseParser(path, embeddingType); } private final String textEmbeddingsPath; private final CustomServiceEmbeddingType embeddingType; - public TextEmbeddingResponseParser(String textEmbeddingsPath, CustomServiceEmbeddingType embeddingType) { + public DenseEmbeddingResponseParser(String textEmbeddingsPath, CustomServiceEmbeddingType embeddingType) { this.textEmbeddingsPath = Objects.requireNonNull(textEmbeddingsPath); this.embeddingType = Objects.requireNonNull(embeddingType); } - public TextEmbeddingResponseParser(StreamInput in) throws IOException { + public DenseEmbeddingResponseParser(StreamInput in) throws IOException { this.textEmbeddingsPath = in.readString(); if (in.getTransportVersion().supports(ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE)) { this.embeddingType = in.readEnum(CustomServiceEmbeddingType.class); @@ -122,7 +122,7 @@ public CustomServiceEmbeddingType getEmbeddingType() { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - TextEmbeddingResponseParser that = (TextEmbeddingResponseParser) o; + DenseEmbeddingResponseParser that = (DenseEmbeddingResponseParser) o; return Objects.equals(textEmbeddingsPath, that.textEmbeddingsPath) && Objects.equals(embeddingType, that.embeddingType); } @@ -174,7 +174,7 @@ private interface EmbeddingConverter { private static class FloatEmbeddings implements EmbeddingConverter { - private final List embeddings; + private final List embeddings; FloatEmbeddings() { this.embeddings = new ArrayList<>(); @@ -182,17 +182,17 @@ private static class FloatEmbeddings implements EmbeddingConverter { public void toEmbedding(Object entry, String fieldName) { var embeddingsAsListFloats = convertToListOfFloats(entry, fieldName); - embeddings.add(TextEmbeddingFloatResults.Embedding.of(embeddingsAsListFloats)); + embeddings.add(DenseEmbeddingFloatResults.Embedding.of(embeddingsAsListFloats)); } - public TextEmbeddingFloatResults getResults() { - return new TextEmbeddingFloatResults(embeddings); + public DenseEmbeddingFloatResults getResults() { + return new DenseEmbeddingFloatResults(embeddings); } } private static class ByteEmbeddings implements EmbeddingConverter { - private final List embeddings; + private final List embeddings; ByteEmbeddings() { this.embeddings = new ArrayList<>(); @@ -200,17 +200,17 @@ private static class ByteEmbeddings implements EmbeddingConverter { public void toEmbedding(Object entry, String fieldName) { var convertedEmbeddings = convertToListOfBytes(entry, fieldName); - this.embeddings.add(TextEmbeddingByteResults.Embedding.of(convertedEmbeddings)); + this.embeddings.add(DenseEmbeddingByteResults.Embedding.of(convertedEmbeddings)); } - public TextEmbeddingByteResults getResults() { - return new TextEmbeddingByteResults(embeddings); + public DenseEmbeddingByteResults getResults() { + return new DenseEmbeddingByteResults(embeddings); } } private static class BitEmbeddings implements EmbeddingConverter { - private final List embeddings; + private final List embeddings; BitEmbeddings() { this.embeddings = new ArrayList<>(); @@ -218,11 +218,11 @@ private static class BitEmbeddings implements EmbeddingConverter { public void toEmbedding(Object entry, String fieldName) { var convertedEmbeddings = convertToListOfBits(entry, fieldName); - this.embeddings.add(TextEmbeddingByteResults.Embedding.of(convertedEmbeddings)); + this.embeddings.add(DenseEmbeddingByteResults.Embedding.of(convertedEmbeddings)); } - public TextEmbeddingBitResults getResults() { - return new TextEmbeddingBitResults(embeddings); + public DenseEmbeddingBitResults getResults() { + return new DenseEmbeddingBitResults(embeddings); } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 2c1ee96b519a3..646ee520de83f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -38,15 +38,15 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.XPackSettings; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; @@ -642,7 +642,7 @@ public void inferTextEmbedding( ); ActionListener mlResultsListener = listener.delegateFailureAndWrap( - (l, inferenceResult) -> l.onResponse(TextEmbeddingFloatResults.of(inferenceResult.getInferenceResults())) + (l, inferenceResult) -> l.onResponse(DenseEmbeddingFloatResults.of(inferenceResult.getInferenceResults())) ); var maybeDeployListener = mlResultsListener.delegateResponse( @@ -772,22 +772,22 @@ private static void translateToChunkedResult( ActionListener chunkPartListener ) { if (taskType == TaskType.TEXT_EMBEDDING) { - var translated = new ArrayList(); + var translated = new ArrayList(); for (var inferenceResult : inferenceResults) { - if (inferenceResult instanceof MlTextEmbeddingResults mlTextEmbeddingResult) { - translated.add(new TextEmbeddingFloatResults.Embedding(mlTextEmbeddingResult.getInferenceAsFloat())); + if (inferenceResult instanceof MlDenseEmbeddingResults mlTextEmbeddingResult) { + translated.add(new DenseEmbeddingFloatResults.Embedding(mlTextEmbeddingResult.getInferenceAsFloat())); } else if (inferenceResult instanceof ErrorInferenceResults error) { chunkPartListener.onFailure(error.getException()); return; } else { chunkPartListener.onFailure( - createInvalidChunkedResultException(MlTextEmbeddingResults.NAME, inferenceResult.getWriteableName()) + createInvalidChunkedResultException(MlDenseEmbeddingResults.NAME, inferenceResult.getWriteableName()) ); return; } } - chunkPartListener.onResponse(new TextEmbeddingFloatResults(translated)); + chunkPartListener.onResponse(new DenseEmbeddingFloatResults(translated)); } else { // sparse var translated = new ArrayList(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntity.java index 499fe9ae0c6c7..67527698bf02c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntity.java @@ -12,7 +12,7 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.response.XContentUtils; @@ -70,7 +70,7 @@ public class GoogleAiStudioEmbeddingsResponseEntity { * */ - public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { + public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { @@ -81,16 +81,16 @@ public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult positionParserAtTokenAfterField(jsonParser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE); - List embeddingList = parseList( + List embeddingList = parseList( jsonParser, GoogleAiStudioEmbeddingsResponseEntity::parseEmbeddingObject ); - return new TextEmbeddingFloatResults(embeddingList); + return new DenseEmbeddingFloatResults(embeddingList); } } - private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException { + private static DenseEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); positionParserAtTokenAfterField(parser, "values", FAILED_TO_FIND_FIELD_TEMPLATE); @@ -99,7 +99,7 @@ private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContent // parse and discard the rest of the object consumeUntilObjectEnd(parser); - return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList); + return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList); } private GoogleAiStudioEmbeddingsResponseEntity() {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntity.java index b4038e42c62cb..94272815e8db2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntity.java @@ -13,7 +13,7 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -64,7 +64,7 @@ public class GoogleVertexAiEmbeddingsResponseEntity { * */ - public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { + public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { @@ -75,16 +75,16 @@ public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult positionParserAtTokenAfterField(jsonParser, "predictions", FAILED_TO_FIND_FIELD_TEMPLATE); - List embeddingList = parseList( + List embeddingList = parseList( jsonParser, GoogleVertexAiEmbeddingsResponseEntity::parseEmbeddingObject ); - return new TextEmbeddingFloatResults(embeddingList); + return new DenseEmbeddingFloatResults(embeddingList); } } - private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException { + private static DenseEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE); @@ -99,7 +99,7 @@ private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContent consumeUntilObjectEnd(parser); consumeUntilObjectEnd(parser); - return TextEmbeddingFloatResults.Embedding.of(embeddingValueList); + return DenseEmbeddingFloatResults.Embedding.of(embeddingValueList); } private static float parseEmbeddingList(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 081d5c63b84ff..681058d8c21a5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -26,9 +26,9 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.EmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -127,8 +127,8 @@ private static List translateToChunkedResults( List inputs, InferenceServiceResults inferenceResults ) { - if (inferenceResults instanceof TextEmbeddingFloatResults textEmbeddingResults) { - validateInputSizeAgainstEmbeddings(ChunkInferenceInput.inputs(inputs), textEmbeddingResults.embeddings().size()); + if (inferenceResults instanceof DenseEmbeddingFloatResults denseEmbeddingResults) { + validateInputSizeAgainstEmbeddings(ChunkInferenceInput.inputs(inputs), denseEmbeddingResults.embeddings().size()); var results = new ArrayList(inputs.size()); @@ -137,7 +137,7 @@ private static List translateToChunkedResults( new ChunkedInferenceEmbedding( List.of( new EmbeddingResults.Chunk( - textEmbeddingResults.embeddings().get(i), + denseEmbeddingResults.embeddings().get(i), new ChunkedInference.TextOffset(0, inputs.get(i).input().length()) ) ) @@ -153,7 +153,7 @@ private static List translateToChunkedResults( } else { String expectedClasses = Strings.format( "One of [%s,%s]", - TextEmbeddingFloatResults.class.getSimpleName(), + DenseEmbeddingFloatResults.class.getSimpleName(), SparseEmbeddingResults.class.getSimpleName() ); throw createInvalidChunkedResultException(expectedClasses, inferenceResults.getWriteableName()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntity.java index baf1e884108fb..126d5b03fcde4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntity.java @@ -12,7 +12,7 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.response.XContentUtils; @@ -33,7 +33,7 @@ public class HuggingFaceEmbeddingsResponseEntity { * Parse the response from hugging face. The known formats are an array of arrays and object with an {@code embeddings} field containing * an array of arrays. */ - public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { + public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { @@ -91,13 +91,13 @@ public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult * sentence-transformers/all-MiniLM-L6-v2 * sentence-transformers/all-MiniLM-L12-v2 */ - private static TextEmbeddingFloatResults parseArrayFormat(XContentParser parser) throws IOException { - List embeddingList = parseList( + private static DenseEmbeddingFloatResults parseArrayFormat(XContentParser parser) throws IOException { + List embeddingList = parseList( parser, HuggingFaceEmbeddingsResponseEntity::parseEmbeddingEntry ); - return new TextEmbeddingFloatResults(embeddingList); + return new DenseEmbeddingFloatResults(embeddingList); } /** @@ -136,22 +136,22 @@ private static TextEmbeddingFloatResults parseArrayFormat(XContentParser parser) * intfloat/multilingual-e5-small * sentence-transformers/all-mpnet-base-v2 */ - private static TextEmbeddingFloatResults parseObjectFormat(XContentParser parser) throws IOException { + private static DenseEmbeddingFloatResults parseObjectFormat(XContentParser parser) throws IOException { positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE); - List embeddingList = parseList( + List embeddingList = parseList( parser, HuggingFaceEmbeddingsResponseEntity::parseEmbeddingEntry ); - return new TextEmbeddingFloatResults(embeddingList); + return new DenseEmbeddingFloatResults(embeddingList); } - private static TextEmbeddingFloatResults.Embedding parseEmbeddingEntry(XContentParser parser) throws IOException { + private static DenseEmbeddingFloatResults.Embedding parseEmbeddingEntry(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); List embeddingValuesList = parseList(parser, XContentUtils::parseFloat); - return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList); + return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList); } private HuggingFaceEmbeddingsResponseEntity() {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntity.java index 4fda9d5661a2c..d12f44932ed6e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntity.java @@ -12,7 +12,7 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.response.XContentUtils; @@ -30,7 +30,7 @@ public class IbmWatsonxEmbeddingsResponseEntity { private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in IBM watsonx embeddings response"; - public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { + public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { @@ -41,16 +41,16 @@ public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult positionParserAtTokenAfterField(jsonParser, "results", FAILED_TO_FIND_FIELD_TEMPLATE); - List embeddingList = parseList( + List embeddingList = parseList( jsonParser, IbmWatsonxEmbeddingsResponseEntity::parseEmbeddingObject ); - return new TextEmbeddingFloatResults(embeddingList); + return new DenseEmbeddingFloatResults(embeddingList); } } - private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException { + private static DenseEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); @@ -59,7 +59,7 @@ private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContent // parse and discard the rest of the object consumeUntilObjectEnd(parser); - return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList); + return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList); } private IbmWatsonxEmbeddingsResponseEntity() {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntity.java index 8eee003accba0..f9c80d2593642 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntity.java @@ -15,9 +15,9 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.response.XContentUtils; @@ -126,15 +126,15 @@ public static InferenceServiceResults fromResponse(Request request, HttpResult r } private static InferenceServiceResults parseFloatDataObject(XContentParser jsonParser) throws IOException { - List embeddingList = parseList( + List embeddingList = parseList( jsonParser, JinaAIEmbeddingsResponseEntity::parseFloatEmbeddingObject ); - return new TextEmbeddingFloatResults(embeddingList); + return new DenseEmbeddingFloatResults(embeddingList); } - private static TextEmbeddingFloatResults.Embedding parseFloatEmbeddingObject(XContentParser parser) throws IOException { + private static DenseEmbeddingFloatResults.Embedding parseFloatEmbeddingObject(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); @@ -143,19 +143,19 @@ private static TextEmbeddingFloatResults.Embedding parseFloatEmbeddingObject(XCo // parse and discard the rest of the object consumeUntilObjectEnd(parser); - return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList); + return DenseEmbeddingFloatResults.Embedding.of(embeddingValuesList); } private static InferenceServiceResults parseBitDataObject(XContentParser jsonParser) throws IOException { - List embeddingList = parseList( + List embeddingList = parseList( jsonParser, JinaAIEmbeddingsResponseEntity::parseBitEmbeddingObject ); - return new TextEmbeddingBitResults(embeddingList); + return new DenseEmbeddingBitResults(embeddingList); } - private static TextEmbeddingByteResults.Embedding parseBitEmbeddingObject(XContentParser parser) throws IOException { + private static DenseEmbeddingByteResults.Embedding parseBitEmbeddingObject(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); @@ -164,7 +164,7 @@ private static TextEmbeddingByteResults.Embedding parseBitEmbeddingObject(XConte // parse and discard the rest of the object consumeUntilObjectEnd(parser); - return TextEmbeddingByteResults.Embedding.of(embeddingList); + return DenseEmbeddingByteResults.Embedding.of(embeddingList); } private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntity.java index b8130545a711d..3298ca310c01d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntity.java @@ -12,7 +12,7 @@ import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -65,9 +65,9 @@ public class OpenAiEmbeddingsResponseEntity { * * */ - public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { + public static DenseEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) { - return EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults(); + return EmbeddingFloatResult.PARSER.apply(p, null).toDenseEmbeddingFloatResults(); } } @@ -83,9 +83,9 @@ public record EmbeddingFloatResult(List embeddingResu PARSER.declareObjectArray(constructorArg(), EmbeddingFloatResultEntry.PARSER::apply, new ParseField("data")); } - public TextEmbeddingFloatResults toTextEmbeddingFloatResults() { - return new TextEmbeddingFloatResults( - embeddingResults.stream().map(entry -> TextEmbeddingFloatResults.Embedding.of(entry.embedding)).toList() + public DenseEmbeddingFloatResults toDenseEmbeddingFloatResults() { + return new DenseEmbeddingFloatResults( + embeddingResults.stream().map(entry -> DenseEmbeddingFloatResults.Embedding.of(entry.embedding)).toList() ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredServiceSchema.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredServiceSchema.java index 70ed23e215dd6..94d96b3db6d26 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredServiceSchema.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredServiceSchema.java @@ -12,6 +12,7 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.validation.DenseEmbeddingModelValidator; /** * Contains any model-specific settings that are stored in SageMakerServiceSettings. @@ -72,7 +73,7 @@ default boolean isFragment() { /** * If this Schema supports Text Embeddings, then we need to implement this. - * {@link org.elasticsearch.xpack.inference.services.validation.TextEmbeddingModelValidator} will set the dimensions if the user + * {@link DenseEmbeddingModelValidator} will set the dimensions if the user * does not do it, so we need to store the dimensions and flip the {@link #dimensionsSetByUser()} boolean. */ default SageMakerStoredServiceSchema updateModelWithEmbeddingDetails(Integer dimensions) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java index a5fd194f12109..dbbf82f5703b9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java @@ -24,10 +24,10 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingResults; import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema; @@ -92,7 +92,7 @@ public Stream namedWriteables() { } @Override - public TextEmbeddingResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception { + public DenseEmbeddingResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception { try (var p = jsonXContent.createParser(XContentParserConfiguration.EMPTY, response.body().asInputStream())) { return switch (model.apiServiceSettings().elementType()) { case BIT -> TextEmbeddingBinary.PARSER.apply(p, null); @@ -103,7 +103,7 @@ public TextEmbeddingResults responseBody(SageMakerModel model, InvokeEndpoint } /** - * Reads binary format (it says bytes, but the lengths are different) + * Reads binary format * { * "text_embedding_bits": [ * { @@ -120,12 +120,12 @@ public TextEmbeddingResults responseBody(SageMakerModel model, InvokeEndpoint * } */ private static class TextEmbeddingBinary { - private static final ParseField TEXT_EMBEDDING_BITS = new ParseField(TextEmbeddingBitResults.TEXT_EMBEDDING_BITS); + private static final ParseField TEXT_EMBEDDING_BITS = new ParseField(DenseEmbeddingBitResults.TEXT_EMBEDDING_BITS); @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - TextEmbeddingBitResults.class.getSimpleName(), + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + DenseEmbeddingBitResults.class.getSimpleName(), IGNORE_UNKNOWN_FIELDS, - args -> new TextEmbeddingBitResults((List) args[0]) + args -> new DenseEmbeddingBitResults((List) args[0]) ); static { @@ -151,20 +151,20 @@ private static class TextEmbeddingBinary { * } */ private static class TextEmbeddingBytes { - private static final ParseField TEXT_EMBEDDING_BYTES = new ParseField("text_embedding_bytes"); + private static final ParseField TEXT_EMBEDDING_BYTES = new ParseField(DenseEmbeddingByteResults.TEXT_EMBEDDING_BYTES); @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - TextEmbeddingByteResults.class.getSimpleName(), + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + DenseEmbeddingByteResults.class.getSimpleName(), IGNORE_UNKNOWN_FIELDS, - args -> new TextEmbeddingByteResults((List) args[0]) + args -> new DenseEmbeddingByteResults((List) args[0]) ); @SuppressWarnings("unchecked") - private static final ConstructingObjectParser BYTE_PARSER = + private static final ConstructingObjectParser BYTE_PARSER = new ConstructingObjectParser<>( - TextEmbeddingByteResults.Embedding.class.getSimpleName(), + DenseEmbeddingByteResults.Embedding.class.getSimpleName(), IGNORE_UNKNOWN_FIELDS, - args -> TextEmbeddingByteResults.Embedding.of((List) args[0]) + args -> DenseEmbeddingByteResults.Embedding.of((List) args[0]) ); static { @@ -197,20 +197,20 @@ private static class TextEmbeddingBytes { * } */ private static class TextEmbeddingFloat { - private static final ParseField TEXT_EMBEDDING_FLOAT = new ParseField("text_embedding"); + private static final ParseField TEXT_EMBEDDING_FLOAT = new ParseField(DenseEmbeddingFloatResults.TEXT_EMBEDDING); @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - TextEmbeddingByteResults.class.getSimpleName(), + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + DenseEmbeddingFloatResults.class.getSimpleName(), IGNORE_UNKNOWN_FIELDS, - args -> new TextEmbeddingFloatResults((List) args[0]) + args -> new DenseEmbeddingFloatResults((List) args[0]) ); @SuppressWarnings("unchecked") - private static final ConstructingObjectParser FLOAT_PARSER = + private static final ConstructingObjectParser FLOAT_PARSER = new ConstructingObjectParser<>( - TextEmbeddingFloatResults.Embedding.class.getSimpleName(), + DenseEmbeddingFloatResults.Embedding.class.getSimpleName(), IGNORE_UNKNOWN_FIELDS, - args -> TextEmbeddingFloatResults.Embedding.of((List) args[0]) + args -> DenseEmbeddingFloatResults.Embedding.of((List) args[0]) ); static { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java index f060c7a75797b..b4c0b62f7ca12 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java @@ -25,7 +25,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.json.JsonXContent; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.services.openai.response.OpenAiEmbeddingsResponseEntity; import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; @@ -116,9 +116,9 @@ public SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest req } @Override - public TextEmbeddingFloatResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception { + public DenseEmbeddingFloatResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception { try (var p = jsonXContent.createParser(XContentParserConfiguration.EMPTY, response.body().asInputStream())) { - return OpenAiEmbeddingsResponseEntity.EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults(); + return OpenAiEmbeddingsResponseEntity.EmbeddingFloatResult.PARSER.apply(p, null).toDenseEmbeddingFloatResults(); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/DenseEmbeddingModelValidator.java similarity index 85% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/DenseEmbeddingModelValidator.java index ce9df7376ebcb..fb7b415758ca5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/DenseEmbeddingModelValidator.java @@ -16,14 +16,14 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingResults; -public class TextEmbeddingModelValidator implements ModelValidator { +public class DenseEmbeddingModelValidator implements ModelValidator { private final ServiceIntegrationValidator serviceIntegrationValidator; - public TextEmbeddingModelValidator(ServiceIntegrationValidator serviceIntegrationValidator) { + public DenseEmbeddingModelValidator(ServiceIntegrationValidator serviceIntegrationValidator) { this.serviceIntegrationValidator = serviceIntegrationValidator; } @@ -35,7 +35,7 @@ public void validate(InferenceService service, Model model, TimeValue timeout, A } private Model postValidate(InferenceService service, Model model, InferenceServiceResults results) { - if (results instanceof TextEmbeddingResults embeddingResults) { + if (results instanceof DenseEmbeddingResults embeddingResults) { var serviceSettings = model.getServiceSettings(); var dimensions = serviceSettings.dimensions(); int embeddingSize = getEmbeddingSize(embeddingResults); @@ -60,7 +60,7 @@ private Model postValidate(InferenceService service, Model model, InferenceServi throw new ElasticsearchStatusException( "Validation call did not return expected results type." + "Expected a result of type [" - + TextEmbeddingFloatResults.NAME + + DenseEmbeddingFloatResults.NAME + "] got [" + (results == null ? "null" : results.getWriteableName()) + "]", @@ -69,7 +69,7 @@ private Model postValidate(InferenceService service, Model model, InferenceServi } } - private int getEmbeddingSize(TextEmbeddingResults embeddingResults) { + private int getEmbeddingSize(DenseEmbeddingResults embeddingResults) { int embeddingSize; try { embeddingSize = embeddingResults.getFirstEmbeddingSize(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java index fa0e1b3e590a4..6078531eec3f9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java @@ -17,8 +17,8 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingResults; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandEmbeddingModel; public class ElasticsearchInternalServiceModelValidator implements ModelValidator { @@ -54,7 +54,7 @@ public void validate(InferenceService service, Model model, TimeValue timeout, A } private Model postValidate(InferenceService service, Model model, InferenceServiceResults results) { - if (results instanceof TextEmbeddingResults embeddingResults) { + if (results instanceof DenseEmbeddingResults embeddingResults) { var serviceSettings = model.getServiceSettings(); var dimensions = serviceSettings.dimensions(); int embeddingSize = getEmbeddingSize(embeddingResults); @@ -79,7 +79,7 @@ private Model postValidate(InferenceService service, Model model, InferenceServi throw new ElasticsearchStatusException( "Validation call did not return expected results type." + "Expected a result of type [" - + TextEmbeddingFloatResults.NAME + + DenseEmbeddingFloatResults.NAME + "] got [" + (results == null ? "null" : results.getWriteableName()) + "]", @@ -88,7 +88,7 @@ private Model postValidate(InferenceService service, Model model, InferenceServi } } - private int getEmbeddingSize(TextEmbeddingResults embeddingResults) { + private int getEmbeddingSize(DenseEmbeddingResults embeddingResults) { int embeddingSize; try { embeddingSize = embeddingResults.getFirstEmbeddingSize(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java index fac9ee5e9c1c1..59fcdac82c2a4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java @@ -36,7 +36,7 @@ private static ModelValidator buildModelValidatorForTaskType(TaskType taskType, switch (taskType) { case TEXT_EMBEDDING -> { - return new TextEmbeddingModelValidator( + return new DenseEmbeddingModelValidator( Objects.requireNonNullElse(validatorFromService, new SimpleServiceIntegrationValidator()) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntity.java index f9ba5fd58d21a..61436d509e45a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntity.java @@ -15,9 +15,9 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType; @@ -75,9 +75,9 @@ private static void checkByteBounds(Integer value) { } } - public TextEmbeddingByteResults.Embedding toInferenceByteEmbedding() { + public DenseEmbeddingByteResults.Embedding toInferenceByteEmbedding() { embedding.forEach(EmbeddingInt8ResultEntry::checkByteBounds); - return TextEmbeddingByteResults.Embedding.of(embedding.stream().map(Integer::byteValue).toList()); + return DenseEmbeddingByteResults.Embedding.of(embedding.stream().map(Integer::byteValue).toList()); } } @@ -108,8 +108,8 @@ record EmbeddingFloatResultEntry(Integer index, List embedding) { PARSER.declareFloatArray(constructorArg(), new ParseField("embedding")); } - public TextEmbeddingFloatResults.Embedding toInferenceFloatEmbedding() { - return TextEmbeddingFloatResults.Embedding.of(embedding); + public DenseEmbeddingFloatResults.Embedding toInferenceFloatEmbedding() { + return DenseEmbeddingFloatResults.Embedding.of(embedding); } } @@ -166,22 +166,22 @@ public static InferenceServiceResults fromResponse(Request request, HttpResult r if (embeddingType == null || embeddingType == VoyageAIEmbeddingType.FLOAT) { var embeddingResult = EmbeddingFloatResult.PARSER.apply(jsonParser, null); - List embeddingList = embeddingResult.entries.stream() + List embeddingList = embeddingResult.entries.stream() .map(EmbeddingFloatResultEntry::toInferenceFloatEmbedding) .toList(); - return new TextEmbeddingFloatResults(embeddingList); + return new DenseEmbeddingFloatResults(embeddingList); } else if (embeddingType == VoyageAIEmbeddingType.INT8) { var embeddingResult = EmbeddingInt8Result.PARSER.apply(jsonParser, null); - List embeddingList = embeddingResult.entries.stream() + List embeddingList = embeddingResult.entries.stream() .map(EmbeddingInt8ResultEntry::toInferenceByteEmbedding) .toList(); - return new TextEmbeddingByteResults(embeddingList); + return new DenseEmbeddingByteResults(embeddingList); } else if (embeddingType == VoyageAIEmbeddingType.BIT || embeddingType == VoyageAIEmbeddingType.BINARY) { var embeddingResult = EmbeddingInt8Result.PARSER.apply(jsonParser, null); - List embeddingList = embeddingResult.entries.stream() + List embeddingList = embeddingResult.entries.stream() .map(EmbeddingInt8ResultEntry::toInferenceByteEmbedding) .toList(); - return new TextEmbeddingBitResults(embeddingList); + return new DenseEmbeddingBitResults(embeddingList); } else { throw new IllegalArgumentException( "Illegal embedding_type value: " + embeddingType + ". Supported types are: " + VALID_EMBEDDING_TYPES_STRING diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java index c7ff03b05e975..429cf65991b13 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java @@ -11,8 +11,8 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests; import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider; @@ -39,7 +39,7 @@ protected Writeable.Reader instanceReader() { @Override protected InferenceAction.Response createTestInstance() { var result = randomBoolean() - ? TextEmbeddingFloatResultsTests.createRandomResults() + ? DenseEmbeddingFloatResultsTests.createRandomResults() : SparseEmbeddingResultsTests.createRandomResults(); return new InferenceAction.Response(result); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java index 411d992adfa3d..9d450912d9c4e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java @@ -14,10 +14,10 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.hamcrest.Matchers; import java.util.ArrayList; @@ -393,12 +393,12 @@ public void testVeryLongInput_Float() { // Produce inference results for each request, with increasing weights. float weight = 0f; for (var batch : batches) { - var embeddings = new ArrayList(); + var embeddings = new ArrayList(); for (int i = 0; i < batch.batch().requests().size(); i++) { weight += 1 / 16384f; - embeddings.add(new TextEmbeddingFloatResults.Embedding(new float[] { weight })); + embeddings.add(new DenseEmbeddingFloatResults.Embedding(new float[] { weight })); } - batch.listener().onResponse(new TextEmbeddingFloatResults(embeddings)); + batch.listener().onResponse(new DenseEmbeddingFloatResults(embeddings)); } assertNotNull(finalListener.results); @@ -410,8 +410,10 @@ public void testVeryLongInput_Float() { ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); assertThat(getMatchedText(inputs.get(0).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); - assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); - TextEmbeddingFloatResults.Embedding embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); + DenseEmbeddingFloatResults.Embedding embedding = (DenseEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks() + .get(0) + .embedding(); assertThat(embedding.values(), equalTo(new float[] { 1 / 16384f })); // The very long passage "word0 word1 ... word199999" is split into 10000 chunks for @@ -427,8 +429,8 @@ public void testVeryLongInput_Float() { // is the average of the weights 2/16384 ... 21/16384. assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); - assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); - embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); + embedding = (DenseEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new float[] { (2 + 21) / (2 * 16384f) })); // The last merged chunk consists of 19 small chunks (so 380 words) and the weight @@ -438,8 +440,8 @@ public void testVeryLongInput_Float() { startsWith(" word199620 word199621 ") ); assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); - assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); - embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(511).embedding(); + assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); + embedding = (DenseEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(511).embedding(); assertThat(embedding.values(), equalTo(new float[] { (9983 + 10001) / (2 * 16384f) })); // The last input has the token with weight 10002/16384. @@ -448,8 +450,8 @@ public void testVeryLongInput_Float() { chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); assertThat(getMatchedText(inputs.get(2).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); - assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); - embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); + embedding = (DenseEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new float[] { 10002 / 16384f })); } @@ -484,12 +486,12 @@ public void testVeryLongInput_Byte() { // Produce inference results for each request, with increasing weights. byte weight = 0; for (var batch : batches) { - var embeddings = new ArrayList(); + var embeddings = new ArrayList(); for (int i = 0; i < batch.batch().requests().size(); i++) { weight += 1; - embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { weight })); + embeddings.add(new DenseEmbeddingByteResults.Embedding(new byte[] { weight })); } - batch.listener().onResponse(new TextEmbeddingByteResults(embeddings)); + batch.listener().onResponse(new DenseEmbeddingByteResults(embeddings)); } assertNotNull(finalListener.results); @@ -501,8 +503,8 @@ public void testVeryLongInput_Byte() { ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); assertThat(getMatchedText(inputs.get(0).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); - assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); - TextEmbeddingByteResults.Embedding embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class)); + DenseEmbeddingByteResults.Embedding embedding = (DenseEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new byte[] { 1 })); // The very long passage "word0 word1 ... word199999" is split into 10000 chunks for @@ -518,8 +520,8 @@ public void testVeryLongInput_Byte() { // is the average of the weights 2 ... 21, so 11.5, which is rounded to 12. assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); - assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); - embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class)); + embedding = (DenseEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new byte[] { 12 })); // The last merged chunk consists of 19 small chunks (so 380 words) and the weight @@ -530,8 +532,8 @@ public void testVeryLongInput_Byte() { startsWith(" word199620 word199621 ") ); assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); - assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); - embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(511).embedding(); + assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class)); + embedding = (DenseEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(511).embedding(); assertThat(embedding.values(), equalTo(new byte[] { 8 })); // The last input has the token with weight 10002 % 256 = 18 @@ -540,8 +542,8 @@ public void testVeryLongInput_Byte() { chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); assertThat(getMatchedText(inputs.get(2).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); - assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); - embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class)); + embedding = (DenseEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new byte[] { 18 })); } @@ -570,18 +572,18 @@ public void testMergingListener_Float() { // 4 inputs in 2 batches { - var embeddings = new ArrayList(); + var embeddings = new ArrayList(); for (int i = 0; i < batchSize; i++) { - embeddings.add(new TextEmbeddingFloatResults.Embedding(new float[] { randomFloat() })); + embeddings.add(new DenseEmbeddingFloatResults.Embedding(new float[] { randomFloat() })); } - batches.get(0).listener().onResponse(new TextEmbeddingFloatResults(embeddings)); + batches.get(0).listener().onResponse(new DenseEmbeddingFloatResults(embeddings)); } { - var embeddings = new ArrayList(); + var embeddings = new ArrayList(); for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch - embeddings.add(new TextEmbeddingFloatResults.Embedding(new float[] { randomFloat() })); + embeddings.add(new DenseEmbeddingFloatResults.Embedding(new float[] { randomFloat() })); } - batches.get(1).listener().onResponse(new TextEmbeddingFloatResults(embeddings)); + batches.get(1).listener().onResponse(new DenseEmbeddingFloatResults(embeddings)); } assertNotNull(finalListener.results); @@ -650,18 +652,18 @@ public void testMergingListener_Byte() { // 4 inputs in 2 batches { - var embeddings = new ArrayList(); + var embeddings = new ArrayList(); for (int i = 0; i < batchSize; i++) { - embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { randomByte() })); + embeddings.add(new DenseEmbeddingByteResults.Embedding(new byte[] { randomByte() })); } - batches.get(0).listener().onResponse(new TextEmbeddingByteResults(embeddings)); + batches.get(0).listener().onResponse(new DenseEmbeddingByteResults(embeddings)); } { - var embeddings = new ArrayList(); + var embeddings = new ArrayList(); for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch - embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { randomByte() })); + embeddings.add(new DenseEmbeddingByteResults.Embedding(new byte[] { randomByte() })); } - batches.get(1).listener().onResponse(new TextEmbeddingByteResults(embeddings)); + batches.get(1).listener().onResponse(new DenseEmbeddingByteResults(embeddings)); } assertNotNull(finalListener.results); @@ -727,18 +729,18 @@ public void testMergingListener_Bit() { // 4 inputs in 2 batches { - var embeddings = new ArrayList(); + var embeddings = new ArrayList(); for (int i = 0; i < batchSize; i++) { - embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { randomByte() })); + embeddings.add(new DenseEmbeddingByteResults.Embedding(new byte[] { randomByte() })); } - batches.get(0).listener().onResponse(new TextEmbeddingBitResults(embeddings)); + batches.get(0).listener().onResponse(new DenseEmbeddingBitResults(embeddings)); } { - var embeddings = new ArrayList(); + var embeddings = new ArrayList(); for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch - embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { randomByte() })); + embeddings.add(new DenseEmbeddingByteResults.Embedding(new byte[] { randomByte() })); } - batches.get(1).listener().onResponse(new TextEmbeddingBitResults(embeddings)); + batches.get(1).listener().onResponse(new DenseEmbeddingBitResults(embeddings)); } assertNotNull(finalListener.results); @@ -892,10 +894,10 @@ public void onFailure(Exception e) { var batches = new EmbeddingRequestChunker<>(inputs, 10, 100, 0).batchRequestsWithListeners(listener); assertThat(batches, hasSize(1)); - var embeddings = new ArrayList(); - embeddings.add(new TextEmbeddingFloatResults.Embedding(new float[] { randomFloat() })); - embeddings.add(new TextEmbeddingFloatResults.Embedding(new float[] { randomFloat() })); - batches.get(0).listener().onResponse(new TextEmbeddingFloatResults(embeddings)); + var embeddings = new ArrayList(); + embeddings.add(new DenseEmbeddingFloatResults.Embedding(new float[] { randomFloat() })); + embeddings.add(new DenseEmbeddingFloatResults.Embedding(new float[] { randomFloat() })); + batches.get(0).listener().onResponse(new DenseEmbeddingFloatResults(embeddings)); assertEquals("Error the number of embedding responses [2] does not equal the number of requests [3]", failureMessage.get()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java index 027a19aca6d1f..f104e7b87136e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java @@ -48,7 +48,7 @@ import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.core.Strings.format; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index d1499f4009d0a..d596677f4125f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -25,10 +25,10 @@ import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.EmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.utils.FloatConversionUtils; import org.elasticsearch.xpack.inference.chunking.NoneChunkingSettings; import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings; @@ -211,7 +211,7 @@ public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingByte(Mode } chunks.add( new EmbeddingResults.Chunk( - new TextEmbeddingByteResults.Embedding(values), + new DenseEmbeddingByteResults.Embedding(values), new ChunkedInference.TextOffset(0, input.length()) ) ); @@ -233,7 +233,7 @@ public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingFloat(Mod } chunks.add( new EmbeddingResults.Chunk( - new TextEmbeddingFloatResults.Embedding(values), + new DenseEmbeddingFloatResults.Embedding(values), new ChunkedInference.TextOffset(0, input.length()) ) ); @@ -415,8 +415,8 @@ public static ChunkedInference toChunkedResult( ChunkedInference.TextOffset offset = createOffset(useLegacyFormat, entryChunk, matchedText); double[] values = parseDenseVector(entryChunk.rawEmbeddings(), embeddingLength, field.contentType()); EmbeddingResults.Embedding embedding = switch (elementType) { - case FLOAT -> new TextEmbeddingFloatResults.Embedding(FloatConversionUtils.floatArrayOf(values)); - case BYTE, BIT -> new TextEmbeddingByteResults.Embedding(byteArrayOf(values)); + case FLOAT -> new DenseEmbeddingFloatResults.Embedding(FloatConversionUtils.floatArrayOf(values)); + case BYTE, BIT -> new DenseEmbeddingByteResults.Embedding(byteArrayOf(values)); }; chunks.add(new EmbeddingResults.Chunk(embedding, offset)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java index 20354aa9b2dc7..de8583132319b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilderTests.java @@ -22,7 +22,7 @@ import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor; import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; import org.elasticsearch.search.vectors.VectorData; -import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; @@ -175,8 +175,8 @@ public void testInterceptAndRewrite() throws Exception { new FullyQualifiedInferenceId(LOCAL_CLUSTER_GROUP_KEY, DENSE_INFERENCE_ID) ); assertThat(inferenceResults, notNullValue()); - assertThat(inferenceResults, instanceOf(MlTextEmbeddingResults.class)); - VectorData queryVector = new VectorData(((MlTextEmbeddingResults) inferenceResults).getInferenceAsFloat()); + assertThat(inferenceResults, instanceOf(MlDenseEmbeddingResults.class)); + VectorData queryVector = new VectorData(((MlDenseEmbeddingResults) inferenceResults).getInferenceAsFloat()); // Perform data node rewrite on test index 1 final QueryRewriteContext indexMetadataContextTestIndex1 = createIndexMetadataContext( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/MockInferenceClient.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/MockInferenceClient.java index 17438d9786ba3..b22191735f1d9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/MockInferenceClient.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/MockInferenceClient.java @@ -20,11 +20,11 @@ import org.elasticsearch.test.client.NoOpClient; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; -import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import java.util.Arrays; @@ -58,8 +58,8 @@ protected void InferenceResults inferenceResults = generateInferenceResults(inferenceId, input); if (inferenceResults instanceof TextExpansionResults textExpansionResults) { inferenceServiceResults = SparseEmbeddingResults.of(List.of(textExpansionResults)); - } else if (inferenceResults instanceof MlTextEmbeddingResults mlTextEmbeddingResults) { - inferenceServiceResults = TextEmbeddingFloatResults.of(List.of(mlTextEmbeddingResults)); + } else if (inferenceResults instanceof MlDenseEmbeddingResults mlDenseEmbeddingResults) { + inferenceServiceResults = DenseEmbeddingFloatResults.of(List.of(mlDenseEmbeddingResults)); } else { throw new IllegalStateException("Unexpected inference results type [" + inferenceResults.getWriteableName() + "]"); } @@ -126,6 +126,6 @@ private static InferenceResults generateTextEmbeddingResults(MinimalServiceSetti double[] embedding = new double[embeddingSize]; Arrays.fill(embedding, Byte.MIN_VALUE); // Always use a byte value so that the embedding is valid regardless of the element type - return new MlTextEmbeddingResults(DEFAULT_RESULTS_FIELD, embedding, false); + return new MlDenseEmbeddingResults(DEFAULT_RESULTS_FIELD, embedding, false); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index b2d7218720a57..160261eb1b1cc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -67,10 +67,10 @@ import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.XPackClientPlugin; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; -import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; @@ -348,9 +348,9 @@ private InferenceAction.Response generateTextEmbeddingInferenceResponse() { int inferenceLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(denseVectorElementType, TEXT_EMBEDDING_DIMENSION_COUNT); double[] inference = new double[inferenceLength]; Arrays.fill(inference, 1.0); - MlTextEmbeddingResults textEmbeddingResults = new MlTextEmbeddingResults(DEFAULT_RESULTS_FIELD, inference, false); + MlDenseEmbeddingResults textEmbeddingResults = new MlDenseEmbeddingResults(DEFAULT_RESULTS_FIELD, inference, false); - return new InferenceAction.Response(TextEmbeddingFloatResults.of(List.of(textEmbeddingResults))); + return new InferenceAction.Response(DenseEmbeddingFloatResults.of(List.of(textEmbeddingResults))); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java index 4e60b09530684..ece25247bb5b5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java @@ -23,7 +23,7 @@ import org.elasticsearch.xpack.core.inference.InferenceContext; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; import org.elasticsearch.xpack.inference.InferencePlugin; import org.junit.Before; @@ -234,7 +234,7 @@ public void testExtractProductUseCase_EmptyWhenHeaderValueEmpty() { static InferenceAction.Response createResponse() { return new InferenceAction.Response( - new TextEmbeddingByteResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -1 }))) + new DenseEmbeddingByteResults(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -1 }))) ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index c9f5ea738b33b..94f65acb67cff 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -33,9 +33,9 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.InputTypeTests; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; @@ -516,7 +516,7 @@ private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettin var firstResult = results.getFirst(); assertThat(firstResult, instanceOf(ChunkedInferenceEmbedding.class)); Class expectedClass = switch (taskType) { - case TEXT_EMBEDDING -> TextEmbeddingFloatResults.Chunk.class; + case TEXT_EMBEDDING -> DenseEmbeddingFloatResults.Chunk.class; case SPARSE_EMBEDDING -> SparseEmbeddingResults.Chunk.class; default -> null; }; @@ -650,10 +650,10 @@ private AlibabaCloudSearchModel createEmbeddingsModel( ) { public ExecutableAction accept(AlibabaCloudSearchActionVisitor visitor, Map taskSettings) { return (inferenceInputs, timeout, listener) -> { - TextEmbeddingFloatResults results = new TextEmbeddingFloatResults( + DenseEmbeddingFloatResults results = new DenseEmbeddingFloatResults( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123f, -0.0123f }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.0456f, -0.0456f }) + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123f, -0.0123f }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0456f, -0.0456f }) ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java index b09fbf43a8ca4..dec6539732156 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java @@ -21,11 +21,11 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; @@ -50,9 +50,9 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests.buildExpectationRerank; import static org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests.buildExpectationSparseEmbeddings; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.hamcrest.Matchers.is; @@ -88,7 +88,7 @@ public void testExecute_withTextEmbeddingsAction_Success() { float[] values = { 0.1111111f, 0.2222222f, 0.3333333f }; doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); - listener.onResponse(new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(values)))); + listener.onResponse(new DenseEmbeddingFloatResults(List.of(new DenseEmbeddingFloatResults.Embedding(values)))); return Void.TYPE; }).when(sender).send(any(), any(), any(), any()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntityTests.java index ed8a1185bd846..28e5c9422a970 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/response/AlibabaCloudSearchEmbeddingsResponseEntityTests.java @@ -9,7 +9,7 @@ import org.apache.http.HttpResponse; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.request.AlibabaCloudSearchRequest; @@ -50,14 +50,14 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException, URI uri = new URI("mock_uri"); when(request.getURI()).thenReturn(uri); - TextEmbeddingFloatResults parsedResults = AlibabaCloudSearchEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = AlibabaCloudSearchEmbeddingsResponseEntity.fromResponse( request, new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); assertThat( parsedResults.embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { -0.02868066355586052f, 0.022033605724573135f }))) + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { -0.02868066355586052f, 0.022033605724573135f }))) ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index 04c8bca2287e5..33d2d120e1b8d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -38,7 +38,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.Utils; import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -70,7 +70,7 @@ import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; @@ -1026,7 +1026,7 @@ public void testInfer_SendsRequest_ForTitanEmbeddingsModel() throws IOException ); var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender() ) { - var results = new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F }))); + var results = new DenseEmbeddingFloatResults(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F }))); requestSender.enqueue(results); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -1067,8 +1067,8 @@ public void testInfer_SendsRequest_ForCohereEmbeddingsModel() throws IOException ) ) { try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { - var results = new TextEmbeddingFloatResults( - List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })) + var results = new DenseEmbeddingFloatResults( + List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })) ); requestSender.enqueue(results); @@ -1343,14 +1343,14 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep ) { try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { { - var mockResults1 = new TextEmbeddingFloatResults( - List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })) + var mockResults1 = new DenseEmbeddingFloatResults( + List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })) ); requestSender.enqueue(mockResults1); } { - var mockResults2 = new TextEmbeddingFloatResults( - List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.223F, 0.278F })) + var mockResults2 = new DenseEmbeddingFloatResults( + List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.223F, 0.278F })) ); requestSender.enqueue(mockResults2); } @@ -1373,10 +1373,10 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.123F, 0.678F }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -1385,10 +1385,10 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.223F, 0.278F }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java index 5dd42dc66485f..8f1ba6b277c37 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java @@ -15,7 +15,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.InputTypeTests; import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; @@ -33,7 +33,7 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.hamcrest.Matchers.is; @@ -53,8 +53,8 @@ public void shutdown() throws IOException { public void testEmbeddingsRequestAction_Titan() throws IOException { var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool); - var mockedFloatResults = List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })); - var mockedResult = new TextEmbeddingFloatResults(mockedFloatResults); + var mockedFloatResults = List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })); + var mockedResult = new DenseEmbeddingFloatResults(mockedFloatResults); try (var sender = new AmazonBedrockMockRequestSender()) { sender.enqueue(mockedResult); var creator = new AmazonBedrockActionCreator(sender, serviceComponents, TIMEOUT); @@ -91,8 +91,8 @@ public void testEmbeddingsRequestAction_Titan() throws IOException { public void testEmbeddingsRequestAction_Cohere() throws IOException { var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool); - var mockedFloatResults = List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })); - var mockedResult = new TextEmbeddingFloatResults(mockedFloatResults); + var mockedFloatResults = List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F })); + var mockedResult = new DenseEmbeddingFloatResults(mockedFloatResults); try (var sender = new AmazonBedrockMockRequestSender()) { sender.enqueue(mockedResult); var creator = new AmazonBedrockActionCreator(sender, serviceComponents, TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockExecutorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockExecutorTests.java index 19fe2ca784b1c..d267f80080885 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockExecutorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockExecutorTests.java @@ -34,7 +34,7 @@ import java.util.List; import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.common.TruncatorTests.createTruncator; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java index d608b6ec9fb8e..38a9ca17fec7b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java @@ -32,7 +32,7 @@ import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.services.amazonbedrock.client.AmazonBedrockExecutorTests.TEST_AMAZON_TITAN_EMBEDDINGS_RESULT; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index 06668b91e1965..734d66a20bf56 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -39,8 +39,8 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -1347,10 +1347,10 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.0123f, -0.0123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -1359,10 +1359,10 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 1.0123f, -1.0123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java index 975661d180795..ef82cc18946b1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java @@ -46,7 +46,7 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioEmbeddingsResponseEntityTests.java index 3da3598a4637a..e2ebd72b17da5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioEmbeddingsResponseEntityTests.java @@ -9,7 +9,7 @@ import org.apache.http.HttpResponse; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -50,11 +50,11 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { var entity = new AzureAiStudioEmbeddingsResponseEntity(); - var parsedResults = (TextEmbeddingFloatResults) entity.apply( + var parsedResults = (DenseEmbeddingFloatResults) entity.apply( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingFloatResults.Embedding.of(List.of(0.014539449F, -0.015288644F))))); + assertThat(parsedResults.embeddings(), is(List.of(DenseEmbeddingFloatResults.Embedding.of(List.of(0.014539449F, -0.015288644F))))); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index b5d8fb887c62c..6c0b4f272650a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -37,7 +37,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -63,7 +63,7 @@ import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; @@ -1007,10 +1007,10 @@ private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOExcepti var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.123f, -0.123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -1019,10 +1019,10 @@ private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOExcepti var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 1.123f, -1.123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java index 1724747f9861a..23c773ab0d61b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java @@ -42,7 +42,7 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.core.Strings.format; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java index 20ac31398de2f..c6db0f441d348 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java @@ -44,7 +44,7 @@ import java.util.List; import java.util.concurrent.TimeUnit; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index 0f897500698ee..a340236e31eaa 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -39,8 +39,8 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -69,7 +69,7 @@ import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; @@ -1351,7 +1351,7 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); assertArrayEquals( new float[] { 0.123f, -0.123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -1362,7 +1362,7 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); assertArrayEquals( new float[] { 0.223f, -0.223f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -1448,10 +1448,10 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException { var byteResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(byteResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 1), byteResult.chunks().get(0).offset()); - assertThat(byteResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); + assertThat(byteResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class)); assertArrayEquals( new byte[] { 23, -23 }, - ((TextEmbeddingByteResults.Embedding) byteResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingByteResults.Embedding) byteResult.chunks().get(0).embedding()).values() ); } { @@ -1459,10 +1459,10 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException { var byteResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(byteResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 2), byteResult.chunks().get(0).offset()); - assertThat(byteResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); + assertThat(byteResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingByteResults.Embedding.class)); assertArrayEquals( new byte[] { 24, -24 }, - ((TextEmbeddingByteResults.Embedding) byteResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingByteResults.Embedding) byteResult.chunks().get(0).embedding()).values() ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java index f3d816d9a118d..a6f3e6a07d66c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java @@ -40,7 +40,7 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java index f5191ecac5ce6..008bce852b8d6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java @@ -44,8 +44,8 @@ import java.util.Map; import java.util.concurrent.TimeUnit; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationByte; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationByte; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntityTests.java index 6df356bfe0a80..78bf38f204bf7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntityTests.java @@ -10,9 +10,9 @@ import org.apache.http.HttpResponse; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; import org.hamcrest.MatcherAssert; @@ -56,10 +56,10 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - MatcherAssert.assertThat(parsedResults, instanceOf(TextEmbeddingFloatResults.class)); + MatcherAssert.assertThat(parsedResults, instanceOf(DenseEmbeddingFloatResults.class)); MatcherAssert.assertThat( - ((TextEmbeddingFloatResults) parsedResults).embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }))) + ((DenseEmbeddingFloatResults) parsedResults).embeddings(), + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }))) ); } @@ -90,14 +90,14 @@ public void testFromResponse_CreatesResultsForASingleItem_ObjectFormat() throws } """; - TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); MatcherAssert.assertThat( parsedResults.embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }))) + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }))) ); } @@ -134,14 +134,14 @@ public void testFromResponse_UsesTheFirstValidEmbeddingsEntry() throws IOExcepti } """; - TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); MatcherAssert.assertThat( parsedResults.embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }))) + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }))) ); } @@ -178,14 +178,14 @@ public void testFromResponse_UsesTheFirstValidEmbeddingsEntryInt8_WithInvalidFir } """; - TextEmbeddingByteResults parsedResults = (TextEmbeddingByteResults) CohereEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingByteResults parsedResults = (DenseEmbeddingByteResults) CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); MatcherAssert.assertThat( parsedResults.embeddings(), - is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -1, (byte) 0 }))) + is(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -1, (byte) 0 }))) ); } @@ -216,14 +216,14 @@ public void testFromResponse_ParsesBytes() throws IOException { } """; - TextEmbeddingByteResults parsedResults = (TextEmbeddingByteResults) CohereEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingByteResults parsedResults = (DenseEmbeddingByteResults) CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); MatcherAssert.assertThat( parsedResults.embeddings(), - is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -1, (byte) 0 }))) + is(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -1, (byte) 0 }))) ); } @@ -257,14 +257,14 @@ public void testFromResponse_ParsesBytes_FromBinaryEmbeddingsEntry() throws IOEx } """; - TextEmbeddingBitResults parsedResults = (TextEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingBitResults parsedResults = (DenseEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); MatcherAssert.assertThat( parsedResults.embeddings(), - is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 }))) + is(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 }))) ); } @@ -297,7 +297,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException } """; - TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -306,8 +306,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException parsedResults.embeddings(), is( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }), - new TextEmbeddingFloatResults.Embedding(new float[] { -0.123F, 0.123F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { -0.123F, 0.123F }) ) ) ); @@ -344,7 +344,7 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw } """; - TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -353,8 +353,8 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw parsedResults.embeddings(), is( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }), - new TextEmbeddingFloatResults.Embedding(new float[] { -0.123F, 0.123F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { -0.123F, 0.123F }) ) ) ); @@ -397,7 +397,7 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat_Binary( } """; - TextEmbeddingBitResults parsedResults = (TextEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingBitResults parsedResults = (DenseEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -406,8 +406,8 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat_Binary( parsedResults.embeddings(), is( List.of( - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67 }), - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 34, (byte) -64, (byte) 97, (byte) 65, (byte) -42 }) + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67 }), + new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) 34, (byte) -64, (byte) 97, (byte) 65, (byte) -42 }) ) ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java index 1c5c13e2086c2..fb9802c6938d5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java @@ -14,7 +14,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser; -import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.DenseEmbeddingResponseParser; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.hamcrest.MatcherAssert; @@ -94,7 +94,7 @@ public static CustomModel createModel( public static CustomModel getTestModel() { return getTestModel( TaskType.TEXT_EMBEDDING, - new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT) + new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java index 65d9de30576ff..74c3e0f57e73c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java @@ -26,10 +26,10 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.DenseEmbeddingResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser; -import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.hamcrest.MatcherAssert; @@ -60,7 +60,7 @@ public static CustomServiceSettings createRandom(String inputUrl) { var requestContentString = randomAlphaOfLength(10); var responseJsonParser = switch (taskType) { - case TEXT_EMBEDDING -> new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT); + case TEXT_EMBEDDING -> new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT); case SPARSE_EMBEDDING -> new SparseEmbeddingResponseParser( "$.result.sparse_embeddings[*].embedding[*].token_id", "$.result.sparse_embeddings[*].embedding[*].weights" @@ -100,7 +100,7 @@ public void testFromMap() { var queryParameters = List.of(List.of("key", "value")); String requestContentString = "request body"; - var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT); + var responseParser = new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT); var settings = CustomServiceSettings.fromMap( new HashMap<>( @@ -124,7 +124,7 @@ public void testFromMap() { Map.of( CustomServiceSettings.JSON_PARSER, new HashMap<>( - Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") ) ) ), @@ -158,7 +158,7 @@ public void testFromMap_EmbeddingType_Bit() { String url = "http://www.abc.com"; String requestContentString = "request body"; - var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.BIT); + var responseParser = new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.BIT); var settings = CustomServiceSettings.fromMap( new HashMap<>( @@ -173,9 +173,9 @@ public void testFromMap_EmbeddingType_Bit() { CustomServiceSettings.JSON_PARSER, new HashMap<>( Map.of( - TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, + DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding", - TextEmbeddingResponseParser.EMBEDDING_TYPE, + DenseEmbeddingResponseParser.EMBEDDING_TYPE, CustomServiceEmbeddingType.BIT.toString() ) ) @@ -207,7 +207,7 @@ public void testFromMap_EmbeddingType_Binary() { String url = "http://www.abc.com"; String requestContentString = "request body"; - var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.BINARY); + var responseParser = new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.BINARY); var settings = CustomServiceSettings.fromMap( new HashMap<>( @@ -222,9 +222,9 @@ public void testFromMap_EmbeddingType_Binary() { CustomServiceSettings.JSON_PARSER, new HashMap<>( Map.of( - TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, + DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding", - TextEmbeddingResponseParser.EMBEDDING_TYPE, + DenseEmbeddingResponseParser.EMBEDDING_TYPE, CustomServiceEmbeddingType.BINARY.toString() ) ) @@ -256,7 +256,7 @@ public void testFromMap_EmbeddingType_Byte() { String url = "http://www.abc.com"; String requestContentString = "request body"; - var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.BYTE); + var responseParser = new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.BYTE); var settings = CustomServiceSettings.fromMap( new HashMap<>( @@ -271,9 +271,9 @@ public void testFromMap_EmbeddingType_Byte() { CustomServiceSettings.JSON_PARSER, new HashMap<>( Map.of( - TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, + DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding", - TextEmbeddingResponseParser.EMBEDDING_TYPE, + DenseEmbeddingResponseParser.EMBEDDING_TYPE, CustomServiceEmbeddingType.BYTE.toString() ) ) @@ -367,7 +367,7 @@ public void testFromMap_Completion_ThrowsWhenEmbeddingIsIncludedInMap() { Map.of( CompletionResponseParser.COMPLETION_PARSER_RESULT, "$.result.text", - TextEmbeddingResponseParser.EMBEDDING_TYPE, + DenseEmbeddingResponseParser.EMBEDDING_TYPE, "byte" ) ) @@ -393,7 +393,7 @@ public void testFromMap_WithOptionalsNotSpecified() { String url = "http://www.abc.com"; String requestContentString = "request body"; - var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT); + var responseParser = new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT); var settings = CustomServiceSettings.fromMap( new HashMap<>( @@ -407,7 +407,7 @@ public void testFromMap_WithOptionalsNotSpecified() { Map.of( CustomServiceSettings.JSON_PARSER, new HashMap<>( - Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") ) ) ) @@ -445,7 +445,7 @@ public void testFromMap_RemovesNullValues_FromMaps() { String requestContentString = "request body"; - var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT); + var responseParser = new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT); var settings = CustomServiceSettings.fromMap( new HashMap<>( @@ -467,7 +467,7 @@ public void testFromMap_RemovesNullValues_FromMaps() { Map.of( CustomServiceSettings.JSON_PARSER, new HashMap<>( - Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") ) ) ) @@ -519,7 +519,7 @@ public void testFromMap_ReturnsError_IfHeadersContainsNonStringValues() { Map.of( CustomServiceSettings.JSON_PARSER, new HashMap<>( - Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") ) ) ) @@ -566,7 +566,7 @@ public void testFromMap_ReturnsError_IfQueryParamsContainsNonStringValues() { Map.of( CustomServiceSettings.JSON_PARSER, new HashMap<>( - Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") ) ) ) @@ -604,7 +604,7 @@ public void testFromMap_ReturnsError_IfRequestMapIsMissing() { Map.of( CustomServiceSettings.JSON_PARSER, new HashMap<>( - Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") ) ) ) @@ -636,7 +636,7 @@ public void testFromMap_ReturnsError_IfResponseMapIsMissing() { Map.of( CustomServiceSettings.JSON_PARSER, new HashMap<>( - Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") ) ) ) @@ -675,7 +675,7 @@ public void testFromMap_ReturnsError_IfJsonParserMapIsNotEmptyAfterParsing() { CustomServiceSettings.JSON_PARSER, new HashMap<>( Map.of( - TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, + DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding", "key", "value" @@ -717,7 +717,7 @@ public void testFromMap_ReturnsError_IfResponseMapIsNotEmptyAfterParsing() { Map.of( CustomServiceSettings.JSON_PARSER, new HashMap<>( - Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") ), "key", "value" @@ -757,7 +757,7 @@ public void testFromMap_ReturnsError_IfTaskTypeIsInvalid() { Map.of( CustomServiceSettings.JSON_PARSER, new HashMap<>( - Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") ) ) ) @@ -779,7 +779,7 @@ public void testXContent() throws IOException { Map.of("key", "value"), null, "string", - new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT), + new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT), null ); @@ -868,7 +868,7 @@ public void testXContent_WithInputTypeTranslationValues() throws IOException { Map.of("key", "value"), null, "string", - new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT), + new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT), null, null, new InputTypeTranslator(Map.of(InputType.SEARCH, "do_search", InputType.INGEST, "do_ingest"), "a_default") @@ -915,7 +915,7 @@ public void testXContent_BatchSize11() throws IOException { Map.of("key", "value"), null, "string", - new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT), + new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT), null, 11, InputTypeTranslator.EMPTY_TRANSLATOR diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index 44bf17c3ac96d..1d43330381fa4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -28,9 +28,9 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -39,9 +39,9 @@ import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser; +import org.elasticsearch.xpack.inference.services.custom.response.DenseEmbeddingResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser; -import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.hamcrest.CoreMatchers; import org.hamcrest.Matchers; @@ -138,7 +138,7 @@ private static void assertTextEmbeddingModel(Model model, boolean modelIncludesS var customModel = assertCommonModelFields(model, modelIncludesSecrets); assertThat(customModel.getTaskType(), is(TaskType.TEXT_EMBEDDING)); - assertThat(customModel.getServiceSettings().getResponseJsonParser(), instanceOf(TextEmbeddingResponseParser.class)); + assertThat(customModel.getServiceSettings().getResponseJsonParser(), instanceOf(DenseEmbeddingResponseParser.class)); } private static CustomModel assertCommonModelFields(Model model, boolean modelIncludesSecrets) { @@ -204,7 +204,7 @@ private static Map createServiceSettingsMap(TaskType taskType) { private static Map createResponseParserMap(TaskType taskType) { return switch (taskType) { case TEXT_EMBEDDING -> new HashMap<>( - Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + Map.of(DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") ); case COMPLETION -> new HashMap<>(Map.of(CompletionResponseParser.COMPLETION_PARSER_RESULT, "$.result.text")); case SPARSE_EMBEDDING -> new HashMap<>( @@ -240,18 +240,18 @@ private static Map createSecretSettingsMap() { private static CustomModel createInternalEmbeddingModel(SimilarityMeasure similarityMeasure) { return createInternalEmbeddingModel( similarityMeasure, - new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT), + new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT), "http://www.abc.com" ); } - private static CustomModel createInternalEmbeddingModel(TextEmbeddingResponseParser parser, String url) { + private static CustomModel createInternalEmbeddingModel(DenseEmbeddingResponseParser parser, String url) { return createInternalEmbeddingModel(SimilarityMeasure.DOT_PRODUCT, parser, url); } private static CustomModel createInternalEmbeddingModel( @Nullable SimilarityMeasure similarityMeasure, - TextEmbeddingResponseParser parser, + DenseEmbeddingResponseParser parser, String url ) { var inferenceId = "inference_id"; @@ -276,7 +276,7 @@ private static CustomModel createInternalEmbeddingModel( private static CustomModel createInternalEmbeddingModel( @Nullable SimilarityMeasure similarityMeasure, - TextEmbeddingResponseParser parser, + DenseEmbeddingResponseParser parser, String url, @Nullable ChunkingSettings chunkingSettings, @Nullable Integer batchSize @@ -336,7 +336,7 @@ public void testInfer_ReturnsAnError_WithoutParsingTheResponseBody() throws IOEx webServer.enqueue(new MockResponse().setResponseCode(400).setBody(responseJson)); var model = createInternalEmbeddingModel( - new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT), + new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT), getUrl(webServer) ); PlainActionFuture listener = new PlainActionFuture<>(); @@ -394,7 +394,7 @@ public void testInfer_HandlesTextEmbeddingRequest_OpenAI_Format() throws IOExcep webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = createInternalEmbeddingModel( - new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT), + new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT), getUrl(webServer) ); PlainActionFuture listener = new PlainActionFuture<>(); @@ -412,12 +412,12 @@ public void testInfer_HandlesTextEmbeddingRequest_OpenAI_Format() throws IOExcep ); InferenceServiceResults results = listener.actionGet(TIMEOUT); - assertThat(results, instanceOf(TextEmbeddingFloatResults.class)); + assertThat(results, instanceOf(DenseEmbeddingFloatResults.class)); - var embeddingResults = (TextEmbeddingFloatResults) results; + var embeddingResults = (DenseEmbeddingFloatResults) results; assertThat( embeddingResults.embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }))) + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }))) ); } } @@ -676,7 +676,7 @@ public void testParseRequestConfig_ThrowsAValidationError_WhenReplacementDoesNot public void testChunkedInfer_ChunkingSettingsSet() throws IOException { var model = createInternalEmbeddingModel( SimilarityMeasure.DOT_PRODUCT, - new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT), + new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT), getUrl(webServer), ChunkingSettingsTests.createRandomChunkingSettings(), 2 @@ -732,10 +732,10 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.123f, -0.123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -744,10 +744,10 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.223f, -0.223f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -762,7 +762,7 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { var model = createInternalEmbeddingModel( - new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT), + new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT), getUrl(webServer) ); String responseJson = """ @@ -807,10 +807,10 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.123f, -0.123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java index 8b35979c3daf5..8ad0b028d3b34 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java @@ -27,8 +27,8 @@ import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings; import org.elasticsearch.xpack.inference.services.custom.InputTypeTranslator; import org.elasticsearch.xpack.inference.services.custom.QueryParameters; +import org.elasticsearch.xpack.inference.services.custom.response.DenseEmbeddingResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; -import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.io.IOException; @@ -62,7 +62,7 @@ public void testCreateRequest() throws IOException { headers, new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"), new QueryParameters.Parameter("key", "value2"))), requestContentString, - new TextEmbeddingResponseParser("$.result.embeddings", CustomServiceEmbeddingType.FLOAT), + new DenseEmbeddingResponseParser("$.result.embeddings", CustomServiceEmbeddingType.FLOAT), new RateLimitSettings(10_000), null, new InputTypeTranslator(Map.of(InputType.INGEST, "value"), "default") @@ -122,7 +122,7 @@ public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() throws IOEx ) ), requestContentString, - new TextEmbeddingResponseParser("$.result.embeddings", CustomServiceEmbeddingType.FLOAT), + new DenseEmbeddingResponseParser("$.result.embeddings", CustomServiceEmbeddingType.FLOAT), new RateLimitSettings(10_000), null, new InputTypeTranslator(Map.of(InputType.INGEST, "value"), "default") @@ -180,7 +180,7 @@ public void testCreateRequest_SecretsInTheJsonBody_AreEncodedCorrectly() throws headers, new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"), new QueryParameters.Parameter("key", "value2"))), requestContentString, - new TextEmbeddingResponseParser("$.result.embeddings", CustomServiceEmbeddingType.FLOAT), + new DenseEmbeddingResponseParser("$.result.embeddings", CustomServiceEmbeddingType.FLOAT), new RateLimitSettings(10_000) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java index e53add6733aca..b475a770c7cfe 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java @@ -13,9 +13,9 @@ import org.elasticsearch.inference.WeightedToken; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; @@ -62,7 +62,7 @@ public void testFromTextEmbeddingResponse() throws IOException { var model = CustomModelTests.getTestModel( TaskType.TEXT_EMBEDDING, - new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT) + new DenseEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT) ); var request = new CustomRequest( EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null), model.getServiceSettings().getInputTypeTranslator()), @@ -73,10 +73,10 @@ public void testFromTextEmbeddingResponse() throws IOException { new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(results, instanceOf(TextEmbeddingFloatResults.class)); + assertThat(results, instanceOf(DenseEmbeddingFloatResults.class)); assertThat( - ((TextEmbeddingFloatResults) results).embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { -0.02868066355586052f, 0.022033605724573135f }))) + ((DenseEmbeddingFloatResults) results).embeddings(), + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { -0.02868066355586052f, 0.022033605724573135f }))) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/DenseEmbeddingResponseParserTests.java similarity index 71% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/DenseEmbeddingResponseParserTests.java index 7796b5e1e7f6b..cffb2a5db6af5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/DenseEmbeddingResponseParserTests.java @@ -16,9 +16,9 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.services.custom.CustomServiceEmbeddingType; @@ -29,24 +29,24 @@ import java.util.List; import java.util.Map; -import static org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser.EMBEDDING_TYPE; -import static org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS; +import static org.elasticsearch.xpack.inference.services.custom.response.DenseEmbeddingResponseParser.EMBEDDING_TYPE; +import static org.elasticsearch.xpack.inference.services.custom.response.DenseEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS; import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; -public class TextEmbeddingResponseParserTests extends AbstractBWCWireSerializationTestCase { +public class DenseEmbeddingResponseParserTests extends AbstractBWCWireSerializationTestCase { private static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE = TransportVersion.fromName( "ml_inference_custom_service_embedding_type" ); - public static TextEmbeddingResponseParser createRandom() { - return new TextEmbeddingResponseParser("$." + randomAlphaOfLength(5), randomFrom(CustomServiceEmbeddingType.values())); + public static DenseEmbeddingResponseParser createRandom() { + return new DenseEmbeddingResponseParser("$." + randomAlphaOfLength(5), randomFrom(CustomServiceEmbeddingType.values())); } public void testFromMap() { var validation = new ValidationException(); - var parser = TextEmbeddingResponseParser.fromMap( + var parser = DenseEmbeddingResponseParser.fromMap( new HashMap<>( Map.of( TEXT_EMBEDDING_PARSER_EMBEDDINGS, @@ -59,14 +59,14 @@ public void testFromMap() { validation ); - assertThat(parser, is(new TextEmbeddingResponseParser("$.result[*].embeddings", CustomServiceEmbeddingType.BIT))); + assertThat(parser, is(new DenseEmbeddingResponseParser("$.result[*].embeddings", CustomServiceEmbeddingType.BIT))); } public void testFromMap_ThrowsException_WhenRequiredFieldIsNotPresent() { var validation = new ValidationException(); var exception = expectThrows( ValidationException.class, - () -> TextEmbeddingResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.result[*].embeddings")), "scope", validation) + () -> DenseEmbeddingResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.result[*].embeddings")), "scope", validation) ); assertThat( @@ -76,7 +76,7 @@ public void testFromMap_ThrowsException_WhenRequiredFieldIsNotPresent() { } public void testToXContent() throws IOException { - var entity = new TextEmbeddingResponseParser("$.result.path", CustomServiceEmbeddingType.BINARY); + var entity = new DenseEmbeddingResponseParser("$.result.path", CustomServiceEmbeddingType.BINARY); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); { @@ -120,14 +120,18 @@ public void testParse() throws IOException { } """; - var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT); - TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) parser.parse( + var parser = new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT); + DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) parser.parse( new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); assertThat( parsedResults, - is(new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))) + is( + new DenseEmbeddingFloatResults( + List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })) + ) + ) ); } @@ -153,12 +157,15 @@ public void testParseByte() throws IOException { } """; - var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.BYTE); - TextEmbeddingByteResults parsedResults = (TextEmbeddingByteResults) parser.parse( + var parser = new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.BYTE); + DenseEmbeddingByteResults parsedResults = (DenseEmbeddingByteResults) parser.parse( new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults, is(new TextEmbeddingByteResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { 1, -2 }))))); + assertThat( + parsedResults, + is(new DenseEmbeddingByteResults(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { 1, -2 })))) + ); } public void testParseBit() throws IOException { @@ -183,12 +190,12 @@ public void testParseBit() throws IOException { } """; - var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.BIT); - TextEmbeddingBitResults parsedResults = (TextEmbeddingBitResults) parser.parse( + var parser = new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.BIT); + DenseEmbeddingBitResults parsedResults = (DenseEmbeddingBitResults) parser.parse( new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults, is(new TextEmbeddingBitResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { 1, -2 }))))); + assertThat(parsedResults, is(new DenseEmbeddingBitResults(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { 1, -2 }))))); } public void testParse_MultipleEmbeddings() throws IOException { @@ -221,18 +228,18 @@ public void testParse_MultipleEmbeddings() throws IOException { } """; - var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT); - TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) parser.parse( + var parser = new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT); + DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) parser.parse( new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); assertThat( parsedResults, is( - new TextEmbeddingFloatResults( + new DenseEmbeddingFloatResults( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 1F, -2F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 1F, -2F }) ) ) ) @@ -269,7 +276,7 @@ public void testParse_ThrowsException_WhenExtractedField_IsNotAListOfFloats() { } """; - var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT); + var parser = new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT); var exception = expectThrows( IllegalArgumentException.class, () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))) @@ -303,7 +310,7 @@ public void testParse_ThrowsException_WhenExtractedField_IsNotAList() { } """; - var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT); + var parser = new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT); var exception = expectThrows( IllegalArgumentException.class, () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))) @@ -319,25 +326,25 @@ public void testParse_ThrowsException_WhenExtractedField_IsNotAList() { } @Override - protected TextEmbeddingResponseParser mutateInstanceForVersion(TextEmbeddingResponseParser instance, TransportVersion version) { + protected DenseEmbeddingResponseParser mutateInstanceForVersion(DenseEmbeddingResponseParser instance, TransportVersion version) { if (version.supports(ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE) == false) { - return new TextEmbeddingResponseParser(instance.getTextEmbeddingsPath(), CustomServiceEmbeddingType.FLOAT); + return new DenseEmbeddingResponseParser(instance.getTextEmbeddingsPath(), CustomServiceEmbeddingType.FLOAT); } return instance; } @Override - protected Writeable.Reader instanceReader() { - return TextEmbeddingResponseParser::new; + protected Writeable.Reader instanceReader() { + return DenseEmbeddingResponseParser::new; } @Override - protected TextEmbeddingResponseParser createTestInstance() { + protected DenseEmbeddingResponseParser createTestInstance() { return createRandom(); } @Override - protected TextEmbeddingResponseParser mutateInstance(TextEmbeddingResponseParser instance) throws IOException { - return randomValueOtherThan(instance, TextEmbeddingResponseParserTests::createRandom); + protected DenseEmbeddingResponseParser mutateInstance(DenseEmbeddingResponseParser instance) throws IOException { + return randomValueOtherThan(instance, DenseEmbeddingResponseParserTests::createRandom); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index b8a82d6a7a29c..551152aafa533 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -40,8 +40,8 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; @@ -889,9 +889,9 @@ public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOExceptio var denseResult = (ChunkedInferenceEmbedding) results.getFirst(); assertThat(denseResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, "hello world".length()), denseResult.chunks().getFirst().offset()); - assertThat(denseResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(denseResult.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); - var embedding = (TextEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding(); + var embedding = (DenseEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding(); assertArrayEquals(new float[] { 0.123f, -0.456f, 0.789f }, embedding.values(), 0.0f); } @@ -901,9 +901,9 @@ public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOExceptio var denseResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(denseResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, "dense embedding".length()), denseResult.chunks().getFirst().offset()); - assertThat(denseResult.chunks().getFirst().embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(denseResult.chunks().getFirst().embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); - var embedding = (TextEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding(); + var embedding = (DenseEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding(); assertArrayEquals(new float[] { 0.987f, -0.654f, 0.321f }, embedding.values(), 0.0f); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java index 1d91713eef931..ab715b7f73e96 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java @@ -20,9 +20,9 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -298,8 +298,8 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction() var result = listener.actionGet(TIMEOUT); - assertThat(result, instanceOf(TextEmbeddingFloatResults.class)); - var textEmbeddingResults = (TextEmbeddingFloatResults) result; + assertThat(result, instanceOf(DenseEmbeddingFloatResults.class)); + var textEmbeddingResults = (DenseEmbeddingFloatResults) result; assertThat(textEmbeddingResults.embeddings(), hasSize(2)); var firstEmbedding = textEmbeddingResults.embeddings().get(0); @@ -354,8 +354,8 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_W var result = listener.actionGet(TIMEOUT); - assertThat(result, instanceOf(TextEmbeddingFloatResults.class)); - var textEmbeddingResults = (TextEmbeddingFloatResults) result; + assertThat(result, instanceOf(DenseEmbeddingFloatResults.class)); + var textEmbeddingResults = (DenseEmbeddingFloatResults) result; assertThat(textEmbeddingResults.embeddings(), hasSize(1)); var embedding = textEmbeddingResults.embeddings().get(0); @@ -447,8 +447,8 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_E var result = listener.actionGet(TIMEOUT); - assertThat(result, instanceOf(TextEmbeddingFloatResults.class)); - var textEmbeddingResults = (TextEmbeddingFloatResults) result; + assertThat(result, instanceOf(DenseEmbeddingFloatResults.class)); + var textEmbeddingResults = (DenseEmbeddingFloatResults) result; assertThat(textEmbeddingResults.embeddings(), hasSize(0)); assertThat(webServer.requests(), hasSize(1)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests.java index 2883a1ab73c21..79721b95af067 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests.java @@ -9,7 +9,7 @@ import org.apache.http.HttpResponse; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceDenseTextEmbeddingsResponseEntity; @@ -35,7 +35,7 @@ public void testDenseTextEmbeddingsResponse_SingleEmbeddingInData_NoMeta() throw } """; - TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -64,7 +64,7 @@ public void testDenseTextEmbeddingsResponse_MultipleEmbeddingsInData_NoMeta() th } """; - TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -85,7 +85,7 @@ public void testDenseTextEmbeddingsResponse_EmptyData() throws Exception { } """; - TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -111,7 +111,7 @@ public void testDenseTextEmbeddingsResponse_SingleEmbeddingInData_IgnoresMeta() } """; - TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 2cc0323e6d913..0db705c1770ba 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -53,9 +53,9 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.MachineLearningField; import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; @@ -71,8 +71,8 @@ import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentTests; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResultsTests; +import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResultsTests; @@ -1109,8 +1109,8 @@ public void testChunkInfer_E5ChunkingSettingsSet() throws InterruptedException { @SuppressWarnings("unchecked") private void testChunkInfer_e5(ChunkingSettings chunkingSettings) throws InterruptedException { var mlTrainedModelResults = new ArrayList(); - mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults()); - mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults()); + mlTrainedModelResults.add(MlDenseEmbeddingResultsTests.createRandomResults()); + mlTrainedModelResults.add(MlDenseEmbeddingResultsTests.createRandomResults()); var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true); Client client = mock(Client.class); @@ -1136,20 +1136,20 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) throws Interru assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbedding.class)); var result1 = (ChunkedInferenceEmbedding) chunkedResponse.get(0); assertThat(result1.chunks(), hasSize(1)); - assertThat(result1.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(result1.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( - ((MlTextEmbeddingResults) mlTrainedModelResults.get(0)).getInferenceAsFloat(), - ((TextEmbeddingFloatResults.Embedding) result1.chunks().get(0).embedding()).values(), + ((MlDenseEmbeddingResults) mlTrainedModelResults.get(0)).getInferenceAsFloat(), + ((DenseEmbeddingFloatResults.Embedding) result1.chunks().get(0).embedding()).values(), 0.0001f ); assertEquals(new ChunkedInference.TextOffset(0, 1), result1.chunks().get(0).offset()); assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbedding.class)); var result2 = (ChunkedInferenceEmbedding) chunkedResponse.get(1); assertThat(result2.chunks(), hasSize(1)); - assertThat(result2.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(result2.chunks().get(0).embedding(), instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( - ((MlTextEmbeddingResults) mlTrainedModelResults.get(1)).getInferenceAsFloat(), - ((TextEmbeddingFloatResults.Embedding) result2.chunks().get(0).embedding()).values(), + ((MlDenseEmbeddingResults) mlTrainedModelResults.get(1)).getInferenceAsFloat(), + ((DenseEmbeddingFloatResults.Embedding) result2.chunks().get(0).embedding()).values(), 0.0001f ); assertEquals(new ChunkedInference.TextOffset(0, 2), result2.chunks().get(0).offset()); @@ -1377,8 +1377,8 @@ public void testChunkInferSetsTokenization() { @SuppressWarnings("unchecked") public void testChunkInfer_FailsBatch() throws InterruptedException { var mlTrainedModelResults = new ArrayList(); - mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults()); - mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults()); + mlTrainedModelResults.add(MlDenseEmbeddingResultsTests.createRandomResults()); + mlTrainedModelResults.add(MlDenseEmbeddingResultsTests.createRandomResults()); mlTrainedModelResults.add(new ErrorInferenceResults(new RuntimeException("boom"))); var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true); @@ -1454,7 +1454,7 @@ public void testChunkingLargeDocument() throws InterruptedException { var listener = (ActionListener) invocationOnMock.getArguments()[2]; var mlTrainedModelResults = new ArrayList(); for (int i = 0; i < request.numberOfDocuments(); i++) { - mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults()); + mlTrainedModelResults.add(MlDenseEmbeddingResultsTests.createRandomResults()); } var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true); listener.onResponse(response); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index be087f73f8d5b..7a5fef5ff0995 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -39,7 +39,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -66,7 +66,7 @@ import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; @@ -941,11 +941,11 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, input.get(0).input().length()), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.0123f, -0.0123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } @@ -956,11 +956,11 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, input.get(1).input().length()), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.0456f, -0.0456f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java index 38e33c546b291..9e6b5547a6b1d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java @@ -40,7 +40,7 @@ import java.util.Map; import java.util.concurrent.TimeUnit; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntityTests.java index eca4a369c29c8..6bfb33769602f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/response/GoogleAiStudioEmbeddingsResponseEntityTests.java @@ -9,7 +9,7 @@ import org.apache.http.HttpResponse; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -36,12 +36,12 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { } """; - TextEmbeddingFloatResults parsedResults = GoogleAiStudioEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = GoogleAiStudioEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F))))); + assertThat(parsedResults.embeddings(), is(List.of(DenseEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F))))); } public void testFromResponse_CreatesResultsForMultipleItems() throws IOException { @@ -64,7 +64,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException } """; - TextEmbeddingFloatResults parsedResults = GoogleAiStudioEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = GoogleAiStudioEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -73,8 +73,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException parsedResults.embeddings(), is( List.of( - TextEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F)), - TextEmbeddingFloatResults.Embedding.of(List.of(0.030681048F, 0.01714732F)) + DenseEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F)), + DenseEmbeddingFloatResults.Embedding.of(List.of(0.030681048F, 0.01714732F)) ) ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntityTests.java index 8f19edb3031d7..c3a8a8b74078b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiEmbeddingsResponseEntityTests.java @@ -9,7 +9,7 @@ import org.apache.http.HttpResponse; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -42,12 +42,12 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { } """; - TextEmbeddingFloatResults parsedResults = GoogleVertexAiEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = GoogleVertexAiEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingFloatResults.Embedding.of(List.of(-0.123F, 0.123F))))); + assertThat(parsedResults.embeddings(), is(List.of(DenseEmbeddingFloatResults.Embedding.of(List.of(-0.123F, 0.123F))))); } public void testFromResponse_CreatesResultsForMultipleItems() throws IOException { @@ -82,7 +82,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException } """; - TextEmbeddingFloatResults parsedResults = GoogleVertexAiEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = GoogleVertexAiEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -91,8 +91,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException parsedResults.embeddings(), is( List.of( - TextEmbeddingFloatResults.Embedding.of(List.of(-0.123F, 0.123F)), - TextEmbeddingFloatResults.Embedding.of(List.of(-0.456F, 0.456F)) + DenseEmbeddingFloatResults.Embedding.of(List.of(-0.123F, 0.123F)), + DenseEmbeddingFloatResults.Embedding.of(List.of(-0.456F, 0.456F)) ) ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index 67534d83a2e15..5d60a87200f3e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -42,8 +42,8 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -76,7 +76,7 @@ import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; @@ -1212,10 +1212,10 @@ public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() th var embeddingResult = (ChunkedInferenceEmbedding) result; assertThat(embeddingResult.chunks(), hasSize(1)); assertThat(embeddingResult.chunks().get(0).offset(), is(new ChunkedInference.TextOffset(0, "abc".length()))); - assertThat(embeddingResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(embeddingResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { -0.0123f, 0.0123f }, - ((TextEmbeddingFloatResults.Embedding) embeddingResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) embeddingResult.chunks().get(0).embedding()).values(), 0.001f ); assertThat(webServer.requests(), hasSize(1)); @@ -1266,10 +1266,10 @@ public void testChunkedInfer() throws IOException { var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 3), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.123f, -0.123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java index 1194f3a5fa95c..1225e40a0fc0c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java @@ -19,8 +19,8 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests; import org.elasticsearch.xpack.inference.InputTypeTests; import org.elasticsearch.xpack.inference.common.TruncatorTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; @@ -231,7 +231,10 @@ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws I var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))); + assertThat( + result.asMap(), + is(DenseEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F }))) + ); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); @@ -416,7 +419,10 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))); + assertThat( + result.asMap(), + is(DenseEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F }))) + ); assertThat(webServer.requests(), hasSize(2)); { @@ -478,7 +484,10 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))); + assertThat( + result.asMap(), + is(DenseEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F }))) + ); assertThat(webServer.requests(), hasSize(1)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntityTests.java index 61e035326d163..e157038f2f244 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceEmbeddingsResponseEntityTests.java @@ -10,7 +10,7 @@ import org.apache.http.HttpResponse; import org.elasticsearch.common.ParsingException; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -32,14 +32,14 @@ public void testFromResponse_CreatesResultsForASingleItem_ArrayFormat() throws I ] """; - TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); assertThat( parsedResults.embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) ); } @@ -55,14 +55,14 @@ public void testFromResponse_CreatesResultsForASingleItem_ObjectFormat() throws } """; - TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); assertThat( parsedResults.embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) ); } @@ -80,7 +80,7 @@ public void testFromResponse_CreatesResultsForMultipleItems_ArrayFormat() throws ] """; - TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -89,8 +89,8 @@ public void testFromResponse_CreatesResultsForMultipleItems_ArrayFormat() throws parsedResults.embeddings(), is( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }) ) ) ); @@ -112,7 +112,7 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw } """; - TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -121,8 +121,8 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw parsedResults.embeddings(), is( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }) ) ) ); @@ -255,12 +255,12 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ArrayFormat() throw ] """; - TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 1.0F })))); + assertThat(parsedResults.embeddings(), is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 1.0F })))); } public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ObjectFormat() throws IOException { @@ -274,12 +274,12 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ObjectFormat() thro } """; - TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 1.0F })))); + assertThat(parsedResults.embeddings(), is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 1.0F })))); } public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ArrayFormat() throws IOException { @@ -291,12 +291,12 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ArrayFormat() thro ] """; - TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F })))); + assertThat(parsedResults.embeddings(), is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F })))); } public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ObjectFormat() throws IOException { @@ -310,12 +310,12 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ObjectFormat() thr } """; - TextEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = HuggingFaceEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F })))); + assertThat(parsedResults.embeddings(), is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F })))); } public void testFromResponse_FailsWhenEmbeddingValueIsAnObject_ObjectFormat() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 6e24981e3f3b3..229bbb73eb14d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -38,7 +38,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -70,7 +70,7 @@ import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; @@ -789,11 +789,11 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, input.get(0).input().length()), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.0123f, -0.0123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } @@ -804,11 +804,11 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, input.get(1).input().length()), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.0456f, -0.0456f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java index a0980a8151036..793bb535c0cc5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java @@ -44,7 +44,7 @@ import java.util.Map; import java.util.concurrent.TimeUnit; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntityTests.java index db42e9c49e1e5..c3bdb8c686985 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntityTests.java @@ -9,7 +9,7 @@ import org.apache.http.HttpResponse; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -36,12 +36,12 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { } """; - TextEmbeddingFloatResults parsedResults = IbmWatsonxEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = IbmWatsonxEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F))))); + assertThat(parsedResults.embeddings(), is(List.of(DenseEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F))))); } public void testFromResponse_CreatesResultsForMultipleItems() throws IOException { @@ -66,7 +66,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException } """; - TextEmbeddingFloatResults parsedResults = IbmWatsonxEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = IbmWatsonxEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -75,8 +75,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException parsedResults.embeddings(), is( List.of( - TextEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F)), - TextEmbeddingFloatResults.Embedding.of(List.of(0.030681048F, 0.01714732F)) + DenseEmbeddingFloatResults.Embedding.of(List.of(-0.00606332F, 0.058092743F)), + DenseEmbeddingFloatResults.Embedding.of(List.of(0.030681048F, 0.01714732F)) ) ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index d408e269219cb..a6a04692d33bf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -38,7 +38,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -65,7 +65,7 @@ import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; @@ -1688,10 +1688,10 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel mode var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.123f, -0.123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -1700,10 +1700,10 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel mode var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.223f, -0.223f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntityTests.java index c1b19cb450789..4df61cdc439af 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIEmbeddingsResponseEntityTests.java @@ -11,9 +11,9 @@ import org.elasticsearch.common.ParsingException; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.InputTypeTests; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType; @@ -69,10 +69,10 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults, instanceOf(TextEmbeddingFloatResults.class)); + assertThat(parsedResults, instanceOf(DenseEmbeddingFloatResults.class)); assertThat( - ((TextEmbeddingFloatResults) parsedResults).embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) + ((DenseEmbeddingFloatResults) parsedResults).embeddings(), + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) ); } @@ -123,13 +123,13 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults, instanceOf(TextEmbeddingFloatResults.class)); + assertThat(parsedResults, instanceOf(DenseEmbeddingFloatResults.class)); assertThat( - ((TextEmbeddingFloatResults) parsedResults).embeddings(), + ((DenseEmbeddingFloatResults) parsedResults).embeddings(), is( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }) ) ) ); @@ -363,8 +363,8 @@ public void testFromResponse_SucceedsWhenEmbeddingType_IsBinary() throws IOExcep ); assertThat( - ((TextEmbeddingBitResults) parsedResults).embeddings(), - is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 }))) + ((DenseEmbeddingBitResults) parsedResults).embeddings(), + is(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 }))) ); } @@ -411,8 +411,8 @@ public void testFromResponse_SucceedsWhenEmbeddingType_IsBit() throws IOExceptio ); assertThat( - ((TextEmbeddingBitResults) parsedResults).embeddings(), - is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 }))) + ((DenseEmbeddingBitResults) parsedResults).embeddings(), + is(List.of(new DenseEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 }))) ); } @@ -504,7 +504,7 @@ public void testFieldsInDifferentOrderServer() throws IOException { } }"""; - TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) JinaAIEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = (DenseEmbeddingFloatResults) JinaAIEmbeddingsResponseEntity.fromResponse( JinaAIEmbeddingsRequestTests.createRequest( List.of("abc"), InputTypeTests.randomWithNull(), @@ -525,9 +525,9 @@ public void testFieldsInDifferentOrderServer() throws IOException { parsedResults.embeddings(), is( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F }) ) ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java index f6a0232db529c..818ae94192ce1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java @@ -42,7 +42,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; @@ -695,11 +695,11 @@ public void testChunkedInfer(LlamaEmbeddingsModel model) throws IOException { assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.010060793f, -0.0017529363f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } @@ -707,11 +707,11 @@ public void testChunkedInfer(LlamaEmbeddingsModel model) throws IOException { assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.110060793f, -0.1017529363f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java index 5bf65870e1dfc..df2ce97614edc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java @@ -19,7 +19,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests; import org.elasticsearch.xpack.inference.InputTypeTests; import org.elasticsearch.xpack.inference.common.TruncatorTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; @@ -98,7 +98,10 @@ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws I var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))); + assertThat( + result.asMap(), + is(DenseEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F }))) + ); assertEmbeddingsRequest(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index db94ec5c9c2f5..3ff7bf1a78374 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -41,7 +41,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.ModelConfigurationsTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; @@ -1135,11 +1135,11 @@ public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.123f, -0.123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } @@ -1147,11 +1147,11 @@ public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.223f, -0.223f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 0a03c7a231e8c..85c9c73d002b2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -43,7 +43,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; @@ -84,7 +84,7 @@ import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; @@ -1049,11 +1049,11 @@ private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException { var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.123f, -0.123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } @@ -1062,11 +1062,11 @@ private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException { var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.223f, -0.223f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java index 5cdcf402835b4..1291a1ad4a36c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java @@ -36,7 +36,7 @@ import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java index c34ebbde4ac8f..1851dc4cad2b3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java @@ -40,7 +40,7 @@ import java.util.List; import java.util.concurrent.TimeUnit; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntityTests.java index f2a430eefc801..77e1f384509b4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntityTests.java @@ -10,7 +10,7 @@ import org.apache.http.HttpResponse; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentParseException; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -45,14 +45,14 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { } """; - TextEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); assertThat( parsedResults.embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) ); } @@ -86,7 +86,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException } """; - TextEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -95,8 +95,8 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException parsedResults.embeddings(), is( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }) ) ) ); @@ -254,12 +254,12 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOExceptio } """; - TextEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 1.0F })))); + assertThat(parsedResults.embeddings(), is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 1.0F })))); } public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOException { @@ -283,12 +283,12 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOExcepti } """; - TextEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F })))); + assertThat(parsedResults.embeddings(), is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F })))); } public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() { @@ -365,7 +365,7 @@ public void testFieldsInDifferentOrderServer() throws IOException { } }"""; - TextEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse( + DenseEmbeddingFloatResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), response.getBytes(StandardCharsets.UTF_8)) ); @@ -374,9 +374,9 @@ public void testFieldsInDifferentOrderServer() throws IOException { parsedResults.embeddings(), is( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F }) ) ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java index 5d6bec1bcfbff..b391ce0cc7949 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java @@ -29,7 +29,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; @@ -449,7 +449,7 @@ public void testChunkedInfer() throws Exception { var model = mockModelForChunking(); SageMakerSchema schema = mock(); - when(schema.response(any(), any(), any())).thenReturn(TextEmbeddingFloatResultsTests.createRandomResults()); + when(schema.response(any(), any(), any())).thenReturn(DenseEmbeddingFloatResultsTests.createRandomResults()); when(schemas.schemaFor(model)).thenReturn(schema); mockInvoke(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayloadTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayloadTests.java index ed0ee43266ab5..3ac0f20d4242c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayloadTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayloadTests.java @@ -9,8 +9,8 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema; @@ -66,8 +66,8 @@ public void testBitResponse() throws Exception { assertThat(bitResults.embeddings().size(), is(1)); var embedding = bitResults.embeddings().get(0); - assertThat(embedding, isA(TextEmbeddingByteResults.Embedding.class)); - assertThat(((TextEmbeddingByteResults.Embedding) embedding).values(), is(new byte[] { 23 })); + assertThat(embedding, isA(DenseEmbeddingByteResults.Embedding.class)); + assertThat(((DenseEmbeddingByteResults.Embedding) embedding).values(), is(new byte[] { 23 })); } public void testByteResponse() throws Exception { @@ -87,8 +87,8 @@ public void testByteResponse() throws Exception { assertThat(byteResults.embeddings().size(), is(1)); var embedding = byteResults.embeddings().get(0); - assertThat(embedding, isA(TextEmbeddingByteResults.Embedding.class)); - assertThat(((TextEmbeddingByteResults.Embedding) embedding).values(), is(new byte[] { 23 })); + assertThat(embedding, isA(DenseEmbeddingByteResults.Embedding.class)); + assertThat(((DenseEmbeddingByteResults.Embedding) embedding).values(), is(new byte[] { 23 })); } public void testFloatResponse() throws Exception { @@ -108,7 +108,7 @@ public void testFloatResponse() throws Exception { assertThat(byteResults.embeddings().size(), is(1)); var embedding = byteResults.embeddings().get(0); - assertThat(embedding, isA(TextEmbeddingFloatResults.Embedding.class)); - assertThat(((TextEmbeddingFloatResults.Embedding) embedding).values(), is(new float[] { 0.1F })); + assertThat(embedding, isA(DenseEmbeddingFloatResults.Embedding.class)); + assertThat(((DenseEmbeddingFloatResults.Embedding) embedding).values(), is(new float[] { 0.1F })); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayloadTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayloadTests.java index 35b78b004618c..2435514e9f7e1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayloadTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayloadTests.java @@ -12,7 +12,7 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemaPayloadTestCase; @@ -145,11 +145,11 @@ public void testResponse() throws Exception { .body(SdkBytes.fromString(responseJson, StandardCharsets.UTF_8)) .build(); - var textEmbeddingFloatResults = payload.responseBody(mock(), invokeEndpointResponse); + var denseEmbeddingFloatResults = payload.responseBody(mock(), invokeEndpointResponse); assertThat( - textEmbeddingFloatResults.embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) + denseEmbeddingFloatResults.embeddings(), + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/DenseEmbeddingModelValidatorTests.java similarity index 90% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/DenseEmbeddingModelValidatorTests.java index 45726f0789667..9ad4e81b20cb7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/DenseEmbeddingModelValidatorTests.java @@ -16,10 +16,10 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.validation.ServiceIntegrationValidator; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResultsTests; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResultsTests; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.EmptyTaskSettingsTests; import org.elasticsearch.xpack.inference.ModelConfigurationsTests; import org.junit.Before; @@ -38,7 +38,7 @@ import static org.mockito.Mockito.when; import static org.mockito.MockitoAnnotations.openMocks; -public class TextEmbeddingModelValidatorTests extends ESTestCase { +public class DenseEmbeddingModelValidatorTests extends ESTestCase { private static final TimeValue TIMEOUT = TimeValue.ONE_MINUTE; @@ -53,13 +53,13 @@ public class TextEmbeddingModelValidatorTests extends ESTestCase { @Mock private ServiceSettings mockServiceSettings; - private TextEmbeddingModelValidator underTest; + private DenseEmbeddingModelValidator underTest; @Before public void setup() { openMocks(this); - underTest = new TextEmbeddingModelValidator(mockServiceIntegrationValidator); + underTest = new DenseEmbeddingModelValidator(mockServiceIntegrationValidator); when(mockInferenceService.updateModelWithEmbeddingDetails(eq(mockModel), anyInt())).thenReturn(mockModel); when(mockActionListener.delegateFailureAndWrap(any())).thenCallRealMethod(); @@ -94,7 +94,7 @@ public void testValidate_ServiceReturnsNonTextEmbeddingResults() { } public void testValidate_RetrievingEmbeddingSizeThrowsIllegalStateException() { - TextEmbeddingFloatResults results = new TextEmbeddingFloatResults(List.of()); + DenseEmbeddingFloatResults results = new DenseEmbeddingFloatResults(List.of()); when(mockServiceSettings.dimensionsSetByUser()).thenReturn(true); when(mockServiceSettings.dimensions()).thenReturn(randomNonNegativeInt()); @@ -107,7 +107,7 @@ public void testValidate_RetrievingEmbeddingSizeThrowsIllegalStateException() { } public void testValidate_DimensionsSetByUserDoNotEqualEmbeddingSize() { - TextEmbeddingByteResults results = TextEmbeddingByteResultsTests.createRandomResults(); + DenseEmbeddingByteResults results = DenseEmbeddingByteResultsTests.createRandomResults(); var dimensions = randomValueOtherThan(results.getFirstEmbeddingSize(), ESTestCase::randomNonNegativeInt); when(mockServiceSettings.dimensionsSetByUser()).thenReturn(true); @@ -131,7 +131,7 @@ public void testValidate_DimensionsNotSetByUser() { } private void mockSuccessfulValidation(Boolean dimensionsSetByUser) { - TextEmbeddingByteResults results = TextEmbeddingByteResultsTests.createRandomResults(); + DenseEmbeddingByteResults results = DenseEmbeddingByteResultsTests.createRandomResults(); when(mockModel.getConfigurations()).thenReturn(ModelConfigurationsTests.createRandomInstance()); when(mockModel.getTaskSettings()).thenReturn(EmptyTaskSettingsTests.createRandom()); when(mockServiceSettings.dimensionsSetByUser()).thenReturn(dimensionsSetByUser); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidatorTests.java index 30e7a33757c16..c9db9907792fd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidatorTests.java @@ -16,7 +16,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingResults; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandEmbeddingModel; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings; @@ -118,7 +118,7 @@ public void testValidate_ElandTextEmbeddingModelValidationFails() { public void testValidate_ElandTextEmbeddingModelValidationSucceedsAndDimensionsSetByUserValid() { var dimensions = randomIntBetween(1, 10); - var mockInferenceServiceResults = mock(TextEmbeddingResults.class); + var mockInferenceServiceResults = mock(DenseEmbeddingResults.class); var mockUpdatedModel = mock(CustomElandEmbeddingModel.class); when(mockInferenceServiceResults.getFirstEmbeddingSize()).thenReturn(dimensions); CustomElandEmbeddingModel customElandEmbeddingModel = createCustomElandEmbeddingModel(true, dimensions); @@ -151,7 +151,7 @@ public void testValidate_ElandTextEmbeddingModelValidationSucceedsAndDimensionsS public void testValidate_ElandTextEmbeddingModelValidationSucceedsAndDimensionsSetByUserInvalid() { var dimensions = randomIntBetween(1, 10); - var mockInferenceServiceResults = mock(TextEmbeddingResults.class); + var mockInferenceServiceResults = mock(DenseEmbeddingResults.class); when(mockInferenceServiceResults.getFirstEmbeddingSize()).thenReturn( randomValueOtherThan(dimensions, () -> randomIntBetween(1, 10)) ); @@ -207,7 +207,7 @@ public void testValidate_ElandTextEmbeddingAndValidationReturnsInvalidResultsTyp public void testValidate_ElandTextEmbeddingModelDimensionsNotSetByUser() { var dimensions = randomIntBetween(1, 10); - var mockInferenceServiceResults = mock(TextEmbeddingResults.class); + var mockInferenceServiceResults = mock(DenseEmbeddingResults.class); when(mockInferenceServiceResults.getFirstEmbeddingSize()).thenReturn(dimensions); CustomElandEmbeddingModel customElandEmbeddingModel = createCustomElandEmbeddingModel(false, null); @@ -239,7 +239,7 @@ public void testValidate_ElandTextEmbeddingModelDimensionsNotSetByUser() { } public void testValidate_ElandTextEmbeddingModelAndEmbeddingSizeRetrievalThrowsException() { - var mockInferenceServiceResults = mock(TextEmbeddingResults.class); + var mockInferenceServiceResults = mock(DenseEmbeddingResults.class); when(mockInferenceServiceResults.getFirstEmbeddingSize()).thenThrow(ElasticsearchStatusException.class); CustomElandEmbeddingModel customElandEmbeddingModel = createCustomElandEmbeddingModel(false, null); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java index b4d563b565eee..804984c866542 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java @@ -95,7 +95,7 @@ public void testBuildModelValidator_ValidTaskType() { private Map> taskTypeToModelValidatorClassMap() { return Map.of( TaskType.TEXT_EMBEDDING, - TextEmbeddingModelValidator.class, + DenseEmbeddingModelValidator.class, TaskType.SPARSE_EMBEDDING, SimpleModelValidator.class, TaskType.RERANK, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index 9bbfcb4d58667..0e9059132f23d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -37,7 +37,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -63,7 +63,7 @@ import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; @@ -1654,10 +1654,10 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel mo var floatResult = (ChunkedInferenceEmbedding) results.getFirst(); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().getFirst().offset()); - assertThat(floatResult.chunks().get(0).embedding(), CoreMatchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat(floatResult.chunks().get(0).embedding(), CoreMatchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.123f, -0.123f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -1666,10 +1666,13 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel mo var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().getFirst().offset()); - assertThat(floatResult.chunks().getFirst().embedding(), CoreMatchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertThat( + floatResult.chunks().getFirst().embedding(), + CoreMatchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class) + ); assertArrayEquals( new float[] { 0.223f, -0.223f }, - ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + ((DenseEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java index afbb5532d5fdd..86d6fab29842b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java @@ -37,7 +37,7 @@ import java.util.Map; import java.util.concurrent.TimeUnit; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java index 7050c2d131e17..b326664c527c1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java @@ -46,9 +46,9 @@ import java.util.Map; import java.util.concurrent.TimeUnit; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationBinary; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationByte; -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationBinary; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationByte; +import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntityTests.java index 80a00737ddf52..7fa6b0c7bece3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/response/VoyageAIEmbeddingsResponseEntityTests.java @@ -11,7 +11,7 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentParseException; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.inference.InputTypeTests; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.services.voyageai.request.VoyageAIEmbeddingsRequest; @@ -60,8 +60,8 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { ); assertThat( - ((TextEmbeddingFloatResults) parsedResults).embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) + ((DenseEmbeddingFloatResults) parsedResults).embeddings(), + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) ); } @@ -106,11 +106,11 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException ); assertThat( - ((TextEmbeddingFloatResults) parsedResults).embeddings(), + ((DenseEmbeddingFloatResults) parsedResults).embeddings(), is( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.0123F, -0.0123F }) ) ) ); @@ -299,8 +299,8 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOExceptio ); assertThat( - ((TextEmbeddingFloatResults) parsedResults).embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 1.0F }))) + ((DenseEmbeddingFloatResults) parsedResults).embeddings(), + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 1.0F }))) ); } @@ -336,8 +336,8 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOExcepti ); assertThat( - ((TextEmbeddingFloatResults) parsedResults).embeddings(), - is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F }))) + ((DenseEmbeddingFloatResults) parsedResults).embeddings(), + is(List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F }))) ); } @@ -427,15 +427,15 @@ public void testFieldsInDifferentOrderServer() throws IOException { new HttpResult(mock(HttpResponse.class), response.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults, instanceOf(TextEmbeddingFloatResults.class)); + assertThat(parsedResults, instanceOf(DenseEmbeddingFloatResults.class)); assertThat( - ((TextEmbeddingFloatResults) parsedResults).embeddings(), + ((DenseEmbeddingFloatResults) parsedResults).embeddings(), is( List.of( - new TextEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }), - new TextEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F }) + new DenseEmbeddingFloatResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.5F }), + new DenseEmbeddingFloatResults.Embedding(new float[] { 0.5F, 0.5F }) ) ) ); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java index c28fc8f44c3fa..49598aa620464 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java @@ -9,7 +9,7 @@ import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult; @@ -85,7 +85,7 @@ static InferenceResults processResult( tokenization.anyTruncated() ); } else { - return new MlTextEmbeddingResults( + return new MlDenseEmbeddingResults( Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD), pyTorchResult.getInferenceResult()[0][0], tokenization.anyTruncated() diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessorTests.java index 8369412580b88..6f66a05187fee 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessorTests.java @@ -10,7 +10,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizationResult; @@ -40,9 +40,9 @@ public void testSingleResult() { var tokenization = tokenizer.tokenize(input, Tokenization.Truncate.NONE, 0, 0, null); var tokenizationResult = new BertTokenizationResult(TextExpansionProcessorTests.TEST_CASED_VOCAB, tokenization, 0); var inferenceResult = TextEmbeddingProcessor.processResult(tokenizationResult, pytorchResult, "foo", false); - assertThat(inferenceResult, instanceOf(MlTextEmbeddingResults.class)); + assertThat(inferenceResult, instanceOf(MlDenseEmbeddingResults.class)); - var result = (MlTextEmbeddingResults) inferenceResult; + var result = (MlDenseEmbeddingResults) inferenceResult; assertThat(result.getInference().length, greaterThan(0)); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java index 86b4e8c588e51..fe269b6bcb0f5 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java @@ -17,7 +17,7 @@ import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; -import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder; import org.elasticsearch.xpack.ml.MachineLearningTests; @@ -53,7 +53,7 @@ public ActionResponse createResponse(float[] array, TextEmbeddingQueryVectorBuil embedding[i] = array[i]; } return new InferModelAction.Response( - List.of(new MlTextEmbeddingResults("foo", embedding, randomBoolean())), + List.of(new MlDenseEmbeddingResults("foo", embedding, randomBoolean())), builder.getModelId(), true );