- 
                Notifications
    
You must be signed in to change notification settings  - Fork 25.6k
 
[ML] Custom Service add embedding type support #130141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ad091b6
              c6f9a7f
              a4b50bd
              9c42ecd
              12d271e
              2a9330b
              43f550c
              a00919b
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
| 
     | 
||
| package org.elasticsearch.xpack.inference.services.custom; | ||
| 
     | 
||
| import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; | ||
| 
     | 
||
| import java.util.Locale; | ||
| 
     | 
||
| public enum CustomServiceEmbeddingType { | ||
| /** | ||
| * Use this when you want to get back the default float embeddings. | ||
| */ | ||
| FLOAT(DenseVectorFieldMapper.ElementType.FLOAT), | ||
| /** | ||
| * Use this when you want to get back signed int8 embeddings. | ||
| */ | ||
| BYTE(DenseVectorFieldMapper.ElementType.BYTE), | ||
| /** | ||
| * Use this when you want to get back binary embeddings. | ||
| */ | ||
| BIT(DenseVectorFieldMapper.ElementType.BIT), | ||
| /** | ||
| * This is a synonym for BIT | ||
| */ | ||
| BINARY(DenseVectorFieldMapper.ElementType.BIT); | ||
| 
     | 
||
| private final DenseVectorFieldMapper.ElementType elementType; | ||
| 
     | 
||
| CustomServiceEmbeddingType(DenseVectorFieldMapper.ElementType elementType) { | ||
| this.elementType = elementType; | ||
| } | ||
| 
     | 
||
| @Override | ||
| public String toString() { | ||
| return name().toLowerCase(Locale.ROOT); | ||
| } | ||
| 
     | 
||
| public DenseVectorFieldMapper.ElementType toElementType() { | ||
| return elementType; | ||
| } | ||
| 
     | 
||
| public static CustomServiceEmbeddingType fromString(String name) { | ||
| return valueOf(name.trim().toUpperCase(Locale.ROOT)); | ||
| } | ||
| } | 
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -66,12 +66,7 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser | |
| private static final String RESPONSE_SCOPE = String.join(".", ModelConfigurations.SERVICE_SETTINGS, RESPONSE); | ||
| private static final int DEFAULT_EMBEDDING_BATCH_SIZE = 10; | ||
| 
     | 
||
| public static CustomServiceSettings fromMap( | ||
| Map<String, Object> map, | ||
| ConfigurationParseContext context, | ||
| TaskType taskType, | ||
| String inferenceId | ||
| ) { | ||
| public static CustomServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context, TaskType taskType) { | ||
| ValidationException validationException = new ValidationException(); | ||
| 
     | 
||
| var textEmbeddingSettings = TextEmbeddingSettings.fromMap(map, taskType, validationException); | ||
| 
          
            
          
           | 
    @@ -137,22 +132,12 @@ public static CustomServiceSettings fromMap( | |
| ); | ||
| } | ||
| 
     | 
||
| public record TextEmbeddingSettings( | ||
| @Nullable SimilarityMeasure similarityMeasure, | ||
| @Nullable Integer dimensions, | ||
| @Nullable Integer maxInputTokens, | ||
| @Nullable DenseVectorFieldMapper.ElementType elementType | ||
| ) implements ToXContentFragment, Writeable { | ||
| public static class TextEmbeddingSettings implements ToXContentFragment, Writeable { | ||
| 
     | 
||
| // This specifies float for the element type but null for all other settings | ||
| public static final TextEmbeddingSettings DEFAULT_FLOAT = new TextEmbeddingSettings( | ||
| null, | ||
| null, | ||
| null, | ||
| DenseVectorFieldMapper.ElementType.FLOAT | ||
| ); | ||
| public static final TextEmbeddingSettings DEFAULT_FLOAT = new TextEmbeddingSettings(null, null, null); | ||
| // This refers to settings that are not related to the text embedding task type (all the settings should be null) | ||
| public static final TextEmbeddingSettings NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS = new TextEmbeddingSettings(null, null, null, null); | ||
| public static final TextEmbeddingSettings NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS = new TextEmbeddingSettings(null, null, null); | ||
| 
     | 
||
| public static TextEmbeddingSettings fromMap(Map<String, Object> map, TaskType taskType, ValidationException validationException) { | ||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We never included  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep that's correct, it was just hard coded previously.  | 
||
| if (taskType != TaskType.TEXT_EMBEDDING) { | ||
| 
        
          
        
         | 
    @@ -162,24 +147,42 @@ public static TextEmbeddingSettings fromMap(Map<String, Object> map, TaskType ta | |
| SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); | ||
| Integer dims = removeAsType(map, DIMENSIONS, Integer.class); | ||
| Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); | ||
| return new TextEmbeddingSettings(similarity, dims, maxInputTokens, DenseVectorFieldMapper.ElementType.FLOAT); | ||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The element type logic has been delegated to the   | 
||
| return new TextEmbeddingSettings(similarity, dims, maxInputTokens); | ||
| } | ||
| 
     | 
||
| private final SimilarityMeasure similarityMeasure; | ||
| private final Integer dimensions; | ||
| private final Integer maxInputTokens; | ||
| 
     | 
||
| public TextEmbeddingSettings( | ||
| @Nullable SimilarityMeasure similarityMeasure, | ||
| @Nullable Integer dimensions, | ||
| @Nullable Integer maxInputTokens | ||
| ) { | ||
| this.similarityMeasure = similarityMeasure; | ||
| this.dimensions = dimensions; | ||
| this.maxInputTokens = maxInputTokens; | ||
| } | ||
| 
     | 
||
| public TextEmbeddingSettings(StreamInput in) throws IOException { | ||
| this( | ||
| in.readOptionalEnum(SimilarityMeasure.class), | ||
| in.readOptionalVInt(), | ||
| in.readOptionalVInt(), | ||
| in.readOptionalEnum(DenseVectorFieldMapper.ElementType.class) | ||
| ); | ||
| this.similarityMeasure = in.readOptionalEnum(SimilarityMeasure.class); | ||
| this.dimensions = in.readOptionalVInt(); | ||
| this.maxInputTokens = in.readOptionalVInt(); | ||
| 
     | 
||
| if (in.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE)) { | ||
| in.readOptionalEnum(DenseVectorFieldMapper.ElementType.class); | ||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For older versions, we'll read it but ignore it. It should only be float which we'll default to in the   | 
||
| } | ||
| } | ||
| 
     | 
||
| @Override | ||
| public void writeTo(StreamOutput out) throws IOException { | ||
| out.writeOptionalEnum(similarityMeasure); | ||
| out.writeOptionalVInt(dimensions); | ||
| out.writeOptionalVInt(maxInputTokens); | ||
| out.writeOptionalEnum(elementType); | ||
| 
     | 
||
| if (out.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE)) { | ||
| out.writeOptionalEnum(null); | ||
| } | ||
| } | ||
| 
     | 
