Skip to content
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED_8_19 = def(8_841_0_59);
public static final TransportVersion ML_INFERENCE_COHERE_API_VERSION_8_19 = def(8_841_0_60);
public static final TransportVersion ESQL_DOCUMENTS_FOUND_AND_VALUES_LOADED_8_19 = def(8_841_0_61);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE_8_19 = def(8_841_0_62);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down Expand Up @@ -325,6 +326,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ML_INFERENCE_COHERE_API_VERSION = def(9_110_0_00);
public static final TransportVersion ESQL_PROFILE_INCLUDE_PLAN = def(9_111_0_00);
public static final TransportVersion MAPPINGS_IN_DATA_STREAMS = def(9_112_0_00);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE = def(9_113_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1085,5 +1085,11 @@ public static void validateInputTypeAgainstAllowlist(
}
}

public static void checkByteBounds(short value) {
if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
throw new IllegalArgumentException("Value [" + value + "] is out of range for a byte");
}
}

private ServiceUtils() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public CustomModel(
inferenceId,
taskType,
service,
CustomServiceSettings.fromMap(serviceSettings, context, taskType, inferenceId),
CustomServiceSettings.fromMap(serviceSettings, context, taskType),
CustomTaskSettings.fromMap(taskSettings),
CustomSecretSettings.fromMap(secrets)
);
Expand All @@ -66,7 +66,7 @@ public CustomModel(
inferenceId,
taskType,
service,
CustomServiceSettings.fromMap(serviceSettings, context, taskType, inferenceId),
CustomServiceSettings.fromMap(serviceSettings, context, taskType),
CustomTaskSettings.fromMap(taskSettings),
CustomSecretSettings.fromMap(secrets),
chunkingSettings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,12 +333,7 @@ private static CustomServiceSettings getCustomServiceSettings(CustomModel custom
var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;

return new CustomServiceSettings(
new CustomServiceSettings.TextEmbeddingSettings(
similarityToUse,
embeddingSize,
serviceSettings.getMaxInputTokens(),
serviceSettings.elementType()
),
new CustomServiceSettings.TextEmbeddingSettings(similarityToUse, embeddingSize, serviceSettings.getMaxInputTokens()),
serviceSettings.getUrl(),
serviceSettings.getHeaders(),
serviceSettings.getQueryParameters(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.custom;

import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;

import java.util.Locale;

public enum CustomServiceEmbeddingType {
/**
* Use this when you want to get back the default float embeddings.
*/
FLOAT(DenseVectorFieldMapper.ElementType.FLOAT),
/**
* Use this when you want to get back signed int8 embeddings.
*/
BYTE(DenseVectorFieldMapper.ElementType.BYTE),
/**
* Use this when you want to get back binary embeddings.
*/
BIT(DenseVectorFieldMapper.ElementType.BIT),
/**
* This is a synonym for BIT
*/
BINARY(DenseVectorFieldMapper.ElementType.BIT);

private final DenseVectorFieldMapper.ElementType elementType;

CustomServiceEmbeddingType(DenseVectorFieldMapper.ElementType elementType) {
this.elementType = elementType;
}

@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
}

public DenseVectorFieldMapper.ElementType toElementType() {
return elementType;
}

public static CustomServiceEmbeddingType fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,7 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
private static final String RESPONSE_SCOPE = String.join(".", ModelConfigurations.SERVICE_SETTINGS, RESPONSE);
private static final int DEFAULT_EMBEDDING_BATCH_SIZE = 10;

public static CustomServiceSettings fromMap(
Map<String, Object> map,
ConfigurationParseContext context,
TaskType taskType,
String inferenceId
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inferenceId wasn't being used.

) {
public static CustomServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context, TaskType taskType) {
ValidationException validationException = new ValidationException();

var textEmbeddingSettings = TextEmbeddingSettings.fromMap(map, taskType, validationException);
Expand Down Expand Up @@ -137,22 +132,12 @@ public static CustomServiceSettings fromMap(
);
}

public record TextEmbeddingSettings(
@Nullable SimilarityMeasure similarityMeasure,
@Nullable Integer dimensions,
@Nullable Integer maxInputTokens,
@Nullable DenseVectorFieldMapper.ElementType elementType
) implements ToXContentFragment, Writeable {
public static class TextEmbeddingSettings implements ToXContentFragment, Writeable {

// This specifies float for the element type but null for all other settings
public static final TextEmbeddingSettings DEFAULT_FLOAT = new TextEmbeddingSettings(
null,
null,
null,
DenseVectorFieldMapper.ElementType.FLOAT
);
public static final TextEmbeddingSettings DEFAULT_FLOAT = new TextEmbeddingSettings(null, null, null);
// This refers to settings that are not related to the text embedding task type (all the settings should be null)
public static final TextEmbeddingSettings NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS = new TextEmbeddingSettings(null, null, null, null);
public static final TextEmbeddingSettings NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS = new TextEmbeddingSettings(null, null, null);

public static TextEmbeddingSettings fromMap(Map<String, Object> map, TaskType taskType, ValidationException validationException) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We never included elementType in the toXContent method, so we don't have to worry about backwards compatibility with older versions of this model, correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep that's correct, it was just hard coded previously.

if (taskType != TaskType.TEXT_EMBEDDING) {
Expand All @@ -162,24 +147,44 @@ public static TextEmbeddingSettings fromMap(Map<String, Object> map, TaskType ta
SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
return new TextEmbeddingSettings(similarity, dims, maxInputTokens, DenseVectorFieldMapper.ElementType.FLOAT);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The element type logic has been delegated to the TextEmbeddingResponseParser so removing it from being hard coded to float here.

return new TextEmbeddingSettings(similarity, dims, maxInputTokens);
}

private final SimilarityMeasure similarityMeasure;
private final Integer dimensions;
private final Integer maxInputTokens;

public TextEmbeddingSettings(
@Nullable SimilarityMeasure similarityMeasure,
@Nullable Integer dimensions,
@Nullable Integer maxInputTokens
) {
this.similarityMeasure = similarityMeasure;
this.dimensions = dimensions;
this.maxInputTokens = maxInputTokens;
}

public TextEmbeddingSettings(StreamInput in) throws IOException {
this(
in.readOptionalEnum(SimilarityMeasure.class),
in.readOptionalVInt(),
in.readOptionalVInt(),
in.readOptionalEnum(DenseVectorFieldMapper.ElementType.class)
);
this.similarityMeasure = in.readOptionalEnum(SimilarityMeasure.class);
this.dimensions = in.readOptionalVInt();
this.maxInputTokens = in.readOptionalVInt();

if (in.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE)
&& in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE_8_19) == false) {
in.readOptionalEnum(DenseVectorFieldMapper.ElementType.class);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For older versions, we'll read it but ignore it. It should only be float which we'll default to in the TextEmbeddingResponseParser.

}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalEnum(similarityMeasure);
out.writeOptionalVInt(dimensions);
out.writeOptionalVInt(maxInputTokens);
out.writeOptionalEnum(elementType);

if (out.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE)
&& out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE_8_19) == false) {
out.writeOptionalEnum(null);
}
}

@Override
Expand All @@ -193,8 +198,23 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (maxInputTokens != null) {
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
}

return builder;
}

@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
TextEmbeddingSettings that = (TextEmbeddingSettings) o;
return similarityMeasure == that.similarityMeasure
&& Objects.equals(dimensions, that.dimensions)
&& Objects.equals(maxInputTokens, that.maxInputTokens);
}

@Override
public int hashCode() {
return Objects.hash(similarityMeasure, dimensions, maxInputTokens);
}
}

