Skip to content

Commit 172637b

Browse files
[ML] Custom Service add embedding type support (#130141)
* Adding embedding type * Adding more tests and cleaning up
1 parent b34a8c8 commit 172637b

File tree

18 files changed

+686
-136
lines changed

18 files changed

+686
-136
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ static TransportVersion def(int id) {
334334
public static final TransportVersion ML_INFERENCE_IBM_WATSONX_COMPLETION_ADDED = def(9_115_0_00);
335335
public static final TransportVersion ESQL_SPLIT_ON_BIG_VALUES = def(9_116_0_00);
336336
public static final TransportVersion ESQL_LOCAL_RELATION_WITH_NEW_BLOCKS = def(9_117_0_00);
337+
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE = def(9_118_0_00);
337338

338339
/*
339340
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,5 +1085,11 @@ public static void validateInputTypeAgainstAllowlist(
10851085
}
10861086
}
10871087

1088+
public static void checkByteBounds(short value) {
1089+
if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
1090+
throw new IllegalArgumentException("Value [" + value + "] is out of range for a byte");
1091+
}
1092+
}
1093+
10881094
private ServiceUtils() {}
10891095
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ public CustomModel(
4646
inferenceId,
4747
taskType,
4848
service,
49-
CustomServiceSettings.fromMap(serviceSettings, context, taskType, inferenceId),
49+
CustomServiceSettings.fromMap(serviceSettings, context, taskType),
5050
CustomTaskSettings.fromMap(taskSettings),
5151
CustomSecretSettings.fromMap(secrets)
5252
);
@@ -66,7 +66,7 @@ public CustomModel(
6666
inferenceId,
6767
taskType,
6868
service,
69-
CustomServiceSettings.fromMap(serviceSettings, context, taskType, inferenceId),
69+
CustomServiceSettings.fromMap(serviceSettings, context, taskType),
7070
CustomTaskSettings.fromMap(taskSettings),
7171
CustomSecretSettings.fromMap(secrets),
7272
chunkingSettings

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -333,12 +333,7 @@ private static CustomServiceSettings getCustomServiceSettings(CustomModel custom
333333
var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
334334

335335
return new CustomServiceSettings(
336-
new CustomServiceSettings.TextEmbeddingSettings(
337-
similarityToUse,
338-
embeddingSize,
339-
serviceSettings.getMaxInputTokens(),
340-
serviceSettings.elementType()
341-
),
336+
new CustomServiceSettings.TextEmbeddingSettings(similarityToUse, embeddingSize, serviceSettings.getMaxInputTokens()),
342337
serviceSettings.getUrl(),
343338
serviceSettings.getHeaders(),
344339
serviceSettings.getQueryParameters(),
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.custom;
9+
10+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
11+
12+
import java.util.Locale;
13+
14+
public enum CustomServiceEmbeddingType {
15+
/**
16+
* Use this when you want to get back the default float embeddings.
17+
*/
18+
FLOAT(DenseVectorFieldMapper.ElementType.FLOAT),
19+
/**
20+
* Use this when you want to get back signed int8 embeddings.
21+
*/
22+
BYTE(DenseVectorFieldMapper.ElementType.BYTE),
23+
/**
24+
* Use this when you want to get back binary embeddings.
25+
*/
26+
BIT(DenseVectorFieldMapper.ElementType.BIT),
27+
/**
28+
* This is a synonym for BIT
29+
*/
30+
BINARY(DenseVectorFieldMapper.ElementType.BIT);
31+
32+
private final DenseVectorFieldMapper.ElementType elementType;
33+
34+
CustomServiceEmbeddingType(DenseVectorFieldMapper.ElementType elementType) {
35+
this.elementType = elementType;
36+
}
37+
38+
@Override
39+
public String toString() {
40+
return name().toLowerCase(Locale.ROOT);
41+
}
42+
43+
public DenseVectorFieldMapper.ElementType toElementType() {
44+
return elementType;
45+
}
46+
47+
public static CustomServiceEmbeddingType fromString(String name) {
48+
return valueOf(name.trim().toUpperCase(Locale.ROOT));
49+
}
50+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,7 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
6666
private static final String RESPONSE_SCOPE = String.join(".", ModelConfigurations.SERVICE_SETTINGS, RESPONSE);
6767
private static final int DEFAULT_EMBEDDING_BATCH_SIZE = 10;
6868

69-
public static CustomServiceSettings fromMap(
70-
Map<String, Object> map,
71-
ConfigurationParseContext context,
72-
TaskType taskType,
73-
String inferenceId
74-
) {
69+
public static CustomServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context, TaskType taskType) {
7570
ValidationException validationException = new ValidationException();
7671

7772
var textEmbeddingSettings = TextEmbeddingSettings.fromMap(map, taskType, validationException);
@@ -137,22 +132,12 @@ public static CustomServiceSettings fromMap(
137132
);
138133
}
139134

140-
public record TextEmbeddingSettings(
141-
@Nullable SimilarityMeasure similarityMeasure,
142-
@Nullable Integer dimensions,
143-
@Nullable Integer maxInputTokens,
144-
@Nullable DenseVectorFieldMapper.ElementType elementType
145-
) implements ToXContentFragment, Writeable {
135+
public static class TextEmbeddingSettings implements ToXContentFragment, Writeable {
146136

147137
// This specifies float for the element type but null for all other settings
148-
public static final TextEmbeddingSettings DEFAULT_FLOAT = new TextEmbeddingSettings(
149-
null,
150-
null,
151-
null,
152-
DenseVectorFieldMapper.ElementType.FLOAT
153-
);
138+
public static final TextEmbeddingSettings DEFAULT_FLOAT = new TextEmbeddingSettings(null, null, null);
154139
// This refers to settings that are not related to the text embedding task type (all the settings should be null)
155-
public static final TextEmbeddingSettings NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS = new TextEmbeddingSettings(null, null, null, null);
140+
public static final TextEmbeddingSettings NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS = new TextEmbeddingSettings(null, null, null);
156141

157142
public static TextEmbeddingSettings fromMap(Map<String, Object> map, TaskType taskType, ValidationException validationException) {
158143
if (taskType != TaskType.TEXT_EMBEDDING) {
@@ -162,24 +147,42 @@ public static TextEmbeddingSettings fromMap(Map<String, Object> map, TaskType ta
162147
SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
163148
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
164149
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
165-
return new TextEmbeddingSettings(similarity, dims, maxInputTokens, DenseVectorFieldMapper.ElementType.FLOAT);
150+
return new TextEmbeddingSettings(similarity, dims, maxInputTokens);
151+
}
152+
153+
private final SimilarityMeasure similarityMeasure;
154+
private final Integer dimensions;
155+
private final Integer maxInputTokens;
156+
157+
public TextEmbeddingSettings(
158+
@Nullable SimilarityMeasure similarityMeasure,
159+
@Nullable Integer dimensions,
160+
@Nullable Integer maxInputTokens
161+
) {
162+
this.similarityMeasure = similarityMeasure;
163+
this.dimensions = dimensions;
164+
this.maxInputTokens = maxInputTokens;
166165
}
167166

168167
public TextEmbeddingSettings(StreamInput in) throws IOException {
169-
this(
170-
in.readOptionalEnum(SimilarityMeasure.class),
171-
in.readOptionalVInt(),
172-
in.readOptionalVInt(),
173-
in.readOptionalEnum(DenseVectorFieldMapper.ElementType.class)
174-
);
168+
this.similarityMeasure = in.readOptionalEnum(SimilarityMeasure.class);
169+
this.dimensions = in.readOptionalVInt();
170+
this.maxInputTokens = in.readOptionalVInt();
171+
172+
if (in.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE)) {
173+
in.readOptionalEnum(DenseVectorFieldMapper.ElementType.class);
174+
}
175175
}
176176

177177
@Override
178178
public void writeTo(StreamOutput out) throws IOException {
179179
out.writeOptionalEnum(similarityMeasure);
180180
out.writeOptionalVInt(dimensions);
181181
out.writeOptionalVInt(maxInputTokens);
182-
out.writeOptionalEnum(elementType);
182+
183+
if (out.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE)) {
184+
out.writeOptionalEnum(null);
185+
}
183186
}
184187

185188
@Override
@@ -193,8 +196,23 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
193196
if (maxInputTokens != null) {
194197
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
195198
}
199+
196200
return builder;
197201
}
202+
203+
@Override
204+
public boolean equals(Object o) {
205+
if (o == null || getClass() != o.getClass()) return false;
206+
TextEmbeddingSettings that = (TextEmbeddingSettings) o;
207+
return similarityMeasure == that.similarityMeasure
208+
&& Objects.equals(dimensions, that.dimensions)
209+
&& Objects.equals(maxInputTokens, that.maxInputTokens);
210+
}
211+
212+
@Override
213+
public int hashCode() {
214+
return Objects.hash(similarityMeasure, dimensions, maxInputTokens);
215+
}
198216
}
199217

200218
private final TextEmbeddingSettings textEmbeddingSettings;
@@ -300,7 +318,12 @@ public Integer dimensions() {
300318

301319
@Override
302320
public DenseVectorFieldMapper.ElementType elementType() {
303-
return textEmbeddingSettings.elementType;
321+
var embeddingType = responseJsonParser.getEmbeddingType();
322+
if (embeddingType != null) {
323+
return embeddingType.toElementType();
324+
}
325+
326+
return null;
304327
}
305328

306329
public Integer getMaxInputTokens() {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/BaseCustomResponseParser.java

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
import java.util.Objects;
2323
import java.util.function.BiFunction;
2424

25-
public abstract class BaseCustomResponseParser<T extends InferenceServiceResults> implements CustomResponseParser {
25+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.checkByteBounds;
26+
27+
public abstract class BaseCustomResponseParser implements CustomResponseParser {
2628

2729
@Override
2830
public InferenceServiceResults parse(HttpResult response) throws IOException {
@@ -36,7 +38,7 @@ public InferenceServiceResults parse(HttpResult response) throws IOException {
3638
}
3739
}
3840

39-
protected abstract T transform(Map<String, Object> extractedField);
41+
protected abstract InferenceServiceResults transform(Map<String, Object> extractedField);
4042

4143
static List<?> validateList(Object obj, String fieldName) {
4244
validateNonNull(obj, fieldName);
@@ -97,6 +99,21 @@ static Float toFloat(Object obj, String fieldName) {
9799
return toNumber(obj, fieldName).floatValue();
98100
}
99101

102+
static List<Byte> convertToListOfBits(Object obj, String fieldName) {
103+
return convertToListOfBytes(obj, fieldName);
104+
}
105+
106+
static List<Byte> convertToListOfBytes(Object obj, String fieldName) {
107+
return castList(validateList(obj, fieldName), BaseCustomResponseParser::toByte, fieldName);
108+
}
109+
110+
static Byte toByte(Object obj, String fieldName) {
111+
var shortValue = toNumber(obj, fieldName).shortValue();
112+
checkByteBounds(shortValue);
113+
114+
return (byte) shortValue;
115+
}
116+
100117
private static Number toNumber(Object obj, String fieldName) {
101118
if (obj instanceof Number == false) {
102119
throw new IllegalArgumentException(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParser.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
2424
import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER;
2525

26-
public class CompletionResponseParser extends BaseCustomResponseParser<ChatCompletionResults> {
26+
public class CompletionResponseParser extends BaseCustomResponseParser {
2727

2828
public static final String NAME = "completion_response_parser";
2929
public static final String COMPLETION_PARSER_RESULT = "completion_result";

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseParser.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,17 @@
1111
import org.elasticsearch.inference.InferenceServiceResults;
1212
import org.elasticsearch.xcontent.ToXContentFragment;
1313
import org.elasticsearch.xpack.inference.external.http.HttpResult;
14+
import org.elasticsearch.xpack.inference.services.custom.CustomServiceEmbeddingType;
1415

1516
import java.io.IOException;
1617

1718
public interface CustomResponseParser extends ToXContentFragment, NamedWriteable {
1819
InferenceServiceResults parse(HttpResult response) throws IOException;
20+
21+
/**
22+
* Returns the configured embedding type for this response parser. This should be overridden for text embedding parsers.
23+
*/
24+
default CustomServiceEmbeddingType getEmbeddingType() {
25+
return null;
26+
}
1927
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParser.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
2727
import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER;
2828

29-
public class RerankResponseParser extends BaseCustomResponseParser<RankedDocsResults> {
29+
public class RerankResponseParser extends BaseCustomResponseParser {
3030

3131
public static final String NAME = "rerank_response_parser";
3232
public static final String RERANK_PARSER_SCORE = "relevance_score";

0 commit comments

Comments
 (0)