Skip to content

Commit a0f415d

Browse files
authored
[ML] Replace "text embedding" with "dense embedding" (#136321)
The name "text embedding" is used in many places where dense vector embeddings are handled, despite the type of the embedding vector not being exclusive to text embeddings. For example, image or multimodal embeddings may also produce a dense vector. To allow future reuse of classes related to dense vectors with multimodal embeddings, the naming is being changed to the more general "dense embedding". Classes which explicitly relate to text embeddings are not being renamed. This rename is internal to the code only and does not change the name of any JSON objects which currently use "text_embedding", as doing so would be a breaking change. - For everything not exclusively related to text embedding, rename classes, methods and variables to use "dense embedding" instead of "text embedding" - Use correct class name in ElasticTextEmbeddingPayload.TextEmbeddingFloat.PARSER - Correct the javadoc in DenseEmbeddingBitResults
1 parent 1758533 commit a0f415d

File tree

112 files changed

+956
-864
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

112 files changed

+956
-864
lines changed

server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public interface InferenceServiceResults extends NamedWriteable, ChunkedToXConte
2424

2525
/**
2626
* <p>Transform the result to match the format required for the TransportCoordinatedInferenceAction.
27-
* TransportCoordinatedInferenceAction expects an ml plugin TextEmbeddingResults or SparseEmbeddingResults.</p>
27+
* TransportCoordinatedInferenceAction expects an ml plugin DenseEmbeddingResults or SparseEmbeddingResults.</p>
2828
*/
2929
default List<? extends InferenceResults> transformToCoordinationFormat() {
3030
throw new UnsupportedOperationException("transformToCoordinationFormat() is not implemented");
Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
1515
import org.elasticsearch.inference.InferenceResults;
1616
import org.elasticsearch.xcontent.ToXContent;
17-
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
17+
import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
1818

1919
import java.io.IOException;
2020
import java.util.Iterator;
@@ -24,9 +24,10 @@
2424
import java.util.Objects;
2525

2626
/**
27-
* Writes a text embedding result in the follow json format
27+
* Writes a dense embedding result in the follow json format.
28+
* <pre>
2829
* {
29-
* "text_embedding_bytes": [
30+
* "text_embedding_bits": [
3031
* {
3132
* "embedding": [
3233
* 23
@@ -39,17 +40,19 @@
3940
* }
4041
* ]
4142
* }
43+
* </pre>
4244
*/
43-
// Note: inheriting from TextEmbeddingByteResults gives a bad implementation of the
45+
// Note: inheriting from DenseEmbeddingByteResults gives a bad implementation of the
4446
// Embedding.merge method for bits. TODO: implement a proper merge method
45-
public record TextEmbeddingBitResults(List<TextEmbeddingByteResults.Embedding> embeddings)
47+
public record DenseEmbeddingBitResults(List<DenseEmbeddingByteResults.Embedding> embeddings)
4648
implements
47-
TextEmbeddingResults<TextEmbeddingByteResults.Embedding> {
49+
DenseEmbeddingResults<DenseEmbeddingByteResults.Embedding> {
50+
// This name is a holdover from before this class was renamed
4851
public static final String NAME = "text_embedding_service_bit_results";
4952
public static final String TEXT_EMBEDDING_BITS = "text_embedding_bits";
5053

51-
public TextEmbeddingBitResults(StreamInput in) throws IOException {
52-
this(in.readCollectionAsList(TextEmbeddingByteResults.Embedding::new));
54+
public DenseEmbeddingBitResults(StreamInput in) throws IOException {
55+
this(in.readCollectionAsList(DenseEmbeddingByteResults.Embedding::new));
5356
}
5457

5558
@Override
@@ -79,7 +82,7 @@ public String getWriteableName() {
7982
@Override
8083
public List<? extends InferenceResults> transformToCoordinationFormat() {
8184
return embeddings.stream()
82-
.map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING_BITS, embedding.toDoubleArray(), false))
85+
.map(embedding -> new MlDenseEmbeddingResults(TEXT_EMBEDDING_BITS, embedding.toDoubleArray(), false))
8386
.toList();
8487
}
8588

@@ -94,7 +97,7 @@ public Map<String, Object> asMap() {
9497
public boolean equals(Object o) {
9598
if (this == o) return true;
9699
if (o == null || getClass() != o.getClass()) return false;
97-
TextEmbeddingBitResults that = (TextEmbeddingBitResults) o;
100+
DenseEmbeddingBitResults that = (DenseEmbeddingBitResults) o;
98101
return Objects.equals(embeddings, that.embeddings);
99102
}
100103

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import org.elasticsearch.xcontent.ToXContentObject;
2121
import org.elasticsearch.xcontent.XContent;
2222
import org.elasticsearch.xcontent.XContentBuilder;
23-
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
23+
import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
2424

2525
import java.io.IOException;
2626
import java.util.Arrays;
@@ -31,7 +31,8 @@
3131
import java.util.Objects;
3232

3333
/**
34-
* Writes a text embedding result in the follow json format
34+
* Writes a dense embedding result in the follow json format
35+
* <pre>
3536
* {
3637
* "text_embedding_bytes": [
3738
* {
@@ -46,13 +47,15 @@
4647
* }
4748
* ]
4849
* }
50+
* </pre>
4951
*/
50-
public record TextEmbeddingByteResults(List<Embedding> embeddings) implements TextEmbeddingResults<TextEmbeddingByteResults.Embedding> {
52+
public record DenseEmbeddingByteResults(List<Embedding> embeddings) implements DenseEmbeddingResults<DenseEmbeddingByteResults.Embedding> {
53+
// This name is a holdover from before this class was renamed
5154
public static final String NAME = "text_embedding_service_byte_results";
5255
public static final String TEXT_EMBEDDING_BYTES = "text_embedding_bytes";
5356

54-
public TextEmbeddingByteResults(StreamInput in) throws IOException {
55-
this(in.readCollectionAsList(TextEmbeddingByteResults.Embedding::new));
57+
public DenseEmbeddingByteResults(StreamInput in) throws IOException {
58+
this(in.readCollectionAsList(DenseEmbeddingByteResults.Embedding::new));
5659
}
5760

5861
@Override
@@ -81,7 +84,7 @@ public String getWriteableName() {
8184
@Override
8285
public List<? extends InferenceResults> transformToCoordinationFormat() {
8386
return embeddings.stream()
84-
.map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING_BYTES, embedding.toDoubleArray(), false))
87+
.map(embedding -> new MlDenseEmbeddingResults(TEXT_EMBEDDING_BYTES, embedding.toDoubleArray(), false))
8588
.toList();
8689
}
8790

@@ -96,7 +99,7 @@ public Map<String, Object> asMap() {
9699
public boolean equals(Object o) {
97100
if (this == o) return true;
98101
if (o == null || getClass() != o.getClass()) return false;
99-
TextEmbeddingByteResults that = (TextEmbeddingByteResults) o;
102+
DenseEmbeddingByteResults that = (DenseEmbeddingByteResults) o;
100103
return Objects.equals(embeddings, that.embeddings);
101104
}
102105

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import org.elasticsearch.xcontent.ToXContentObject;
2424
import org.elasticsearch.xcontent.XContent;
2525
import org.elasticsearch.xcontent.XContentBuilder;
26-
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
26+
import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
2727

2828
import java.io.IOException;
2929
import java.util.ArrayList;
@@ -35,7 +35,8 @@
3535
import java.util.Objects;
3636

3737
/**
38-
* Writes a text embedding result in the follow json format
38+
* Writes a dense embedding result in the follow json format
39+
* <pre>
3940
* {
4041
* "text_embedding": [
4142
* {
@@ -50,20 +51,24 @@
5051
* }
5152
* ]
5253
* }
54+
* </pre>
5355
*/
54-
public record TextEmbeddingFloatResults(List<Embedding> embeddings) implements TextEmbeddingResults<TextEmbeddingFloatResults.Embedding> {
56+
public record DenseEmbeddingFloatResults(List<Embedding> embeddings)
57+
implements
58+
DenseEmbeddingResults<DenseEmbeddingFloatResults.Embedding> {
59+
// This name is a holdover from before this class was renamed
5560
public static final String NAME = "text_embedding_service_results";
5661
public static final String TEXT_EMBEDDING = TaskType.TEXT_EMBEDDING.toString();
5762

58-
public TextEmbeddingFloatResults(StreamInput in) throws IOException {
59-
this(in.readCollectionAsList(TextEmbeddingFloatResults.Embedding::new));
63+
public DenseEmbeddingFloatResults(StreamInput in) throws IOException {
64+
this(in.readCollectionAsList(DenseEmbeddingFloatResults.Embedding::new));
6065
}
6166

62-
public static TextEmbeddingFloatResults of(List<? extends InferenceResults> results) {
67+
public static DenseEmbeddingFloatResults of(List<? extends InferenceResults> results) {
6368
List<Embedding> embeddings = new ArrayList<>(results.size());
6469
for (InferenceResults result : results) {
65-
if (result instanceof MlTextEmbeddingResults embeddingResult) {
66-
embeddings.add(TextEmbeddingFloatResults.Embedding.of(embeddingResult));
70+
if (result instanceof MlDenseEmbeddingResults embeddingResult) {
71+
embeddings.add(DenseEmbeddingFloatResults.Embedding.of(embeddingResult));
6772
} else if (result instanceof org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults errorResult) {
6873
if (errorResult.getException() instanceof ElasticsearchStatusException statusException) {
6974
throw statusException;
@@ -76,11 +81,15 @@ public static TextEmbeddingFloatResults of(List<? extends InferenceResults> resu
7681
}
7782
} else {
7883
throw new IllegalArgumentException(
79-
"Received invalid inference result, of type " + result.getClass().getName() + " but expected TextEmbeddingResults."
84+
"Received invalid inference result, of type "
85+
+ result.getClass().getName()
86+
+ " but expected "
87+
+ MlDenseEmbeddingResults.class.getName()
88+
+ "."
8089
);
8190
}
8291
}
83-
return new TextEmbeddingFloatResults(embeddings);
92+
return new DenseEmbeddingFloatResults(embeddings);
8493
}
8594

8695
@Override
@@ -108,7 +117,7 @@ public String getWriteableName() {
108117

109118
@Override
110119
public List<? extends InferenceResults> transformToCoordinationFormat() {
111-
return embeddings.stream().map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING, embedding.asDoubleArray(), false)).toList();
120+
return embeddings.stream().map(embedding -> new MlDenseEmbeddingResults(TEXT_EMBEDDING, embedding.asDoubleArray(), false)).toList();
112121
}
113122

114123
public Map<String, Object> asMap() {
@@ -122,7 +131,7 @@ public Map<String, Object> asMap() {
122131
public boolean equals(Object o) {
123132
if (this == o) return true;
124133
if (o == null || getClass() != o.getClass()) return false;
125-
TextEmbeddingFloatResults that = (TextEmbeddingFloatResults) o;
134+
DenseEmbeddingFloatResults that = (DenseEmbeddingFloatResults) o;
126135
return Objects.equals(embeddings, that.embeddings);
127136
}
128137

@@ -148,7 +157,7 @@ public Embedding(StreamInput in) throws IOException {
148157
this(in.readFloatArray());
149158
}
150159

151-
public static Embedding of(MlTextEmbeddingResults embeddingResult) {
160+
public static Embedding of(MlDenseEmbeddingResults embeddingResult) {
152161
float[] embeddingAsArray = embeddingResult.getInferenceAsFloat();
153162
return new Embedding(embeddingAsArray);
154163
}
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77

88
package org.elasticsearch.xpack.core.inference.results;
99

10-
public interface TextEmbeddingResults<E extends EmbeddingResults.Embedding<E>> extends EmbeddingResults<E> {
10+
public interface DenseEmbeddingResults<E extends EmbeddingResults.Embedding<E>> extends EmbeddingResults<E> {
1111

1212
/**
13-
* Returns the first text embedding entry in the result list's array size.
14-
* @return the size of the text embedding
13+
* Returns the first embedding entry in the result list's array size.
14+
* @return the size of the embedding
1515
* @throws IllegalStateException if the list of embeddings is empty
1616
*/
1717
int getFirstEmbeddingSize() throws IllegalStateException;

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
2626
import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults;
2727
import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults;
28-
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
28+
import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
2929
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
3030
import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults;
3131
import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults;
@@ -669,7 +669,7 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
669669
);
670670
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, TextExpansionResults.NAME, TextExpansionResults::new));
671671
namedWriteables.add(
672-
new NamedWriteableRegistry.Entry(InferenceResults.class, MlTextEmbeddingResults.NAME, MlTextEmbeddingResults::new)
672+
new NamedWriteableRegistry.Entry(InferenceResults.class, MlDenseEmbeddingResults.NAME, MlDenseEmbeddingResults::new)
673673
);
674674
namedWriteables.add(
675675
new NamedWriteableRegistry.Entry(
Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,21 @@
1616
import java.util.Map;
1717
import java.util.Objects;
1818

19-
public class MlTextEmbeddingResults extends NlpInferenceResults {
19+
public class MlDenseEmbeddingResults extends NlpInferenceResults {
2020

21+
// This name is a holdover from before this class was renamed
2122
public static final String NAME = "text_embedding_result";
2223

2324
private final String resultsField;
2425
private final double[] inference;
2526

26-
public MlTextEmbeddingResults(String resultsField, double[] inference, boolean isTruncated) {
27+
public MlDenseEmbeddingResults(String resultsField, double[] inference, boolean isTruncated) {
2728
super(isTruncated);
2829
this.inference = inference;
2930
this.resultsField = resultsField;
3031
}
3132

32-
public MlTextEmbeddingResults(StreamInput in) throws IOException {
33+
public MlDenseEmbeddingResults(StreamInput in) throws IOException {
3334
super(in);
3435
inference = in.readDoubleArray();
3536
resultsField = in.readString();
@@ -89,7 +90,7 @@ public boolean equals(Object o) {
8990
if (this == o) return true;
9091
if (o == null || getClass() != o.getClass()) return false;
9192
if (super.equals(o) == false) return false;
92-
MlTextEmbeddingResults that = (MlTextEmbeddingResults) o;
93+
MlDenseEmbeddingResults that = (MlDenseEmbeddingResults) o;
9394
return Objects.equals(resultsField, that.resultsField) && Arrays.equals(inference, that.inference);
9495
}
9596

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/vectors/TextEmbeddingQueryVectorBuilder.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
2222
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
2323
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
24-
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
24+
import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
2525
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
2626
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate;
2727

@@ -126,14 +126,14 @@ public void buildVector(Client client, ActionListener<float[]> listener) {
126126
return;
127127
}
128128

129-
if (response.getInferenceResults().get(0) instanceof MlTextEmbeddingResults textEmbeddingResults) {
129+
if (response.getInferenceResults().get(0) instanceof MlDenseEmbeddingResults textEmbeddingResults) {
130130
listener.onResponse(textEmbeddingResults.getInferenceAsFloat());
131131
} else if (response.getInferenceResults().get(0) instanceof WarningInferenceResults warning) {
132132
listener.onFailure(new IllegalStateException(warning.getWarning()));
133133
} else {
134134
throw new IllegalArgumentException(
135135
"expected a result of type ["
136-
+ MlTextEmbeddingResults.NAME
136+
+ MlDenseEmbeddingResults.NAME
137137
+ "] received ["
138138
+ response.getInferenceResults().get(0).getWriteableName()
139139
+ "]. Is ["

0 commit comments

Comments
 (0)