Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public interface InferenceServiceResults extends NamedWriteable, ChunkedToXConte

/**
* <p>Transform the result to match the format required for the TransportCoordinatedInferenceAction.
* TransportCoordinatedInferenceAction expects an ml plugin TextEmbeddingResults or SparseEmbeddingResults.</p>
* TransportCoordinatedInferenceAction expects an ml plugin DenseEmbeddingResults or SparseEmbeddingResults.</p>
*/
default List<? extends InferenceResults> transformToCoordinationFormat() {
throw new UnsupportedOperationException("transformToCoordinationFormat() is not implemented");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;

import java.io.IOException;
import java.util.Iterator;
Expand All @@ -24,9 +24,10 @@
import java.util.Objects;

/**
* Writes a text embedding result in the follow json format
* Writes a dense embedding result in the follow json format.
* <pre>
* {
* "text_embedding_bytes": [
* "text_embedding_bits": [
* {
* "embedding": [
* 23
Expand All @@ -39,17 +40,19 @@
* }
* ]
* }
* </pre>
*/
// Note: inheriting from TextEmbeddingByteResults gives a bad implementation of the
// Note: inheriting from DenseEmbeddingByteResults gives a bad implementation of the
// Embedding.merge method for bits. TODO: implement a proper merge method
public record TextEmbeddingBitResults(List<TextEmbeddingByteResults.Embedding> embeddings)
public record DenseEmbeddingBitResults(List<DenseEmbeddingByteResults.Embedding> embeddings)
implements
TextEmbeddingResults<TextEmbeddingByteResults.Embedding> {
DenseEmbeddingResults<DenseEmbeddingByteResults.Embedding> {
// This name is a holdover from before this class was renamed
public static final String NAME = "text_embedding_service_bit_results";
public static final String TEXT_EMBEDDING_BITS = "text_embedding_bits";

public TextEmbeddingBitResults(StreamInput in) throws IOException {
this(in.readCollectionAsList(TextEmbeddingByteResults.Embedding::new));
public DenseEmbeddingBitResults(StreamInput in) throws IOException {
this(in.readCollectionAsList(DenseEmbeddingByteResults.Embedding::new));
}

@Override
Expand Down Expand Up @@ -79,7 +82,7 @@ public String getWriteableName() {
@Override
public List<? extends InferenceResults> transformToCoordinationFormat() {
return embeddings.stream()
.map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING_BITS, embedding.toDoubleArray(), false))
.map(embedding -> new MlDenseEmbeddingResults(TEXT_EMBEDDING_BITS, embedding.toDoubleArray(), false))
.toList();
}

Expand All @@ -94,7 +97,7 @@ public Map<String, Object> asMap() {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TextEmbeddingBitResults that = (TextEmbeddingBitResults) o;
DenseEmbeddingBitResults that = (DenseEmbeddingBitResults) o;
return Objects.equals(embeddings, that.embeddings);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;

import java.io.IOException;
import java.util.Arrays;
Expand All @@ -31,7 +31,8 @@
import java.util.Objects;

/**
* Writes a text embedding result in the follow json format
* Writes a dense embedding result in the follow json format
* <pre>
* {
* "text_embedding_bytes": [
* {
Expand All @@ -46,13 +47,15 @@
* }
* ]
* }
* </pre>
*/
public record TextEmbeddingByteResults(List<Embedding> embeddings) implements TextEmbeddingResults<TextEmbeddingByteResults.Embedding> {
public record DenseEmbeddingByteResults(List<Embedding> embeddings) implements DenseEmbeddingResults<DenseEmbeddingByteResults.Embedding> {
// This name is a holdover from before this class was renamed
public static final String NAME = "text_embedding_service_byte_results";
public static final String TEXT_EMBEDDING_BYTES = "text_embedding_bytes";

public TextEmbeddingByteResults(StreamInput in) throws IOException {
this(in.readCollectionAsList(TextEmbeddingByteResults.Embedding::new));
public DenseEmbeddingByteResults(StreamInput in) throws IOException {
this(in.readCollectionAsList(DenseEmbeddingByteResults.Embedding::new));
}

@Override
Expand Down Expand Up @@ -81,7 +84,7 @@ public String getWriteableName() {
@Override
public List<? extends InferenceResults> transformToCoordinationFormat() {
return embeddings.stream()
.map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING_BYTES, embedding.toDoubleArray(), false))
.map(embedding -> new MlDenseEmbeddingResults(TEXT_EMBEDDING_BYTES, embedding.toDoubleArray(), false))
.toList();
}

Expand All @@ -96,7 +99,7 @@ public Map<String, Object> asMap() {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TextEmbeddingByteResults that = (TextEmbeddingByteResults) o;
DenseEmbeddingByteResults that = (DenseEmbeddingByteResults) o;
return Objects.equals(embeddings, that.embeddings);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -35,7 +35,8 @@
import java.util.Objects;

/**
* Writes a text embedding result in the follow json format
* Writes a dense embedding result in the follow json format
* <pre>
* {
* "text_embedding": [
* {
Expand All @@ -50,20 +51,24 @@
* }
* ]
* }
* </pre>
*/
public record TextEmbeddingFloatResults(List<Embedding> embeddings) implements TextEmbeddingResults<TextEmbeddingFloatResults.Embedding> {
public record DenseEmbeddingFloatResults(List<Embedding> embeddings)
implements
DenseEmbeddingResults<DenseEmbeddingFloatResults.Embedding> {
// This name is a holdover from before this class was renamed
public static final String NAME = "text_embedding_service_results";
public static final String TEXT_EMBEDDING = TaskType.TEXT_EMBEDDING.toString();

public TextEmbeddingFloatResults(StreamInput in) throws IOException {
this(in.readCollectionAsList(TextEmbeddingFloatResults.Embedding::new));
public DenseEmbeddingFloatResults(StreamInput in) throws IOException {
this(in.readCollectionAsList(DenseEmbeddingFloatResults.Embedding::new));
}

public static TextEmbeddingFloatResults of(List<? extends InferenceResults> results) {
public static DenseEmbeddingFloatResults of(List<? extends InferenceResults> results) {
List<Embedding> embeddings = new ArrayList<>(results.size());
for (InferenceResults result : results) {
if (result instanceof MlTextEmbeddingResults embeddingResult) {
embeddings.add(TextEmbeddingFloatResults.Embedding.of(embeddingResult));
if (result instanceof MlDenseEmbeddingResults embeddingResult) {
embeddings.add(DenseEmbeddingFloatResults.Embedding.of(embeddingResult));
} else if (result instanceof org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults errorResult) {
if (errorResult.getException() instanceof ElasticsearchStatusException statusException) {
throw statusException;
Expand All @@ -76,11 +81,15 @@ public static TextEmbeddingFloatResults of(List<? extends InferenceResults> resu
}
} else {
throw new IllegalArgumentException(
"Received invalid inference result, of type " + result.getClass().getName() + " but expected TextEmbeddingResults."
"Received invalid inference result, of type "
+ result.getClass().getName()
+ " but expected "
+ MlDenseEmbeddingResults.class.getName()
+ "."
);
}
}
return new TextEmbeddingFloatResults(embeddings);
return new DenseEmbeddingFloatResults(embeddings);
}

@Override
Expand Down Expand Up @@ -108,7 +117,7 @@ public String getWriteableName() {

@Override
public List<? extends InferenceResults> transformToCoordinationFormat() {
return embeddings.stream().map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING, embedding.asDoubleArray(), false)).toList();
return embeddings.stream().map(embedding -> new MlDenseEmbeddingResults(TEXT_EMBEDDING, embedding.asDoubleArray(), false)).toList();
}

public Map<String, Object> asMap() {
Expand All @@ -122,7 +131,7 @@ public Map<String, Object> asMap() {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TextEmbeddingFloatResults that = (TextEmbeddingFloatResults) o;
DenseEmbeddingFloatResults that = (DenseEmbeddingFloatResults) o;
return Objects.equals(embeddings, that.embeddings);
}

Expand All @@ -148,7 +157,7 @@ public Embedding(StreamInput in) throws IOException {
this(in.readFloatArray());
}

public static Embedding of(MlTextEmbeddingResults embeddingResult) {
public static Embedding of(MlDenseEmbeddingResults embeddingResult) {
float[] embeddingAsArray = embeddingResult.getInferenceAsFloat();
return new Embedding(embeddingAsArray);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

package org.elasticsearch.xpack.core.inference.results;

public interface TextEmbeddingResults<E extends EmbeddingResults.Embedding<E>> extends EmbeddingResults<E> {
public interface DenseEmbeddingResults<E extends EmbeddingResults.Embedding<E>> extends EmbeddingResults<E> {

/**
* Returns the first text embedding entry in the result list's array size.
* @return the size of the text embedding
* Returns the first embedding entry in the result list's array size.
* @return the size of the embedding
* @throws IllegalStateException if the list of embeddings is empty
*/
int getFirstEmbeddingSize() throws IllegalStateException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults;
Expand Down Expand Up @@ -669,7 +669,7 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
);
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, TextExpansionResults.NAME, TextExpansionResults::new));
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceResults.class, MlTextEmbeddingResults.NAME, MlTextEmbeddingResults::new)
new NamedWriteableRegistry.Entry(InferenceResults.class, MlDenseEmbeddingResults.NAME, MlDenseEmbeddingResults::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,21 @@
import java.util.Map;
import java.util.Objects;

public class MlTextEmbeddingResults extends NlpInferenceResults {
public class MlDenseEmbeddingResults extends NlpInferenceResults {

// This name is a holdover from before this class was renamed
public static final String NAME = "text_embedding_result";

private final String resultsField;
private final double[] inference;

public MlTextEmbeddingResults(String resultsField, double[] inference, boolean isTruncated) {
public MlDenseEmbeddingResults(String resultsField, double[] inference, boolean isTruncated) {
super(isTruncated);
this.inference = inference;
this.resultsField = resultsField;
}

public MlTextEmbeddingResults(StreamInput in) throws IOException {
public MlDenseEmbeddingResults(StreamInput in) throws IOException {
super(in);
inference = in.readDoubleArray();
resultsField = in.readString();
Expand Down Expand Up @@ -89,7 +90,7 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (super.equals(o) == false) return false;
MlTextEmbeddingResults that = (MlTextEmbeddingResults) o;
MlDenseEmbeddingResults that = (MlDenseEmbeddingResults) o;
return Objects.equals(resultsField, that.resultsField) && Arrays.equals(inference, that.inference);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate;

Expand Down Expand Up @@ -117,14 +117,14 @@ public void buildVector(Client client, ActionListener<float[]> listener) {
return;
}

if (response.getInferenceResults().get(0) instanceof MlTextEmbeddingResults textEmbeddingResults) {
if (response.getInferenceResults().get(0) instanceof MlDenseEmbeddingResults textEmbeddingResults) {
listener.onResponse(textEmbeddingResults.getInferenceAsFloat());
} else if (response.getInferenceResults().get(0) instanceof WarningInferenceResults warning) {
listener.onFailure(new IllegalStateException(warning.getWarning()));
} else {
throw new IllegalArgumentException(
"expected a result of type ["
+ MlTextEmbeddingResults.NAME
+ MlDenseEmbeddingResults.NAME
+ "] received ["
+ response.getInferenceResults().get(0).getWriteableName()
+ "]. Is ["
Expand Down
Loading