Skip to content

Commit 0dad969

Browse files
authored
[Inference API] Propagate dimensions for dense text embedding generation using EIS (elastic#137328) (elastic#137364)
1 parent 47862e8 commit 0dad969

File tree

3 files changed

+45
-9
lines changed

3 files changed

+45
-9
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequest.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,12 @@ public HttpRequestBase createHttpRequestBase() {
5656
var httpPost = new HttpPost(uri);
5757
var usageContext = ElasticInferenceServiceUsageContext.fromInputType(inputType);
5858
var requestEntity = Strings.toString(
59-
new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(inputs, model.getServiceSettings().modelId(), usageContext)
59+
new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(
60+
inputs,
61+
model.getServiceSettings().modelId(),
62+
usageContext,
63+
model.getServiceSettings().dimensions()
64+
)
6065
);
6166

6267
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestEntity.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@
1919
public record ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(
2020
List<String> inputs,
2121
String modelId,
22-
@Nullable ElasticInferenceServiceUsageContext usageContext
22+
@Nullable ElasticInferenceServiceUsageContext usageContext,
23+
@Nullable Integer dimensions
2324
) implements ToXContentObject {
2425

2526
private static final String INPUT_FIELD = "input";
2627
private static final String MODEL_FIELD = "model";
2728
private static final String USAGE_CONTEXT = "usage_context";
29+
private static final String DIMENSIONS = "dimensions";
2830

2931
public ElasticInferenceServiceDenseTextEmbeddingsRequestEntity {
3032
Objects.requireNonNull(inputs);
@@ -49,6 +51,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
4951
builder.field(USAGE_CONTEXT, usageContext);
5052
}
5153

54+
// optional field
55+
if (Objects.nonNull(dimensions)) {
56+
builder.field(DIMENSIONS, dimensions);
57+
}
58+
5259
builder.endObject();
5360

5461
return builder;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestEntityTests.java

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ public void testToXContent_SingleInput_UnspecifiedUsageContext() throws IOExcept
2525
var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(
2626
List.of("abc"),
2727
"my-model-id",
28-
ElasticInferenceServiceUsageContext.UNSPECIFIED
28+
ElasticInferenceServiceUsageContext.UNSPECIFIED,
29+
null
2930
);
3031
String xContentString = xContentEntityToString(entity);
3132
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
@@ -39,7 +40,8 @@ public void testToXContent_MultipleInputs_UnspecifiedUsageContext() throws IOExc
3940
var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(
4041
List.of("abc", "def"),
4142
"my-model-id",
42-
ElasticInferenceServiceUsageContext.UNSPECIFIED
43+
ElasticInferenceServiceUsageContext.UNSPECIFIED,
44+
null
4345
);
4446
String xContentString = xContentEntityToString(entity);
4547
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
@@ -57,7 +59,8 @@ public void testToXContent_SingleInput_SearchUsageContext() throws IOException {
5759
var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(
5860
List.of("abc"),
5961
"my-model-id",
60-
ElasticInferenceServiceUsageContext.SEARCH
62+
ElasticInferenceServiceUsageContext.SEARCH,
63+
null
6164
);
6265
String xContentString = xContentEntityToString(entity);
6366
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
@@ -73,7 +76,8 @@ public void testToXContent_SingleInput_IngestUsageContext() throws IOException {
7376
var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(
7477
List.of("abc"),
7578
"my-model-id",
76-
ElasticInferenceServiceUsageContext.INGEST
79+
ElasticInferenceServiceUsageContext.INGEST,
80+
null
7781
);
7882
String xContentString = xContentEntityToString(entity);
7983
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
@@ -85,11 +89,29 @@ public void testToXContent_SingleInput_IngestUsageContext() throws IOException {
8589
"""));
8690
}
8791

92+
public void testToXContent_SingleInput_DimensionsSpecified() throws IOException {
93+
var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(
94+
List.of("abc"),
95+
"my-model-id",
96+
ElasticInferenceServiceUsageContext.UNSPECIFIED,
97+
100
98+
);
99+
String xContentString = xContentEntityToString(entity);
100+
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
101+
{
102+
"input": ["abc"],
103+
"model": "my-model-id",
104+
"dimensions": 100
105+
}
106+
"""));
107+
}
108+
88109
public void testToXContent_MultipleInputs_SearchUsageContext() throws IOException {
89110
var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(
90111
List.of("first input", "second input", "third input"),
91112
"my-dense-model",
92-
ElasticInferenceServiceUsageContext.SEARCH
113+
ElasticInferenceServiceUsageContext.SEARCH,
114+
null
93115
);
94116
String xContentString = xContentEntityToString(entity);
95117
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
@@ -109,7 +131,8 @@ public void testToXContent_MultipleInputs_IngestUsageContext() throws IOExceptio
109131
var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(
110132
List.of("document one", "document two"),
111133
"embedding-model-v2",
112-
ElasticInferenceServiceUsageContext.INGEST
134+
ElasticInferenceServiceUsageContext.INGEST,
135+
null
113136
);
114137
String xContentString = xContentEntityToString(entity);
115138
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
@@ -128,7 +151,8 @@ public void testToXContent_EmptyInput_UnspecifiedUsageContext() throws IOExcepti
128151
var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(
129152
List.of(""),
130153
"my-model-id",
131-
ElasticInferenceServiceUsageContext.UNSPECIFIED
154+
ElasticInferenceServiceUsageContext.UNSPECIFIED,
155+
null
132156
);
133157
String xContentString = xContentEntityToString(entity);
134158
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""

0 commit comments

Comments
 (0)