Skip to content

Commit 00ba02e

Browse files
committed
Adding embeddings type to Jina AI service settings
1 parent cc6e84e commit 00ba02e

File tree

13 files changed

+883
-137
lines changed

13 files changed

+883
-137
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,11 @@ static TransportVersion def(int id) {
172172
public static final TransportVersion TIMEOUT_GET_PARAM_FOR_RESOLVE_CLUSTER = def(8_838_0_00);
173173
public static final TransportVersion INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING = def(8_839_0_00);
174174
public static final TransportVersion ML_INFERENCE_IBM_WATSONX_RERANK_ADDED = def(8_840_0_00);
175+
public static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_X = def(8_840_0_02);
175176
public static final TransportVersion ELASTICSEARCH_9_0 = def(9_000_0_00);
176177
public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_001_0_00);
178+
public static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_002_0_00);
179+
177180
/*
178181
* STOP! READ THIS FIRST! No, really,
179182
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequest.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.xpack.inference.external.jinaai.JinaAIAccount;
1515
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
1616
import org.elasticsearch.xpack.inference.external.request.Request;
17+
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType;
1718
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel;
1819
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
1920

@@ -30,6 +31,7 @@ public class JinaAIEmbeddingsRequest extends JinaAIRequest {
3031
private final JinaAIEmbeddingsTaskSettings taskSettings;
3132
private final String model;
3233
private final String inferenceEntityId;
34+
private final JinaAIEmbeddingType embeddingType;
3335

3436
public JinaAIEmbeddingsRequest(List<String> input, JinaAIEmbeddingsModel embeddingsModel) {
3537
Objects.requireNonNull(embeddingsModel);
@@ -38,6 +40,7 @@ public JinaAIEmbeddingsRequest(List<String> input, JinaAIEmbeddingsModel embeddi
3840
this.input = Objects.requireNonNull(input);
3941
taskSettings = embeddingsModel.getTaskSettings();
4042
model = embeddingsModel.getServiceSettings().getCommonSettings().modelId();
43+
embeddingType = embeddingsModel.getServiceSettings().getEmbeddingType();
4144
inferenceEntityId = embeddingsModel.getInferenceEntityId();
4245
}
4346

@@ -46,7 +49,7 @@ public HttpRequest createHttpRequest() {
4649
HttpPost httpPost = new HttpPost(account.uri());
4750

4851
ByteArrayEntity byteEntity = new ByteArrayEntity(
49-
Strings.toString(new JinaAIEmbeddingsRequestEntity(input, taskSettings, model)).getBytes(StandardCharsets.UTF_8)
52+
Strings.toString(new JinaAIEmbeddingsRequestEntity(input, taskSettings, model, embeddingType)).getBytes(StandardCharsets.UTF_8)
5053
);
5154
httpPost.setEntity(byteEntity);
5255

@@ -75,6 +78,10 @@ public boolean[] getTruncationInfo() {
7578
return null;
7679
}
7780

81+
public JinaAIEmbeddingType getEmbeddingType() {
82+
return embeddingType;
83+
}
84+
7885
public static URI buildDefaultUri() throws URISyntaxException {
7986
return new URIBuilder().setScheme("https")
8087
.setHost(JinaAIUtils.HOST)

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestEntity.java

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.inference.InputType;
1212
import org.elasticsearch.xcontent.ToXContentObject;
1313
import org.elasticsearch.xcontent.XContentBuilder;
14+
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType;
1415
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
1516

1617
import java.io.IOException;
@@ -19,9 +20,12 @@
1920

2021
import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings.invalidInputTypeMessage;
2122

22-
public record JinaAIEmbeddingsRequestEntity(List<String> input, JinaAIEmbeddingsTaskSettings taskSettings, @Nullable String model)
23-
implements
24-
ToXContentObject {
23+
public record JinaAIEmbeddingsRequestEntity(
24+
List<String> input,
25+
JinaAIEmbeddingsTaskSettings taskSettings,
26+
@Nullable String model,
27+
@Nullable JinaAIEmbeddingType embeddingType
28+
) implements ToXContentObject {
2529

2630
private static final String SEARCH_DOCUMENT = "retrieval.passage";
2731
private static final String SEARCH_QUERY = "retrieval.query";
@@ -30,6 +34,7 @@ public record JinaAIEmbeddingsRequestEntity(List<String> input, JinaAIEmbeddings
3034
private static final String INPUT_FIELD = "input";
3135
private static final String MODEL_FIELD = "model";
3236
public static final String TASK_TYPE_FIELD = "task";
37+
static final String EMBEDDING_TYPE_FIELD = "embedding_type";
3338

3439
public JinaAIEmbeddingsRequestEntity {
3540
Objects.requireNonNull(input);
@@ -43,6 +48,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
4348
builder.field(INPUT_FIELD, input);
4449
builder.field(MODEL_FIELD, model);
4550

51+
if (embeddingType != null) {
52+
builder.field(EMBEDDING_TYPE_FIELD, embeddingType.toRequestString());
53+
}
54+
4655
if (taskSettings.getInputType() != null) {
4756
builder.field(TASK_TYPE_FIELD, convertToString(taskSettings.getInputType()));
4857
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIEmbeddingsResponseEntity.java

Lines changed: 85 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,54 @@
99

1010
package org.elasticsearch.xpack.inference.external.response.jinaai;
1111

12+
import org.elasticsearch.common.Strings;
1213
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
14+
import org.elasticsearch.core.CheckedFunction;
15+
import org.elasticsearch.inference.InferenceServiceResults;
1316
import org.elasticsearch.xcontent.XContentFactory;
1417
import org.elasticsearch.xcontent.XContentParser;
1518
import org.elasticsearch.xcontent.XContentParserConfiguration;
1619
import org.elasticsearch.xcontent.XContentType;
20+
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
21+
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingBitResults;
1722
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
1823
import org.elasticsearch.xpack.inference.external.http.HttpResult;
1924
import org.elasticsearch.xpack.inference.external.request.Request;
25+
import org.elasticsearch.xpack.inference.external.request.jinaai.JinaAIEmbeddingsRequest;
2026
import org.elasticsearch.xpack.inference.external.response.XContentUtils;
27+
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType;
2128

2229
import java.io.IOException;
30+
import java.util.Arrays;
2331
import java.util.List;
32+
import java.util.Map;
2433

2534
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
2635
import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList;
2736
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.consumeUntilObjectEnd;
2837
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
2938
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;
39+
import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType.toLowerCase;
3040

3141
public class JinaAIEmbeddingsResponseEntity {
3242
private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in JinaAI embeddings response";
3343

44+
private static final Map<String, CheckedFunction<XContentParser, InferenceServiceResults, IOException>> EMBEDDING_PARSERS = Map.of(
45+
toLowerCase(JinaAIEmbeddingType.FLOAT),
46+
JinaAIEmbeddingsResponseEntity::parseFloatDataObject,
47+
toLowerCase(JinaAIEmbeddingType.BIT),
48+
JinaAIEmbeddingsResponseEntity::parseBitDataObject,
49+
toLowerCase(JinaAIEmbeddingType.BINARY),
50+
JinaAIEmbeddingsResponseEntity::parseBitDataObject
51+
);
52+
private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes();
53+
54+
private static String supportedEmbeddingTypes() {
55+
var validTypes = EMBEDDING_PARSERS.keySet().toArray(String[]::new);
56+
Arrays.sort(validTypes);
57+
return String.join(", ", validTypes);
58+
}
59+
3460
/**
3561
* Parses the JinaAI json response.
3662
* For a request like:
@@ -73,8 +99,21 @@ public class JinaAIEmbeddingsResponseEntity {
7399
* </code>
74100
* </pre>
75101
*/
76-
public static InferenceTextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
102+
public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException {
103+
// embeddings type is not specified anywhere in the response so grab it from the request
104+
JinaAIEmbeddingsRequest embeddingsRequest = (JinaAIEmbeddingsRequest) request;
105+
var embeddingType = embeddingsRequest.getEmbeddingType().toString();
77106
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
107+
var embeddingValueParser = EMBEDDING_PARSERS.get(embeddingType);
108+
109+
if (embeddingValueParser == null) {
110+
throw new IllegalStateException(
111+
Strings.format(
112+
"Failed to find a supported embedding type for in the Jina AI embeddings response. Supported types are [%s]",
113+
VALID_EMBEDDING_TYPES_STRING
114+
)
115+
);
116+
}
78117

79118
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
80119
moveToFirstToken(jsonParser);
@@ -84,27 +123,64 @@ public static InferenceTextEmbeddingFloatResults fromResponse(Request request, H
84123

85124
positionParserAtTokenAfterField(jsonParser, "data", FAILED_TO_FIND_FIELD_TEMPLATE);
86125

87-
List<InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding> embeddingList = parseList(
88-
jsonParser,
89-
JinaAIEmbeddingsResponseEntity::parseEmbeddingObject
90-
);
91-
92-
return new InferenceTextEmbeddingFloatResults(embeddingList);
126+
return embeddingValueParser.apply(jsonParser);
93127
}
94128
}
95129

