Skip to content

Commit 8f6e03b

Browse files
committed
Initial tests
1 parent 71dfdc8 commit 8f6e03b

16 files changed

+276
-251
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequest.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,8 @@ public HttpRequest createHttpRequest() {
4949
HttpPost httpPost = new HttpPost(account.uri());
5050

5151
ByteArrayEntity byteEntity = new ByteArrayEntity(
52-
Strings.toString(new VoyageAIEmbeddingsRequestEntity(
53-
input,
54-
serviceSettings,
55-
taskSettings,
56-
model
57-
)).getBytes(StandardCharsets.UTF_8)
52+
Strings.toString(new VoyageAIEmbeddingsRequestEntity(input, serviceSettings, taskSettings, model))
53+
.getBytes(StandardCharsets.UTF_8)
5854
);
5955
httpPost.setEntity(byteEntity);
6056

@@ -83,9 +79,13 @@ public boolean[] getTruncationInfo() {
8379
return null;
8480
}
8581

86-
public VoyageAIEmbeddingsTaskSettings getTaskSettings() { return taskSettings; }
82+
public VoyageAIEmbeddingsTaskSettings getTaskSettings() {
83+
return taskSettings;
84+
}
8785

88-
public VoyageAIEmbeddingsServiceSettings getServiceSettings() { return serviceSettings; }
86+
public VoyageAIEmbeddingsServiceSettings getServiceSettings() {
87+
return serviceSettings;
88+
}
8989

9090
public static URI buildDefaultUri() throws URISyntaxException {
9191
return new URIBuilder().setScheme("https")

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,16 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
5353
builder.field(INPUT_TYPE_FIELD, inputType);
5454
}
5555

56-
if(taskSettings.getTruncation() != null) {
56+
if (taskSettings.getTruncation() != null) {
5757
builder.field(TRUNCATION_FIELD, taskSettings.getTruncation());
5858
}
5959

60-
if(serviceSettings.dimensions() != null) {
60+
if (serviceSettings.dimensions() != null) {
6161
builder.field(OUTPUT_DIMENSION, serviceSettings.dimensions());
6262
}
6363

64-
if(serviceSettings.getEmbeddingType() != null) {
65-
builder.field(OUTPUT_DTYPE_FIELD, serviceSettings.getEmbeddingType());
64+
if (serviceSettings.getEmbeddingType() != null) {
65+
builder.field(OUTPUT_DTYPE_FIELD, serviceSettings.getEmbeddingType().toRequestString());
6666
}
6767

6868
builder.endObject();
@@ -71,6 +71,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
7171

7272
static String convertToString(InputType inputType) {
7373
return switch (inputType) {
74+
case null -> null;
7475
case INGEST -> DOCUMENT;
7576
case SEARCH -> QUERY;
7677
default -> {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
5252
builder.field(VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, taskSettings.getTopKDocumentsOnly());
5353
}
5454

55-
if(taskSettings.getTruncation() != null) {
55+
if (taskSettings.getTruncation() != null) {
5656
builder.field(TRUNCATION_FIELD, taskSettings.getTruncation());
5757
}
5858

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java

Lines changed: 36 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
package org.elasticsearch.xpack.inference.external.response.voyageai;
1111

1212
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
13-
import org.elasticsearch.core.CheckedFunction;
1413
import org.elasticsearch.inference.InferenceServiceResults;
1514
import org.elasticsearch.xcontent.XContentFactory;
1615
import org.elasticsearch.xcontent.XContentParser;
@@ -26,34 +25,27 @@
2625
import org.elasticsearch.xpack.inference.external.response.XContentUtils;
2726
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType;
2827

29-
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.consumeUntilObjectEnd;
30-
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
31-
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;
32-
3328
import java.io.IOException;
3429
import java.util.Arrays;
3530
import java.util.List;
36-
import java.util.Map;
3731

3832
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
3933
import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList;
34+
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.consumeUntilObjectEnd;
35+
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
36+
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;
4037
import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType.toLowerCase;
4138

4239
public class VoyageAIEmbeddingsResponseEntity {
4340
private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in VoyageAI embeddings response";
44-
private static final Map<String, CheckedFunction<XContentParser, InferenceServiceResults, IOException>> EMBEDDING_PARSERS = Map.of(
45-
toLowerCase(VoyageAIEmbeddingType.FLOAT),
46-
VoyageAIEmbeddingsResponseEntity::parseFloatEmbeddingsArray,
47-
toLowerCase(VoyageAIEmbeddingType.INT8),
48-
VoyageAIEmbeddingsResponseEntity::parseByteEmbeddingsArray,
49-
toLowerCase(VoyageAIEmbeddingType.BINARY),
50-
VoyageAIEmbeddingsResponseEntity::parseBitEmbeddingsArray
51-
);
5241

5342
private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes();
5443

5544
private static String supportedEmbeddingTypes() {
56-
var validTypes = EMBEDDING_PARSERS.keySet().toArray(String[]::new);
45+
String[] validTypes = new String[] {
46+
toLowerCase(VoyageAIEmbeddingType.FLOAT),
47+
toLowerCase(VoyageAIEmbeddingType.INT8),
48+
toLowerCase(VoyageAIEmbeddingType.BIT) };
5749
Arrays.sort(validTypes);
5850
return String.join(", ", validTypes);
5951
}
@@ -105,7 +97,7 @@ private static String supportedEmbeddingTypes() {
10597
*/
10698
public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException {
10799
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
108-
VoyageAIEmbeddingType embeddingType = ((VoyageAIEmbeddingsRequest)request).getServiceSettings().getEmbeddingType();
100+
VoyageAIEmbeddingType embeddingType = ((VoyageAIEmbeddingsRequest) request).getServiceSettings().getEmbeddingType();
109101

110102
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
111103
moveToFirstToken(jsonParser);
@@ -115,22 +107,31 @@ public static InferenceServiceResults fromResponse(Request request, HttpResult r
115107

116108
positionParserAtTokenAfterField(jsonParser, "data", FAILED_TO_FIND_FIELD_TEMPLATE);
117109

118-
if(embeddingType == null || embeddingType == VoyageAIEmbeddingType.FLOAT) {
110+
if (embeddingType == null || embeddingType == VoyageAIEmbeddingType.FLOAT) {
119111
List<InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding> embeddingList = parseList(
120112
jsonParser,
121113
VoyageAIEmbeddingsResponseEntity::parseEmbeddingObjectFloat
122114
);
123115

124116
return new InferenceTextEmbeddingFloatResults(embeddingList);
125-
} else if(embeddingType == VoyageAIEmbeddingType.INT8) {
117+
} else if (embeddingType == VoyageAIEmbeddingType.INT8) {
126118
List<InferenceByteEmbedding> embeddingList = parseList(
127119
jsonParser,
128120
VoyageAIEmbeddingsResponseEntity::parseEmbeddingObjectByte
129121
);
130122

131123
return new InferenceTextEmbeddingByteResults(embeddingList);
124+
} else if (embeddingType == VoyageAIEmbeddingType.BIT || embeddingType == VoyageAIEmbeddingType.BINARY) {
125+
List<InferenceByteEmbedding> embeddingList = parseList(
126+
jsonParser,
127+
VoyageAIEmbeddingsResponseEntity::parseEmbeddingObjectBit
128+
);
129+
130+
return new InferenceTextEmbeddingBitResults(embeddingList);
132131
} else {
133-
throw new IllegalArgumentException("Illegal output_dtype value: " + embeddingType);
132+
throw new IllegalArgumentException(
133+
"Illegal embedding_type value: " + embeddingType + ". Supported types are: " + VALID_EMBEDDING_TYPES_STRING
134+
);
134135
}
135136
}
136137
}
@@ -148,8 +149,7 @@ private static InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding parseE
148149
return InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding.of(embeddingValuesList);
149150
}
150151

151-
private static InferenceByteEmbedding parseEmbeddingObjectByte(XContentParser parser)
152-
throws IOException {
152+
private static InferenceByteEmbedding parseEmbeddingObjectByte(XContentParser parser) throws IOException {
153153
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
154154

155155
positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);
@@ -161,21 +161,14 @@ private static InferenceByteEmbedding parseEmbeddingObjectByte(XContentParser pa
161161
return InferenceByteEmbedding.of(embeddingValuesList);
162162
}
163163

164-
private static InferenceServiceResults parseBitEmbeddingsArray(XContentParser parser) throws IOException {
165-
var embeddingList = parseList(parser, VoyageAIEmbeddingsResponseEntity::parseByteArrayEntry);
166-
167-
return new InferenceTextEmbeddingBitResults(embeddingList);
168-
}
169-
170-
private static InferenceServiceResults parseByteEmbeddingsArray(XContentParser parser) throws IOException {
171-
var embeddingList = parseList(parser, VoyageAIEmbeddingsResponseEntity::parseByteArrayEntry);
164+
private static InferenceByteEmbedding parseEmbeddingObjectBit(XContentParser parser) throws IOException {
165+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
172166

173-
return new InferenceTextEmbeddingByteResults(embeddingList);
174-
}
167+
positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);
175168

176-
private static InferenceByteEmbedding parseByteArrayEntry(XContentParser parser) throws IOException {
177-
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
178-
List<Byte> embeddingValuesList = parseList(parser, VoyageAIEmbeddingsResponseEntity::parseEmbeddingInt8Entry);
169+
List<Byte> embeddingValuesList = parseList(parser, VoyageAIEmbeddingsResponseEntity::parseEmbeddingBitEntry);
170+
// parse and discard the rest of the object
171+
consumeUntilObjectEnd(parser);
179172

180173
return InferenceByteEmbedding.of(embeddingValuesList);
181174
}
@@ -189,24 +182,20 @@ private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOExce
189182
return (byte) parsedByte;
190183
}
191184

185+
private static Byte parseEmbeddingBitEntry(XContentParser parser) throws IOException {
186+
XContentParser.Token token = parser.currentToken();
187+
ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser);
188+
var parsedBit = parser.shortValue();
189+
checkByteBounds(parsedBit);
190+
191+
return (byte) parsedBit;
192+
}
193+
192194
private static void checkByteBounds(short value) {
193195
if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
194196
throw new IllegalArgumentException("Value [" + value + "] is out of range for a byte");
195197
}
196198
}
197199

198-
private static InferenceServiceResults parseFloatEmbeddingsArray(XContentParser parser) throws IOException {
199-
var embeddingList = parseList(parser, VoyageAIEmbeddingsResponseEntity::parseFloatArrayEntry);
200-
201-
return new InferenceTextEmbeddingFloatResults(embeddingList);
202-
}
203-
204-
private static InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding parseFloatArrayEntry(XContentParser parser)
205-
throws IOException {
206-
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
207-
List<Float> embeddingValuesList = parseList(parser, XContentUtils::parseFloat);
208-
return InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding.of(embeddingValuesList);
209-
}
210-
211200
private VoyageAIEmbeddingsResponseEntity() {}
212201
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingType.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,16 @@ public enum VoyageAIEmbeddingType {
3838
/**
3939
* Use this when you want to get back binary embeddings. Valid only for v3 models.
4040
*/
41-
BIT(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BIT),
41+
BIT(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BINARY),
4242
/**
4343
* This is a synonym for BIT
4444
*/
45-
BINARY(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BIT);
45+
BINARY(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BINARY);
4646

4747
private static final class RequestConstants {
4848
private static final String FLOAT = "float";
4949
private static final String INT8 = "int8";
50-
private static final String BIT = "binary";
50+
private static final String BINARY = "binary";
5151
}
5252

5353
private static final Map<DenseVectorFieldMapper.ElementType, VoyageAIEmbeddingType> ELEMENT_TYPE_TO_VOYAGE_EMBEDDING = Map.of(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsServiceSettings.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,10 @@ public VoyageAIEmbeddingsServiceSettings(
132132

133133
public VoyageAIEmbeddingsServiceSettings(StreamInput in) throws IOException {
134134
this.commonSettings = new VoyageAIServiceSettings(in);
135-
this.embeddingType = Objects.requireNonNullElse(in.readOptionalEnum(VoyageAIEmbeddingType.class), VoyageAIEmbeddingType.FLOAT);
136135
this.similarity = in.readOptionalEnum(SimilarityMeasure.class);
137136
this.dimensions = in.readOptionalVInt();
138137
this.maxInputTokens = in.readOptionalVInt();
138+
this.embeddingType = Objects.requireNonNullElse(in.readOptionalEnum(VoyageAIEmbeddingType.class), VoyageAIEmbeddingType.FLOAT);
139139
}
140140

141141
public VoyageAIServiceSettings getCommonSettings() {
@@ -165,7 +165,6 @@ public VoyageAIEmbeddingType getEmbeddingType() {
165165
return embeddingType;
166166
}
167167

168-
169168
@Override
170169
public DenseVectorFieldMapper.ElementType elementType() {
171170
return embeddingType == null ? DenseVectorFieldMapper.ElementType.FLOAT : embeddingType.toElementType();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettings.java

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727

2828
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
2929
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
30-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
31-
import static org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceFields.OUTPUT_DIMENSION;
3230
import static org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceFields.TRUNCATION;
3331

3432
/**
@@ -43,10 +41,7 @@ public class VoyageAIEmbeddingsTaskSettings implements TaskSettings {
4341
public static final String NAME = "voyageai_embeddings_task_settings";
4442
public static final VoyageAIEmbeddingsTaskSettings EMPTY_SETTINGS = new VoyageAIEmbeddingsTaskSettings(null, null);
4543
static final String INPUT_TYPE = "input_type";
46-
static final EnumSet<InputType> VALID_REQUEST_VALUES = EnumSet.of(
47-
InputType.INGEST,
48-
InputType.SEARCH
49-
);
44+
static final EnumSet<InputType> VALID_REQUEST_VALUES = EnumSet.of(InputType.INGEST, InputType.SEARCH);
5045

5146
public static VoyageAIEmbeddingsTaskSettings fromMap(Map<String, Object> map) {
5247
if (map == null || map.isEmpty()) {
@@ -63,17 +58,7 @@ public static VoyageAIEmbeddingsTaskSettings fromMap(Map<String, Object> map) {
6358
VALID_REQUEST_VALUES,
6459
validationException
6560
);
66-
Boolean truncation = extractOptionalBoolean(
67-
map,
68-
TRUNCATION,
69-
validationException
70-
);
71-
Integer outputDimension = extractOptionalPositiveInteger(
72-
map,
73-
OUTPUT_DIMENSION,
74-
ModelConfigurations.TASK_SETTINGS,
75-
validationException
76-
);
61+
Boolean truncation = extractOptionalBoolean(map, TRUNCATION, validationException);
7762

7863
if (validationException.validationErrors().isEmpty() == false) {
7964
throw validationException;
@@ -132,16 +117,10 @@ private static Boolean getValidTruncation(
132117
private final Boolean truncation;
133118

134119
public VoyageAIEmbeddingsTaskSettings(StreamInput in) throws IOException {
135-
this(
136-
in.readOptionalEnum(InputType.class),
137-
in.readOptionalBoolean()
138-
);
120+
this(in.readOptionalEnum(InputType.class), in.readOptionalBoolean());
139121
}
140122

141-
public VoyageAIEmbeddingsTaskSettings(
142-
@Nullable InputType inputType,
143-
@Nullable Boolean truncation
144-
) {
123+
public VoyageAIEmbeddingsTaskSettings(@Nullable InputType inputType, @Nullable Boolean truncation) {
145124
validateInputType(inputType);
146125
this.inputType = inputType;
147126
this.truncation = truncation;
@@ -204,8 +183,7 @@ public boolean equals(Object o) {
204183
if (this == o) return true;
205184
if (o == null || getClass() != o.getClass()) return false;
206185
VoyageAIEmbeddingsTaskSettings that = (VoyageAIEmbeddingsTaskSettings) o;
207-
return Objects.equals(inputType, that.inputType) &&
208-
Objects.equals(truncation, that.truncation);
186+
return Objects.equals(inputType, that.inputType) && Objects.equals(truncation, that.truncation);
209187
}
210188

211189
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankTaskSettings.java

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,7 @@ public static VoyageAIRerankTaskSettings fromMap(Map<String, Object> map) {
5555
validationException
5656
);
5757

58-
Boolean truncation = extractOptionalBoolean(
59-
map,
60-
TRUNCATION,
61-
validationException
62-
);
58+
Boolean truncation = extractOptionalBoolean(map, TRUNCATION, validationException);
6359

6460
if (validationException.validationErrors().isEmpty() == false) {
6561
throw validationException;
@@ -156,9 +152,9 @@ public boolean equals(Object o) {
156152
if (this == o) return true;
157153
if (o == null || getClass() != o.getClass()) return false;
158154
VoyageAIRerankTaskSettings that = (VoyageAIRerankTaskSettings) o;
159-
return Objects.equals(topKDocumentsOnly, that.topKDocumentsOnly) &&
160-
Objects.equals(returnDocuments, that.returnDocuments) &&
161-
Objects.equals(truncation, that.truncation);
155+
return Objects.equals(topKDocumentsOnly, that.topKDocumentsOnly)
156+
&& Objects.equals(returnDocuments, that.returnDocuments)
157+
&& Objects.equals(truncation, that.truncation);
162158
}
163159

164160
@Override

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceSettingsTests.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
1919
import org.elasticsearch.xpack.inference.services.ServiceFields;
2020
import org.elasticsearch.xpack.inference.services.ServiceUtils;
21-
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
2221
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
2322
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
2423
import org.hamcrest.MatcherAssert;

0 commit comments

Comments
 (0)