Skip to content

Commit 71dfdc8

Browse files
committed
Adding BIT support
1 parent be1e9cf commit 71dfdc8

File tree

2 files changed

+29
-10
lines changed

2 files changed

+29
-10
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import org.elasticsearch.xcontent.XContentParser;
1717
import org.elasticsearch.xcontent.XContentParserConfiguration;
1818
import org.elasticsearch.xcontent.XContentType;
19+
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
20+
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingBitResults;
1921
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
2022
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
2123
import org.elasticsearch.xpack.inference.external.http.HttpResult;
@@ -43,7 +45,9 @@ public class VoyageAIEmbeddingsResponseEntity {
4345
toLowerCase(VoyageAIEmbeddingType.FLOAT),
4446
VoyageAIEmbeddingsResponseEntity::parseFloatEmbeddingsArray,
4547
toLowerCase(VoyageAIEmbeddingType.INT8),
46-
VoyageAIEmbeddingsResponseEntity::parseByteEmbeddingsArray
48+
VoyageAIEmbeddingsResponseEntity::parseByteEmbeddingsArray,
49+
toLowerCase(VoyageAIEmbeddingType.BINARY),
50+
VoyageAIEmbeddingsResponseEntity::parseBitEmbeddingsArray
4751
);
4852

4953
private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes();
@@ -119,7 +123,7 @@ public static InferenceServiceResults fromResponse(Request request, HttpResult r
119123

120124
return new InferenceTextEmbeddingFloatResults(embeddingList);
121125
} else if(embeddingType == VoyageAIEmbeddingType.INT8) {
122-
List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding> embeddingList = parseList(
126+
List<InferenceByteEmbedding> embeddingList = parseList(
123127
jsonParser,
124128
VoyageAIEmbeddingsResponseEntity::parseEmbeddingObjectByte
125129
);
@@ -144,7 +148,7 @@ private static InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding parseE
144148
return InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding.of(embeddingValuesList);
145149
}
146150

147-
private static InferenceTextEmbeddingByteResults.InferenceByteEmbedding parseEmbeddingObjectByte(XContentParser parser)
151+
private static InferenceByteEmbedding parseEmbeddingObjectByte(XContentParser parser)
148152
throws IOException {
149153
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
150154

@@ -154,7 +158,13 @@ private static InferenceTextEmbeddingByteResults.InferenceByteEmbedding parseEmb
154158
// parse and discard the rest of the object
155159
consumeUntilObjectEnd(parser);
156160

157-
return InferenceTextEmbeddingByteResults.InferenceByteEmbedding.of(embeddingValuesList);
161+
return InferenceByteEmbedding.of(embeddingValuesList);
162+
}
163+
164+
private static InferenceServiceResults parseBitEmbeddingsArray(XContentParser parser) throws IOException {
165+
var embeddingList = parseList(parser, VoyageAIEmbeddingsResponseEntity::parseByteArrayEntry);
166+
167+
return new InferenceTextEmbeddingBitResults(embeddingList);
158168
}
159169

160170
private static InferenceServiceResults parseByteEmbeddingsArray(XContentParser parser) throws IOException {
@@ -163,11 +173,11 @@ private static InferenceServiceResults parseByteEmbeddingsArray(XContentParser p
163173
return new InferenceTextEmbeddingByteResults(embeddingList);
164174
}
165175

166-
private static InferenceTextEmbeddingByteResults.InferenceByteEmbedding parseByteArrayEntry(XContentParser parser) throws IOException {
176+
private static InferenceByteEmbedding parseByteArrayEntry(XContentParser parser) throws IOException {
167177
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
168178
List<Byte> embeddingValuesList = parseList(parser, VoyageAIEmbeddingsResponseEntity::parseEmbeddingInt8Entry);
169179

170-
return InferenceTextEmbeddingByteResults.InferenceByteEmbedding.of(embeddingValuesList);
180+
return InferenceByteEmbedding.of(embeddingValuesList);
171181
}
172182

173183
private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingType.java

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
package org.elasticsearch.xpack.inference.services.voyageai.embeddings;
99

10-
import org.elasticsearch.TransportVersion;
11-
import org.elasticsearch.TransportVersions;
1210
import org.elasticsearch.common.Strings;
1311
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
1412

@@ -36,18 +34,29 @@ public enum VoyageAIEmbeddingType {
3634
/**
3735
* This is a synonym for INT8
3836
*/
39-
BYTE(DenseVectorFieldMapper.ElementType.BYTE, RequestConstants.INT8);
37+
BYTE(DenseVectorFieldMapper.ElementType.BYTE, RequestConstants.INT8),
38+
/**
39+
* Use this when you want to get back binary embeddings. Valid only for v3 models.
40+
*/
41+
BIT(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BIT),
42+
/**
43+
* This is a synonym for BIT
44+
*/
45+
BINARY(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BIT);
4046

4147
private static final class RequestConstants {
4248
private static final String FLOAT = "float";
4349
private static final String INT8 = "int8";
50+
private static final String BIT = "binary";
4451
}
4552

4653
private static final Map<DenseVectorFieldMapper.ElementType, VoyageAIEmbeddingType> ELEMENT_TYPE_TO_VOYAGE_EMBEDDING = Map.of(
4754
DenseVectorFieldMapper.ElementType.FLOAT,
4855
FLOAT,
4956
DenseVectorFieldMapper.ElementType.BYTE,
50-
BYTE
57+
BYTE,
58+
DenseVectorFieldMapper.ElementType.BIT,
59+
BIT
5160
);
5261
static final EnumSet<DenseVectorFieldMapper.ElementType> SUPPORTED_ELEMENT_TYPES = EnumSet.copyOf(
5362
ELEMENT_TYPE_TO_VOYAGE_EMBEDDING.keySet()

0 commit comments

Comments
 (0)