96-
private static InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding parseEmbeddingObject(XContentParser parser)
130+
private static InferenceTextEmbeddingFloatResults parseFloatDataObject(XContentParser jsonParser) throws IOException {
131+
List<InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding> embeddingList = parseList(
132+
jsonParser,
133+
JinaAIEmbeddingsResponseEntity::parseFloatEmbeddingObject
134+
);
135+
136+
return new InferenceTextEmbeddingFloatResults(embeddingList);
137+
}
138+
139+
private static InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding parseFloatEmbeddingObject(XContentParser parser)
97140
throws IOException {
98141
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
99142

100143
positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);
101144

102-
List<Float> embeddingValuesList = parseList(parser, XContentUtils::parseFloat);
145+
var embeddingValuesList = parseList(parser, XContentUtils::parseFloat);
103146
// parse and discard the rest of the object
104147
consumeUntilObjectEnd(parser);
105148

106149
return InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding.of(embeddingValuesList);
107150
}
108151

152+
private static InferenceTextEmbeddingBitResults parseBitDataObject(XContentParser jsonParser) throws IOException {
153+
List<InferenceByteEmbedding> embeddingList = parseList(jsonParser, JinaAIEmbeddingsResponseEntity::parseBitEmbeddingObject);
154+
155+
return new InferenceTextEmbeddingBitResults(embeddingList);
156+
}
157+
158+
private static InferenceByteEmbedding parseBitEmbeddingObject(XContentParser parser) throws IOException {
159+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
160+
161+
positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);
162+
163+
var embeddingList = parseList(parser, JinaAIEmbeddingsResponseEntity::parseEmbeddingInt8Entry);
164+
// parse and discard the rest of the object
165+
consumeUntilObjectEnd(parser);
166+
167+
return InferenceByteEmbedding.of(embeddingList);
168+
}
169+
170+
private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException {
171+
XContentParser.Token token = parser.currentToken();
172+
ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser);
173+
var parsedByte = parser.shortValue();
174+
checkByteBounds(parsedByte);
175+
176+
return (byte) parsedByte;
177+
}
178+
179+
private static void checkByteBounds(short value) {
180+
if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
181+
throw new IllegalArgumentException("Value [" + value + "] is out of range for a byte");
182+
}
183+
}
184+
109185
private JinaAIEmbeddingsResponseEntity() {}
110186
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,8 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
306306
),
307307
similarityToUse,
308308
embeddingSize,
309-
maxInputTokens
309+
maxInputTokens,
310+
serviceSettings.getEmbeddingType()
310311
);
311312

