Skip to content

Commit 6aee35b

Browse files
committed
[ML] Bedrock Cohere support for embedding types
Add support for passing embedding types in the service settings, enabling float, int8, and binary embeddings returned in the response. Close #126526
1 parent 36280d2 commit 6aee35b

File tree

15 files changed

+824
-682
lines changed

15 files changed

+824
-682
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ static TransportVersion def(int id) {
158158
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15);
159159
public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16);
160160
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17);
161+
public static final TransportVersion AMAZON_BEDROCK_EMBEDDING_TYPES_8_19 = def(8_841_0_18);
161162
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
162163
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
163164
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
@@ -217,6 +218,7 @@ static TransportVersion def(int id) {
217218
public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG = def(9_047_00_0);
218219
public static final TransportVersion REPO_ANALYSIS_COPY_BLOB = def(9_048_00_0);
219220
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS = def(9_049_00_0);
221+
public static final TransportVersion AMAZON_BEDROCK_EMBEDDING_TYPES = def(9_050_00_0);
220222

221223
/*
222224
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,8 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
351351
serviceSettings.dimensionsSetByUser(),
352352
serviceSettings.maxInputTokens(),
353353
similarityToUse,
354-
serviceSettings.rateLimitSettings()
354+
serviceSettings.rateLimitSettings(),
355+
serviceSettings.embeddingType()
355356
);
356357

357358
return new AmazonBedrockEmbeddingsModel(embeddingsModel, updatedServiceSettings);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettings.java

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
99

10+
import org.elasticsearch.TransportVersions;
1011
import org.elasticsearch.common.ValidationException;
1112
import org.elasticsearch.common.io.stream.StreamInput;
1213
import org.elasticsearch.common.io.stream.StreamOutput;
@@ -19,6 +20,7 @@
1920
import org.elasticsearch.xpack.inference.services.ServiceUtils;
2021
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider;
2122
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockServiceSettings;
23+
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
2224
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
2325

2426
import java.io.IOException;
@@ -29,17 +31,20 @@
2931
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
3032
import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
3133
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
34+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
3235
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
3336
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
3437

3538
public class AmazonBedrockEmbeddingsServiceSettings extends AmazonBedrockServiceSettings {
3639
public static final String NAME = "amazon_bedrock_embeddings_service_settings";
40+
static final String EMBEDDING_TYPE = "embedding_type";
3741
static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user";
3842

3943
private final Integer dimensions;
4044
private final Boolean dimensionsSetByUser;
4145
private final Integer maxInputTokens;
4246
private final SimilarityMeasure similarity;
47+
private final CohereEmbeddingType embeddingType;
4348

4449
public static AmazonBedrockEmbeddingsServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
4550
ValidationException validationException = new ValidationException();
@@ -71,6 +76,15 @@ private static AmazonBedrockEmbeddingsServiceSettings embeddingSettingsFromMap(
7176

7277
Boolean dimensionsSetByUser = extractOptionalBoolean(map, DIMENSIONS_SET_BY_USER, validationException);
7378

79+
var embeddingType = extractOptionalEnum(
80+
map,
81+
EMBEDDING_TYPE,
82+
ModelConfigurations.SERVICE_SETTINGS,
83+
CohereEmbeddingType::fromString,
84+
CohereEmbeddingType.ALL,
85+
validationException
86+
);
87+
7488
switch (context) {
7589
case REQUEST -> {
7690
if (dimensionsSetByUser != null) {
@@ -102,7 +116,8 @@ private static AmazonBedrockEmbeddingsServiceSettings embeddingSettingsFromMap(
102116
dimensionsSetByUser,
103117
maxTokens,
104118
similarity,
105-
baseSettings.rateLimitSettings()
119+
baseSettings.rateLimitSettings(),
120+
embeddingType
106121
);
107122
}
108123

@@ -112,6 +127,9 @@ public AmazonBedrockEmbeddingsServiceSettings(StreamInput in) throws IOException
112127
dimensionsSetByUser = in.readBoolean();
113128
maxInputTokens = in.readOptionalVInt();
114129
similarity = in.readOptionalEnum(SimilarityMeasure.class);
130+
embeddingType = in.getTransportVersion().onOrAfter(TransportVersions.AMAZON_BEDROCK_EMBEDDING_TYPES)
131+
? in.readOptionalEnum(CohereEmbeddingType.class)
132+
: null;
115133
}
116134

117135
public AmazonBedrockEmbeddingsServiceSettings(
@@ -122,13 +140,15 @@ public AmazonBedrockEmbeddingsServiceSettings(
122140
Boolean dimensionsSetByUser,
123141
@Nullable Integer maxInputTokens,
124142
@Nullable SimilarityMeasure similarity,
125-
RateLimitSettings rateLimitSettings
143+
RateLimitSettings rateLimitSettings,
144+
@Nullable CohereEmbeddingType embeddingType
126145
) {
127146
super(region, model, provider, rateLimitSettings);
128147
this.dimensions = dimensions;
129148
this.dimensionsSetByUser = dimensionsSetByUser;
130149
this.maxInputTokens = maxInputTokens;
131150
this.similarity = similarity;
151+
this.embeddingType = embeddingType;
132152
}
133153

134154
@Override
@@ -138,6 +158,9 @@ public void writeTo(StreamOutput out) throws IOException {
138158
out.writeBoolean(dimensionsSetByUser);
139159
out.writeOptionalVInt(maxInputTokens);
140160
out.writeOptionalEnum(similarity);
161+
if (out.getTransportVersion().onOrAfter(TransportVersions.AMAZON_BEDROCK_EMBEDDING_TYPES)) {
162+
out.writeOptionalEnum(embeddingType);
163+
}
141164
}
142165

143166
@Override
@@ -169,6 +192,9 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil
169192
if (similarity != null) {
170193
builder.field(SIMILARITY, similarity);
171194
}
195+
if (embeddingType != null) {
196+
builder.field(EMBEDDING_TYPE, embeddingType);
197+
}
172198

173199
return builder;
174200
}
@@ -192,6 +218,10 @@ public Integer maxInputTokens() {
192218
return maxInputTokens;
193219
}
194220

221+
public CohereEmbeddingType embeddingType() {
222+
return embeddingType;
223+
}
224+
195225
@Override
196226
public DenseVectorFieldMapper.ElementType elementType() {
197227
return DenseVectorFieldMapper.ElementType.FLOAT;
@@ -210,12 +240,23 @@ public boolean equals(Object o) {
210240
&& Objects.equals(dimensionsSetByUser, that.dimensionsSetByUser)
211241
&& Objects.equals(maxInputTokens, that.maxInputTokens)
212242
&& Objects.equals(similarity, that.similarity)
213-
&& Objects.equals(rateLimitSettings, that.rateLimitSettings);
243+
&& Objects.equals(rateLimitSettings, that.rateLimitSettings)
244+
&& Objects.equals(embeddingType, that.embeddingType);
214245
}
215246

216247
@Override
217248
public int hashCode() {
218-
return Objects.hash(region, model, provider, dimensions, dimensionsSetByUser, maxInputTokens, similarity, rateLimitSettings);
249+
return Objects.hash(
250+
region,
251+
model,
252+
provider,
253+
dimensions,
254+
dimensionsSetByUser,
255+
maxInputTokens,
256+
similarity,
257+
rateLimitSettings,
258+
embeddingType
259+
);
219260
}
220261

221262
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockCohereEmbeddingsRequestEntity.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.xcontent.ToXContentObject;
1313
import org.elasticsearch.xcontent.XContentBuilder;
1414
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettings;
15+
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
1516

1617
import java.io.IOException;
1718
import java.util.List;
@@ -22,7 +23,8 @@
2223
public record AmazonBedrockCohereEmbeddingsRequestEntity(
2324
List<String> input,
2425
@Nullable InputType inputType,
25-
AmazonBedrockEmbeddingsTaskSettings taskSettings
26+
AmazonBedrockEmbeddingsTaskSettings taskSettings,
27+
@Nullable CohereEmbeddingType embeddingType
2628
) implements ToXContentObject {
2729

2830
private static final String TEXTS_FIELD = "texts";
@@ -32,6 +34,7 @@ public record AmazonBedrockCohereEmbeddingsRequestEntity(
3234
private static final String CLUSTERING = "clustering";
3335
private static final String CLASSIFICATION = "classification";
3436
private static final String TRUNCATE = "truncate";
37+
private static final String EMBEDDING_TYPES = "embedding_types";
3538

3639
public AmazonBedrockCohereEmbeddingsRequestEntity {
3740
Objects.requireNonNull(input);
@@ -54,6 +57,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
5457
builder.field(TRUNCATE, taskSettings.cohereTruncation().name());
5558
}
5659

60+
if (embeddingType != null) {
61+
builder.field(EMBEDDING_TYPES, List.of(embeddingType.toRequestString()));
62+
}
63+
5764
builder.endObject();
5865
return builder;
5966
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockEmbeddingsEntityFactory.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,12 @@ public static ToXContent createEntity(
3939
return new AmazonBedrockTitanEmbeddingsRequestEntity(truncatedInput.get(0));
4040
}
4141
case COHERE -> {
42-
return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput, inputType, model.getTaskSettings());
42+
return new AmazonBedrockCohereEmbeddingsRequestEntity(
43+
truncatedInput,
44+
inputType,
45+
model.getTaskSettings(),
46+
model.getServiceSettings().embeddingType()
47+
);
4348
}
4449
default -> {
4550
return null;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java

Lines changed: 18 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.AmazonBedrockRequest;
2323
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.embeddings.AmazonBedrockEmbeddingsRequest;
2424
import org.elasticsearch.xpack.inference.services.amazonbedrock.response.AmazonBedrockResponse;
25+
import org.elasticsearch.xpack.inference.services.cohere.response.CohereEmbeddingsResponseEntity;
2526

2627
import java.io.IOException;
2728
import java.nio.charset.StandardCharsets;
@@ -48,13 +49,25 @@ public InferenceServiceResults accept(AmazonBedrockRequest request) {
4849
throw new ElasticsearchException("unexpected request type [" + request.getClass() + "]");
4950
}
5051

51-
public static TextEmbeddingFloatResults fromResponse(InvokeModelResponse response, AmazonBedrockProvider provider) {
52+
private static InferenceServiceResults fromResponse(InvokeModelResponse response, AmazonBedrockProvider provider) {
5253
var charset = StandardCharsets.UTF_8;
5354
var bodyText = String.valueOf(charset.decode(response.body().asByteBuffer()));
5455

56+
try {
57+
if (provider == AmazonBedrockProvider.COHERE) {
58+
return CohereEmbeddingsResponseEntity.fromResponse(bodyText.getBytes(StandardCharsets.UTF_8));
59+
} else {
60+
return fromResponse(bodyText, provider);
61+
}
62+
} catch (IOException e) {
63+
throw new ElasticsearchException(e);
64+
}
65+
}
66+
67+
private static TextEmbeddingFloatResults fromResponse(String response, AmazonBedrockProvider provider) {
5568
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
5669

57-
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, bodyText)) {
70+
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response)) {
5871
// move to the first token
5972
jsonParser.nextToken();
6073

@@ -71,15 +84,10 @@ public static TextEmbeddingFloatResults fromResponse(InvokeModelResponse respons
7184

7285
private static List<TextEmbeddingFloatResults.Embedding> parseEmbeddings(XContentParser jsonParser, AmazonBedrockProvider provider)
7386
throws IOException {
74-
switch (provider) {
75-
case AMAZONTITAN -> {
76-
return parseTitanEmbeddings(jsonParser);
77-
}
78-
case COHERE -> {
79-
return parseCohereEmbeddings(jsonParser);
80-
}
81-
default -> throw new IOException("Unsupported provider [" + provider + "]");
87+
if (provider == AmazonBedrockProvider.AMAZONTITAN) {
88+
return parseTitanEmbeddings(jsonParser);
8289
}
90+
throw new IOException("Unsupported provider [" + provider + "]");
8391
}
8492

8593
private static List<TextEmbeddingFloatResults.Embedding> parseTitanEmbeddings(XContentParser parser) throws IOException {
@@ -96,32 +104,4 @@ private static List<TextEmbeddingFloatResults.Embedding> parseTitanEmbeddings(XC
96104
return List.of(embeddingValues);
97105
}
98106

99-
private static List<TextEmbeddingFloatResults.Embedding> parseCohereEmbeddings(XContentParser parser) throws IOException {
100-
/*
101-
Cohere response:
102-
{
103-
"embeddings": [
104-
[< array of 1024 floats >],
105-
...
106-
],
107-
"id": string,
108-
"response_type" : "embeddings_floats",
109-
"texts": [string]
110-
}
111-
*/
112-
positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE);
113-
114-
List<TextEmbeddingFloatResults.Embedding> embeddingList = parseList(
115-
parser,
116-
AmazonBedrockEmbeddingsResponse::parseCohereEmbeddingsListItem
117-
);
118-
119-
return embeddingList;
120-
}
121-
122-
private static TextEmbeddingFloatResults.Embedding parseCohereEmbeddingsListItem(XContentParser parser) throws IOException {
123-
List<Float> embeddingValuesList = parseList(parser, XContentUtils::parseFloat);
124-
return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList);
125-
}
126-
127107
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ public enum CohereEmbeddingType {
4646
*/
4747
BINARY(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BIT);
4848

49+
public static EnumSet<CohereEmbeddingType> ALL = EnumSet.allOf(CohereEmbeddingType.class);
50+
4951
private static final class RequestConstants {
5052
private static final String FLOAT = "float";
5153
private static final String INT8 = "int8";

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
2424

2525
import java.io.IOException;
26-
import java.util.EnumSet;
2726
import java.util.Locale;
2827
import java.util.Map;
2928
import java.util.Objects;
@@ -60,7 +59,7 @@ static CohereEmbeddingType parseEmbeddingType(
6059
EMBEDDING_TYPE,
6160
ModelConfigurations.SERVICE_SETTINGS,
6261
CohereEmbeddingType::fromString,
63-
EnumSet.allOf(CohereEmbeddingType.class),
62+
CohereEmbeddingType.ALL,
6463
validationException
6564
),
6665
CohereEmbeddingType.FLOAT

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/response/CohereEmbeddingsResponseEntity.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,13 @@ private static String supportedEmbeddingTypes() {
137137
* </pre>
138138
*/
139139
public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException {
140+
return fromResponse(response.body());
141+
}
142+
143+
public static InferenceServiceResults fromResponse(byte[] body) throws IOException {
140144
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
141145

142-
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
146+
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, body)) {
143147
moveToFirstToken(jsonParser);
144148

145149
XContentParser.Token token = jsonParser.currentToken();

0 commit comments

Comments
 (0)