||
| @Override | ||
| 
        
          
        
         | 
    @@ -193,8 +196,23 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws | |
| if (maxInputTokens != null) { | ||
| builder.field(MAX_INPUT_TOKENS, maxInputTokens); | ||
| } | ||
| 
     | 
||
| return builder; | ||
| } | ||
| 
     | 
||
| @Override | ||
| public boolean equals(Object o) { | ||
| if (o == null || getClass() != o.getClass()) return false; | ||
| TextEmbeddingSettings that = (TextEmbeddingSettings) o; | ||
| return similarityMeasure == that.similarityMeasure | ||
| && Objects.equals(dimensions, that.dimensions) | ||
| && Objects.equals(maxInputTokens, that.maxInputTokens); | ||
| } | ||
| 
     | 
||
| @Override | ||
| public int hashCode() { | ||
| return Objects.hash(similarityMeasure, dimensions, maxInputTokens); | ||
| } | ||
| } | ||
| 
     | 
||
| private final TextEmbeddingSettings textEmbeddingSettings; | ||
| 
          
            
          
           | 
    @@ -300,7 +318,12 @@ public Integer dimensions() { | |
| 
     | 
||
| @Override | ||
| public DenseVectorFieldMapper.ElementType elementType() { | ||
| return textEmbeddingSettings.elementType; | ||
| var embeddingType = responseJsonParser.getEmbeddingType(); | ||
| if (embeddingType != null) { | ||
| return embeddingType.toElementType(); | ||
| } | ||
| 
     | 
||
| return null; | ||
| } | ||
| 
     | 
||
| public Integer getMaxInputTokens() { | ||
| 
          
            
          
           | 
    ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -22,7 +22,9 @@ | |
| import java.util.Objects; | ||
| import java.util.function.BiFunction; | ||
| 
     | 
||
| public abstract class BaseCustomResponseParser<T extends InferenceServiceResults> implements CustomResponseParser { | ||
| import static org.elasticsearch.xpack.inference.services.ServiceUtils.checkByteBounds; | ||
| 
     | 
||
| public abstract class BaseCustomResponseParser implements CustomResponseParser { | ||
| 
     | 
||
| @Override | ||
| public InferenceServiceResults parse(HttpResult response) throws IOException { | ||
| 
        
          
        
         | 
    @@ -36,7 +38,7 @@ public InferenceServiceResults parse(HttpResult response) throws IOException { | |
| } | ||
| } | ||
| 
     | 
||
| protected abstract T transform(Map<String, Object> extractedField); | ||
| protected abstract InferenceServiceResults transform(Map<String, Object> extractedField); | ||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The   | 
||
| 
     | 
||
| static List<?> validateList(Object obj, String fieldName) { | ||
| validateNonNull(obj, fieldName); | ||
| 
          
            
          
           | 
    @@ -97,6 +99,21 @@ static Float toFloat(Object obj, String fieldName) { | |
| return toNumber(obj, fieldName).floatValue(); | ||
| } | ||
| 
     | 
||
| static List<Byte> convertToListOfBits(Object obj, String fieldName) { | ||
| return convertToListOfBytes(obj, fieldName); | ||
| } | ||
| 
     | 
||
| static List<Byte> convertToListOfBytes(Object obj, String fieldName) { | ||
| return castList(validateList(obj, fieldName), BaseCustomResponseParser::toByte, fieldName); | ||
| } | ||
| 
     | 
||
| static Byte toByte(Object obj, String fieldName) { | ||
| var shortValue = toNumber(obj, fieldName).shortValue(); | ||
| checkByteBounds(shortValue); | ||
| 
     | 
||
| return (byte) shortValue; | ||
| } | ||
| 
     | 
||
| private static Number toNumber(Object obj, String fieldName) { | ||
| if (obj instanceof Number == false) { | ||
| throw new IllegalArgumentException( | ||
| 
          
            
          
           | 
    ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inferenceIdwasn't being used.