312313
return new JinaAIEmbeddingsModel(embeddingsModel, updatedServiceSettings);
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.jinaai.embeddings;
9+
10+
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.TransportVersions;
12+
import org.elasticsearch.common.Strings;
13+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
14+
15+
import java.util.Arrays;
16+
import java.util.EnumSet;
17+
import java.util.Locale;
18+
import java.util.Map;
19+
20+
/**
21+
* Defines the type of embedding that the Jina AI API should return for a request.
22+
*
23+
*/
24+
public enum JinaAIEmbeddingType {
25+
/**
26+
* Use this when you want to get back the default float embeddings.
27+
*/
28+
FLOAT(DenseVectorFieldMapper.ElementType.FLOAT, RequestConstants.FLOAT),
29+
/**
30+
* Use this when you want to get back binary embeddings.
31+
*/
32+
BIT(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BIT),
33+
/**
34+
* This is a synonym for BIT
35+
*/
36+
BINARY(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BIT);
37+
38+
private static final class RequestConstants {
39+
private static final String FLOAT = "float";
40+
private static final String BIT = "binary";
41+
}
42+
43+
private static final Map<DenseVectorFieldMapper.ElementType, JinaAIEmbeddingType> ELEMENT_TYPE_TO_JINA_AI_EMBEDDING = Map.of(
44+
DenseVectorFieldMapper.ElementType.FLOAT,
45+
FLOAT,
46+
DenseVectorFieldMapper.ElementType.BIT,
47+
BIT
48+
);
49+
static final EnumSet<DenseVectorFieldMapper.ElementType> SUPPORTED_ELEMENT_TYPES = EnumSet.copyOf(
50+
ELEMENT_TYPE_TO_JINA_AI_EMBEDDING.keySet()
51+
);
52+
53+
private final DenseVectorFieldMapper.ElementType elementType;
54+
private final String requestString;
55+
56+
JinaAIEmbeddingType(DenseVectorFieldMapper.ElementType elementType, String requestString) {
57+
this.elementType = elementType;
58+
this.requestString = requestString;
59+
}
60+
61+
@Override
62+
public String toString() {
63+
return name().toLowerCase(Locale.ROOT);
64+
}
65+
66+
public String toRequestString() {
67+
return requestString;
68+
}
69+
70+
public static String toLowerCase(JinaAIEmbeddingType type) {
71+
return type.toString().toLowerCase(Locale.ROOT);
72+
}
73+
74+
public static JinaAIEmbeddingType fromString(String name) {
75+
return valueOf(name.trim().toUpperCase(Locale.ROOT));
76+
}
77+
78+
public static JinaAIEmbeddingType fromElementType(DenseVectorFieldMapper.ElementType elementType) {
79+
var embedding = ELEMENT_TYPE_TO_JINA_AI_EMBEDDING.get(elementType);
80+
81+
if (embedding == null) {
82+
var validElementTypes = SUPPORTED_ELEMENT_TYPES.stream()
83+
.map(value -> value.toString().toLowerCase(Locale.ROOT))
84+
.toArray(String[]::new);
85+
Arrays.sort(validElementTypes);
86+
87+
throw new IllegalArgumentException(
88+
Strings.format(
89+
"Element type [%s] does not map to a Jina AI embedding value, must be one of [%s]",
90+
elementType,
91+
String.join(", ", validElementTypes)
92+
)
93+
);
94+
}
95+
96+
return embedding;
97+
}
98+
99+
public DenseVectorFieldMapper.ElementType toElementType() {
100+
return elementType;
101+
}
102+
103+
/**
104+
* Returns an embedding type that is known based on the transport version provided. If the embedding type enum was not yet
105+
* introduced it will be defaulted FLOAT.
106+
*
107+
* @param embeddingType the value to translate if necessary
108+
* @param version the version that dictates the translation
109+
* @return the embedding type that is known to the version passed in
110+
*/
111+
public static JinaAIEmbeddingType translateToVersion(JinaAIEmbeddingType embeddingType, TransportVersion version) {
112+
if (version.onOrAfter(TransportVersions.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED)
113+
|| version.isPatchFrom(TransportVersions.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_X)) {
114+
return embeddingType;
115+
}
116+
117+
return FLOAT;
118+
}
119+
}

0 commit comments

Comments
 (0)