private final TextEmbeddingSettings textEmbeddingSettings;
Expand Down Expand Up @@ -300,7 +320,12 @@ public Integer dimensions() {

@Override
public DenseVectorFieldMapper.ElementType elementType() {
return textEmbeddingSettings.elementType;
var embeddingType = responseJsonParser.getEmbeddingType();
if (embeddingType != null) {
return embeddingType.toElementType();
}

return null;
}

public Integer getMaxInputTokens() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
import java.util.Objects;
import java.util.function.BiFunction;

public abstract class BaseCustomResponseParser<T extends InferenceServiceResults> implements CustomResponseParser {
import static org.elasticsearch.xpack.inference.services.ServiceUtils.checkByteBounds;

public abstract class BaseCustomResponseParser implements CustomResponseParser {

@Override
public InferenceServiceResults parse(HttpResult response) throws IOException {
Expand All @@ -36,7 +38,7 @@ public InferenceServiceResults parse(HttpResult response) throws IOException {
}
}

protected abstract T transform(Map<String, Object> extractedField);
protected abstract InferenceServiceResults transform(Map<String, Object> extractedField);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TextEmbeddingResponseParser will now return different types of results depending on the embedding type so we need this to be the base results type to cover all them


static List<?> validateList(Object obj, String fieldName) {
validateNonNull(obj, fieldName);
Expand Down Expand Up @@ -97,6 +99,21 @@ static Float toFloat(Object obj, String fieldName) {
return toNumber(obj, fieldName).floatValue();
}

static List<Byte> convertToListOfBits(Object obj, String fieldName) {
return convertToListOfBytes(obj, fieldName);
}

static List<Byte> convertToListOfBytes(Object obj, String fieldName) {
return castList(validateList(obj, fieldName), BaseCustomResponseParser::toByte, fieldName);
}

static Byte toByte(Object obj, String fieldName) {
var shortValue = toNumber(obj, fieldName).shortValue();
checkByteBounds(shortValue);

return (byte) shortValue;
}

private static Number toNumber(Object obj, String fieldName) {
if (obj instanceof Number == false) {
throw new IllegalArgumentException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER;

public class CompletionResponseParser extends BaseCustomResponseParser<ChatCompletionResults> {
public class CompletionResponseParser extends BaseCustomResponseParser {

public static final String NAME = "completion_response_parser";
public static final String COMPLETION_PARSER_RESULT = "completion_result";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,17 @@
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.ToXContentFragment;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.services.custom.CustomServiceEmbeddingType;

import java.io.IOException;

public interface CustomResponseParser extends ToXContentFragment, NamedWriteable {
InferenceServiceResults parse(HttpResult response) throws IOException;

/**
* Returns the configured embedding type for this response parser. This should be overridden for text embedding parsers.
*/
default CustomServiceEmbeddingType getEmbeddingType() {
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER;

public class RerankResponseParser extends BaseCustomResponseParser<RankedDocsResults> {
public class RerankResponseParser extends BaseCustomResponseParser {

public static final String NAME = "rerank_response_parser";
public static final String RERANK_PARSER_SCORE = "relevance_score";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER;

public class SparseEmbeddingResponseParser extends BaseCustomResponseParser<SparseEmbeddingResults> {
public class SparseEmbeddingResponseParser extends BaseCustomResponseParser {

public static final String NAME = "sparse_embedding_response_parser";
public static final String SPARSE_EMBEDDING_TOKEN_PATH = "token_path";
Expand Down
Loading