Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/126565.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 126565
summary: Bedrock Cohere support for embedding types
area: Machine Learning
type: enhancement
issues:
- 126526
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ static TransportVersion def(int id) {
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15);
public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16);
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17);
public static final TransportVersion AMAZON_BEDROCK_EMBEDDING_TYPES_8_19 = def(8_841_0_18);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
Expand Down Expand Up @@ -217,6 +218,7 @@ static TransportVersion def(int id) {
public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG = def(9_047_00_0);
public static final TransportVersion REPO_ANALYSIS_COPY_BLOB = def(9_048_00_0);
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS = def(9_049_00_0);
public static final TransportVersion AMAZON_BEDROCK_EMBEDDING_TYPES = def(9_050_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,8 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
serviceSettings.dimensionsSetByUser(),
serviceSettings.maxInputTokens(),
similarityToUse,
serviceSettings.rateLimitSettings()
serviceSettings.rateLimitSettings(),
serviceSettings.embeddingType()
);

return new AmazonBedrockEmbeddingsModel(embeddingsModel, updatedServiceSettings);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;

import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand All @@ -19,6 +20,7 @@
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider;
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

import java.io.IOException;
Expand All @@ -29,17 +31,20 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;

public class AmazonBedrockEmbeddingsServiceSettings extends AmazonBedrockServiceSettings {
public static final String NAME = "amazon_bedrock_embeddings_service_settings";
static final String EMBEDDING_TYPE = "embedding_type";
static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user";

private final Integer dimensions;
private final Boolean dimensionsSetByUser;
private final Integer maxInputTokens;
private final SimilarityMeasure similarity;
private final CohereEmbeddingType embeddingType;

public static AmazonBedrockEmbeddingsServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException();
Expand Down Expand Up @@ -71,6 +76,15 @@ private static AmazonBedrockEmbeddingsServiceSettings embeddingSettingsFromMap(

Boolean dimensionsSetByUser = extractOptionalBoolean(map, DIMENSIONS_SET_BY_USER, validationException);

var embeddingType = extractOptionalEnum(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a tricky one as the embedding type option depends on the provider but it must go in the service settings as the type cannot change. We can rely on the validate step to catch any cases where the setting must be set but isn't and where the value is not supported.

#126540 is a community contribution that addresses the same missing embedding type problem for Titan. They have added an AmazonBedrockEmbeddingType class, we could map that type to the specific supported values for the provider. And if the provider does not support that type then the validation will fail.

https://github.com/elastic/elasticsearch/pull/126540/files#diff-73bad4ea6d9d6626fb73c2489b51bf631a9ef592f39599e9be8b372c16e11c38

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind waiting until after their change is in, and then we can work from there?

map,
EMBEDDING_TYPE,
ModelConfigurations.SERVICE_SETTINGS,
CohereEmbeddingType::fromString,
CohereEmbeddingType.ALL,
validationException
);

switch (context) {
case REQUEST -> {
if (dimensionsSetByUser != null) {
Expand Down Expand Up @@ -102,7 +116,8 @@ private static AmazonBedrockEmbeddingsServiceSettings embeddingSettingsFromMap(
dimensionsSetByUser,
maxTokens,
similarity,
baseSettings.rateLimitSettings()
baseSettings.rateLimitSettings(),
embeddingType
);
}

Expand All @@ -112,6 +127,9 @@ public AmazonBedrockEmbeddingsServiceSettings(StreamInput in) throws IOException
dimensionsSetByUser = in.readBoolean();
maxInputTokens = in.readOptionalVInt();
similarity = in.readOptionalEnum(SimilarityMeasure.class);
embeddingType = in.getTransportVersion().onOrAfter(TransportVersions.AMAZON_BEDROCK_EMBEDDING_TYPES)
? in.readOptionalEnum(CohereEmbeddingType.class)
: null;
}

public AmazonBedrockEmbeddingsServiceSettings(
Expand All @@ -122,13 +140,15 @@ public AmazonBedrockEmbeddingsServiceSettings(
Boolean dimensionsSetByUser,
@Nullable Integer maxInputTokens,
@Nullable SimilarityMeasure similarity,
RateLimitSettings rateLimitSettings
RateLimitSettings rateLimitSettings,
@Nullable CohereEmbeddingType embeddingType
) {
super(region, model, provider, rateLimitSettings);
this.dimensions = dimensions;
this.dimensionsSetByUser = dimensionsSetByUser;
this.maxInputTokens = maxInputTokens;
this.similarity = similarity;
this.embeddingType = embeddingType;
}

@Override
Expand All @@ -138,6 +158,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(dimensionsSetByUser);
out.writeOptionalVInt(maxInputTokens);
out.writeOptionalEnum(similarity);
if (out.getTransportVersion().onOrAfter(TransportVersions.AMAZON_BEDROCK_EMBEDDING_TYPES)) {
out.writeOptionalEnum(embeddingType);
}
}

@Override
Expand Down Expand Up @@ -169,6 +192,9 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil
if (similarity != null) {
builder.field(SIMILARITY, similarity);
}
if (embeddingType != null) {
builder.field(EMBEDDING_TYPE, embeddingType);
}

return builder;
}
Expand All @@ -192,6 +218,10 @@ public Integer maxInputTokens() {
return maxInputTokens;
}

public CohereEmbeddingType embeddingType() {
return embeddingType;
}

@Override
public DenseVectorFieldMapper.ElementType elementType() {
return DenseVectorFieldMapper.ElementType.FLOAT;
Expand All @@ -210,12 +240,23 @@ public boolean equals(Object o) {
&& Objects.equals(dimensionsSetByUser, that.dimensionsSetByUser)
&& Objects.equals(maxInputTokens, that.maxInputTokens)
&& Objects.equals(similarity, that.similarity)
&& Objects.equals(rateLimitSettings, that.rateLimitSettings);
&& Objects.equals(rateLimitSettings, that.rateLimitSettings)
&& Objects.equals(embeddingType, that.embeddingType);
}

@Override
public int hashCode() {
return Objects.hash(region, model, provider, dimensions, dimensionsSetByUser, maxInputTokens, similarity, rateLimitSettings);
return Objects.hash(
region,
model,
provider,
dimensions,
dimensionsSetByUser,
maxInputTokens,
similarity,
rateLimitSettings,
embeddingType
);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;

import java.io.IOException;
import java.util.List;
Expand All @@ -22,7 +23,8 @@
public record AmazonBedrockCohereEmbeddingsRequestEntity(
List<String> input,
@Nullable InputType inputType,
AmazonBedrockEmbeddingsTaskSettings taskSettings
AmazonBedrockEmbeddingsTaskSettings taskSettings,
@Nullable CohereEmbeddingType embeddingType
) implements ToXContentObject {

private static final String TEXTS_FIELD = "texts";
Expand All @@ -32,6 +34,7 @@ public record AmazonBedrockCohereEmbeddingsRequestEntity(
private static final String CLUSTERING = "clustering";
private static final String CLASSIFICATION = "classification";
private static final String TRUNCATE = "truncate";
private static final String EMBEDDING_TYPES = "embedding_types";

public AmazonBedrockCohereEmbeddingsRequestEntity {
Objects.requireNonNull(input);
Expand All @@ -54,6 +57,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(TRUNCATE, taskSettings.cohereTruncation().name());
}

if (embeddingType != null) {
builder.field(EMBEDDING_TYPES, List.of(embeddingType.toRequestString()));
}

builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ public static ToXContent createEntity(
return new AmazonBedrockTitanEmbeddingsRequestEntity(truncatedInput.get(0));
}
case COHERE -> {
return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput, inputType, model.getTaskSettings());
return new AmazonBedrockCohereEmbeddingsRequestEntity(
truncatedInput,
inputType,
model.getTaskSettings(),
model.getServiceSettings().embeddingType()
);
}
default -> {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.AmazonBedrockRequest;
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.embeddings.AmazonBedrockEmbeddingsRequest;
import org.elasticsearch.xpack.inference.services.amazonbedrock.response.AmazonBedrockResponse;
import org.elasticsearch.xpack.inference.services.cohere.response.CohereEmbeddingsResponseEntity;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
Expand All @@ -48,13 +49,25 @@ public InferenceServiceResults accept(AmazonBedrockRequest request) {
throw new ElasticsearchException("unexpected request type [" + request.getClass() + "]");
}

public static TextEmbeddingFloatResults fromResponse(InvokeModelResponse response, AmazonBedrockProvider provider) {
private static InferenceServiceResults fromResponse(InvokeModelResponse response, AmazonBedrockProvider provider) {
var charset = StandardCharsets.UTF_8;
var bodyText = String.valueOf(charset.decode(response.body().asByteBuffer()));

try {
if (provider == AmazonBedrockProvider.COHERE) {
return CohereEmbeddingsResponseEntity.fromResponse(bodyText.getBytes(StandardCharsets.UTF_8));
} else {
return fromResponse(bodyText, provider);
}
} catch (IOException e) {
throw new ElasticsearchException(e);
}
}

private static TextEmbeddingFloatResults fromResponse(String response, AmazonBedrockProvider provider) {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);

try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, bodyText)) {
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response)) {
// move to the first token
jsonParser.nextToken();

Expand All @@ -71,15 +84,10 @@ public static TextEmbeddingFloatResults fromResponse(InvokeModelResponse respons

private static List<TextEmbeddingFloatResults.Embedding> parseEmbeddings(XContentParser jsonParser, AmazonBedrockProvider provider)
throws IOException {
switch (provider) {
case AMAZONTITAN -> {
return parseTitanEmbeddings(jsonParser);
}
case COHERE -> {
return parseCohereEmbeddings(jsonParser);
}
default -> throw new IOException("Unsupported provider [" + provider + "]");
if (provider == AmazonBedrockProvider.AMAZONTITAN) {
return parseTitanEmbeddings(jsonParser);
}
throw new IOException("Unsupported provider [" + provider + "]");
}

private static List<TextEmbeddingFloatResults.Embedding> parseTitanEmbeddings(XContentParser parser) throws IOException {
Expand All @@ -96,32 +104,4 @@ private static List<TextEmbeddingFloatResults.Embedding> parseTitanEmbeddings(XC
return List.of(embeddingValues);
}

private static List<TextEmbeddingFloatResults.Embedding> parseCohereEmbeddings(XContentParser parser) throws IOException {
/*
Cohere response:
{
"embeddings": [
[< array of 1024 floats >],
...
],
"id": string,
"response_type" : "embeddings_floats",
"texts": [string]
}
*/
positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE);

List<TextEmbeddingFloatResults.Embedding> embeddingList = parseList(
parser,
AmazonBedrockEmbeddingsResponse::parseCohereEmbeddingsListItem
);

return embeddingList;
}

private static TextEmbeddingFloatResults.Embedding parseCohereEmbeddingsListItem(XContentParser parser) throws IOException {
List<Float> embeddingValuesList = parseList(parser, XContentUtils::parseFloat);
return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ public enum CohereEmbeddingType {
*/
BINARY(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BIT);

public static EnumSet<CohereEmbeddingType> ALL = EnumSet.allOf(CohereEmbeddingType.class);

private static final class RequestConstants {
private static final String FLOAT = "float";
private static final String INT8 = "int8";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;

import java.io.IOException;
import java.util.EnumSet;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -60,7 +59,7 @@ static CohereEmbeddingType parseEmbeddingType(
EMBEDDING_TYPE,
ModelConfigurations.SERVICE_SETTINGS,
CohereEmbeddingType::fromString,
EnumSet.allOf(CohereEmbeddingType.class),
CohereEmbeddingType.ALL,
validationException
),
CohereEmbeddingType.FLOAT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,13 @@ private static String supportedEmbeddingTypes() {
* </pre>
*/
public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException {
return fromResponse(response.body());
}

public static InferenceServiceResults fromResponse(byte[] body) throws IOException {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);

try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, body)) {
moveToFirstToken(jsonParser);

XContentParser.Token token = jsonParser.currentToken();
Expand Down
Loading
Loading