Skip to content

Commit 04cdac2

Browse files
authored
[8.x][Inference API] Pass model ID in sparse model inference request body (#121125)
1 parent 44c7fea commit 04cdac2

10 files changed

+69
-82
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequest.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,11 @@ public HttpRequest createHttpRequest() {
5555
var httpPost = new HttpPost(uri);
5656
var usageContext = inputTypeToUsageContext(inputType);
5757
var requestEntity = Strings.toString(
58-
new ElasticInferenceServiceSparseEmbeddingsRequestEntity(truncationResult.input(), usageContext)
58+
new ElasticInferenceServiceSparseEmbeddingsRequestEntity(
59+
truncationResult.input(),
60+
model.getServiceSettings().modelId(),
61+
usageContext
62+
)
5963
);
6064

6165
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestEntity.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,34 @@
1818

1919
public record ElasticInferenceServiceSparseEmbeddingsRequestEntity(
2020
List<String> inputs,
21+
String modelId,
2122
@Nullable ElasticInferenceServiceUsageContext usageContext
2223
) implements ToXContentObject {
2324

2425
private static final String INPUT_FIELD = "input";
26+
27+
private static final String MODEL_ID_FIELD = "model_id";
28+
2529
private static final String USAGE_CONTEXT = "usage_context";
2630

2731
public ElasticInferenceServiceSparseEmbeddingsRequestEntity {
2832
Objects.requireNonNull(inputs);
33+
Objects.requireNonNull(modelId);
2934
}
3035

3136
@Override
3237
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
3338
builder.startObject();
3439
builder.startArray(INPUT_FIELD);
3540

36-
{
37-
for (String input : inputs) {
38-
builder.value(input);
39-
}
41+
for (String input : inputs) {
42+
builder.value(input);
4043
}
4144

4245
builder.endArray();
4346

47+
builder.field(MODEL_ID_FIELD, modelId);
48+
4449
// optional field
4550
if ((usageContext == ElasticInferenceServiceUsageContext.UNSPECIFIED) == false) {
4651
builder.field(USAGE_CONTEXT, usageContext);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModel.java

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,11 @@
2020
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
2121
import org.elasticsearch.xpack.inference.external.action.elastic.ElasticInferenceServiceActionVisitor;
2222
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
23-
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels;
2423

2524
import java.net.URI;
2625
import java.net.URISyntaxException;
27-
import java.util.Locale;
2826
import java.util.Map;
2927

30-
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER;
31-
3228
public class ElasticInferenceServiceSparseEmbeddingsModel extends ElasticInferenceServiceExecutableActionModel {
3329

3430
private final URI uri;
@@ -95,36 +91,15 @@ public URI uri() {
9591
}
9692

9793
private URI createUri() throws ElasticsearchStatusException {
98-
String modelId = getServiceSettings().modelId();
99-
String modelIdUriPath;
100-
101-
switch (modelId) {
102-
case ElserModels.ELSER_V2_MODEL -> modelIdUriPath = "ELSERv2";
103-
default -> throw new ElasticsearchStatusException(
104-
String.format(
105-
Locale.ROOT,
106-
"Unsupported model [%s] for service [%s] and task type [%s]",
107-
modelId,
108-
ELASTIC_INFERENCE_SERVICE_IDENTIFIER,
109-
TaskType.SPARSE_EMBEDDING
110-
),
111-
RestStatus.BAD_REQUEST
112-
);
113-
}
114-
11594
try {
11695
// TODO, consider transforming the base URL into a URI for better error handling.
117-
return new URI(
118-
elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/embed/text/sparse/" + modelIdUriPath
119-
);
96+
return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/embed/text/sparse");
12097
} catch (URISyntaxException e) {
12198
throw new ElasticsearchStatusException(
12299
"Failed to create URI for service ["
123100
+ this.getConfigurations().getService()
124101
+ "] with taskType ["
125102
+ this.getTaskType()
126-
+ "] with model ["
127-
+ this.getServiceSettings().modelId()
128103
+ "]: "
129104
+ e.getMessage(),
130105
RestStatus.BAD_REQUEST,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettings.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import org.elasticsearch.inference.ServiceSettings;
1818
import org.elasticsearch.xcontent.XContentBuilder;
1919
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
20-
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels;
2120
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
2221
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
2322

@@ -61,10 +60,6 @@ public static ElasticInferenceServiceSparseEmbeddingsServiceSettings fromMap(
6160
context
6261
);
6362

64-
if (modelId != null && ElserModels.isValidEisModel(modelId) == false) {
65-
validationException.addValidationError("unknown ELSER model id [" + modelId + "]");
66-
}
67-
6863
if (validationException.validationErrors().isEmpty() == false) {
6964
throw validationException;
7065
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionCreatorTests.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce
9090

9191
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
9292

93-
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer));
93+
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), "my-model-id");
9494
var actionCreator = new ElasticInferenceServiceActionCreator(
9595
sender,
9696
createWithEmptySettings(threadPool),
@@ -120,10 +120,11 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce
120120
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
121121

122122
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
123-
assertThat(requestMap.size(), is(1));
123+
assertThat(requestMap.size(), is(2));
124124
assertThat(requestMap.get("input"), instanceOf(List.class));
125125
var inputList = (List<String>) requestMap.get("input");
126126
assertThat(inputList, contains("hello world"));
127+
assertThat(requestMap.get("model_id"), is("my-model-id"));
127128
}
128129
}
129130

@@ -151,7 +152,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx
151152

152153
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
153154

154-
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer));
155+
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), "my-model-id");
155156
var actionCreator = new ElasticInferenceServiceActionCreator(
156157
sender,
157158
createWithEmptySettings(threadPool),
@@ -174,10 +175,11 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx
174175
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
175176

176177
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
177-
assertThat(requestMap.size(), is(1));
178+
assertThat(requestMap.size(), is(2));
178179
assertThat(requestMap.get("input"), instanceOf(List.class));
179180
var inputList = (List<String>) requestMap.get("input");
180181
assertThat(inputList, contains("hello world"));
182+
assertThat(requestMap.get("model_id"), is("my-model-id"));
181183
}
182184
}
183185

@@ -208,7 +210,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc
208210
webServer.enqueue(new MockResponse().setResponseCode(413).setBody(responseJsonContentTooLarge));
209211
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
210212

211-
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer));
213+
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), "my-model-id");
212214
var actionCreator = new ElasticInferenceServiceActionCreator(
213215
sender,
214216
createWithEmptySettings(threadPool),
@@ -273,7 +275,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException {
273275
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
274276

275277
// truncated to 1 token = 3 characters
276-
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), 1);
278+
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), "my-model-id", 1);
277279
var actionCreator = new ElasticInferenceServiceActionCreator(
278280
sender,
279281
createWithEmptySettings(threadPool),

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestEntityTests.java

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,21 @@ public class ElasticInferenceServiceSparseEmbeddingsRequestEntityTests extends E
2424
public void testToXContent_SingleInput_UnspecifiedUsageContext() throws IOException {
2525
var entity = new ElasticInferenceServiceSparseEmbeddingsRequestEntity(
2626
List.of("abc"),
27+
"my-model-id",
2728
ElasticInferenceServiceUsageContext.UNSPECIFIED
2829
);
2930
String xContentString = xContentEntityToString(entity);
3031
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
3132
{
32-
"input": ["abc"]
33+
"input": ["abc"],
34+
"model_id": "my-model-id"
3335
}"""));
3436
}
3537

3638
public void testToXContent_MultipleInputs_UnspecifiedUsageContext() throws IOException {
3739
var entity = new ElasticInferenceServiceSparseEmbeddingsRequestEntity(
3840
List.of("abc", "def"),
41+
"my-model-id",
3942
ElasticInferenceServiceUsageContext.UNSPECIFIED
4043
);
4144
String xContentString = xContentEntityToString(entity);
@@ -44,28 +47,39 @@ public void testToXContent_MultipleInputs_UnspecifiedUsageContext() throws IOExc
4447
"input": [
4548
"abc",
4649
"def"
47-
]
50+
],
51+
"model_id": "my-model-id"
4852
}
4953
"""));
5054
}
5155

5256
public void testToXContent_MultipleInputs_SearchUsageContext() throws IOException {
53-
var entity = new ElasticInferenceServiceSparseEmbeddingsRequestEntity(List.of("abc"), ElasticInferenceServiceUsageContext.SEARCH);
57+
var entity = new ElasticInferenceServiceSparseEmbeddingsRequestEntity(
58+
List.of("abc"),
59+
"my-model-id",
60+
ElasticInferenceServiceUsageContext.SEARCH
61+
);
5462
String xContentString = xContentEntityToString(entity);
5563
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
5664
{
5765
"input": ["abc"],
66+
"model_id": "my-model-id",
5867
"usage_context": "search"
5968
}
6069
"""));
6170
}
6271

6372
public void testToXContent_MultipleInputs_IngestUsageContext() throws IOException {
64-
var entity = new ElasticInferenceServiceSparseEmbeddingsRequestEntity(List.of("abc"), ElasticInferenceServiceUsageContext.INGEST);
73+
var entity = new ElasticInferenceServiceSparseEmbeddingsRequestEntity(
74+
List.of("abc"),
75+
"my-model-id",
76+
ElasticInferenceServiceUsageContext.INGEST
77+
);
6578
String xContentString = xContentEntityToString(entity);
6679
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
6780
{
6881
"input": ["abc"],
82+
"model_id": "my-model-id",
6983
"usage_context": "ingest"
7084
}
7185
"""));

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestTests.java

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,25 +34,29 @@ public class ElasticInferenceServiceSparseEmbeddingsRequestTests extends ESTestC
3434
public void testCreateHttpRequest_UsageContextSearch() throws IOException {
3535
var url = "http://eis-gateway.com";
3636
var input = "input";
37+
var modelId = "my-model-id";
3738

38-
var request = createRequest(url, input, InputType.SEARCH);
39+
var request = createRequest(url, modelId, input, InputType.SEARCH);
3940
var httpRequest = request.createHttpRequest();
4041

4142
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
4243
var httpPost = (HttpPost) httpRequest.httpRequestBase();
4344

4445
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
4546
var requestMap = entityAsMap(httpPost.getEntity().getContent());
46-
assertThat(requestMap.size(), equalTo(2));
47+
48+
assertThat(requestMap.size(), equalTo(3));
4749
assertThat(requestMap.get("input"), is(List.of(input)));
50+
assertThat(requestMap.get("model_id"), is(modelId));
4851
assertThat(requestMap.get("usage_context"), equalTo("search"));
4952
}
5053

5154
public void testTraceContextPropagatedThroughHTTPHeaders() {
5255
var url = "http://eis-gateway.com";
5356
var input = "input";
57+
var modelId = "my-model-id";
5458

55-
var request = createRequest(url, input, InputType.UNSPECIFIED);
59+
var request = createRequest(url, modelId, input, InputType.UNSPECIFIED);
5660
var httpRequest = request.createHttpRequest();
5761

5862
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@@ -68,24 +72,27 @@ public void testTraceContextPropagatedThroughHTTPHeaders() {
6872
public void testTruncate_ReducesInputTextSizeByHalf() throws IOException {
6973
var url = "http://eis-gateway.com";
7074
var input = "abcd";
75+
var modelId = "my-model-id";
7176

72-
var request = createRequest(url, input, InputType.UNSPECIFIED);
77+
var request = createRequest(url, modelId, input, InputType.UNSPECIFIED);
7378
var truncatedRequest = request.truncate();
7479

7580
var httpRequest = truncatedRequest.createHttpRequest();
7681
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
7782

7883
var httpPost = (HttpPost) httpRequest.httpRequestBase();
7984
var requestMap = entityAsMap(httpPost.getEntity().getContent());
80-
assertThat(requestMap, aMapWithSize(1));
85+
assertThat(requestMap, aMapWithSize(2));
8186
assertThat(requestMap.get("input"), is(List.of("ab")));
87+
assertThat(requestMap.get("model_id"), is(modelId));
8288
}
8389

8490
public void testIsTruncated_ReturnsTrue() {
8591
var url = "http://eis-gateway.com";
8692
var input = "abcd";
93+
var modelId = "my-model-id";
8794

88-
var request = createRequest(url, input, InputType.UNSPECIFIED);
95+
var request = createRequest(url, modelId, input, InputType.UNSPECIFIED);
8996
assertFalse(request.getTruncationInfo()[0]);
9097

9198
var truncatedRequest = request.truncate();
@@ -109,8 +116,8 @@ public void testInputTypeToUsageContext_Unknown_DefaultToUnspecified() {
109116
assertThat(inputTypeToUsageContext(InputType.CLUSTERING), equalTo(ElasticInferenceServiceUsageContext.UNSPECIFIED));
110117
}
111118

112-
public ElasticInferenceServiceSparseEmbeddingsRequest createRequest(String url, String input, InputType inputType) {
113-
var embeddingsModel = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(url);
119+
public ElasticInferenceServiceSparseEmbeddingsRequest createRequest(String url, String modelId, String input, InputType inputType) {
120+
var embeddingsModel = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(url, modelId);
114121

115122
return new ElasticInferenceServiceSparseEmbeddingsRequest(
116123
TruncatorTests.createTruncator(),

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,19 @@
1111
import org.elasticsearch.inference.EmptyTaskSettings;
1212
import org.elasticsearch.inference.TaskType;
1313
import org.elasticsearch.test.ESTestCase;
14-
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels;
1514

1615
public class ElasticInferenceServiceSparseEmbeddingsModelTests extends ESTestCase {
1716

18-
public static ElasticInferenceServiceSparseEmbeddingsModel createModel(String url) {
19-
return createModel(url, null);
17+
public static ElasticInferenceServiceSparseEmbeddingsModel createModel(String url, String modelId) {
18+
return createModel(url, modelId, null);
2019
}
2120

22-
public static ElasticInferenceServiceSparseEmbeddingsModel createModel(String url, Integer maxInputTokens) {
21+
public static ElasticInferenceServiceSparseEmbeddingsModel createModel(String url, String modelId, Integer maxInputTokens) {
2322
return new ElasticInferenceServiceSparseEmbeddingsModel(
2423
"id",
2524
TaskType.SPARSE_EMBEDDING,
2625
"service",
27-
new ElasticInferenceServiceSparseEmbeddingsServiceSettings(ElserModels.ELSER_V2_MODEL, maxInputTokens, null),
26+
new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, maxInputTokens, null),
2827
EmptyTaskSettings.INSTANCE,
2928
EmptySecretSettings.INSTANCE,
3029
new ElasticInferenceServiceComponents(url)

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
package org.elasticsearch.xpack.inference.services.elastic;
99

1010
import org.elasticsearch.common.Strings;
11-
import org.elasticsearch.common.ValidationException;
1211
import org.elasticsearch.common.io.stream.Writeable;
1312
import org.elasticsearch.test.AbstractWireSerializingTestCase;
1413
import org.elasticsearch.xcontent.XContentBuilder;
@@ -23,7 +22,6 @@
2322
import java.util.Map;
2423

2524
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModelsTests.randomElserModel;
26-
import static org.hamcrest.Matchers.containsString;
2725
import static org.hamcrest.Matchers.is;
2826

2927
public class ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase<
@@ -47,7 +45,7 @@ protected ElasticInferenceServiceSparseEmbeddingsServiceSettings mutateInstance(
4745
}
4846

4947
public void testFromMap() {
50-
var modelId = ElserModels.ELSER_V2_MODEL;
48+
var modelId = "my-model-id";
5149

5250
var serviceSettings = ElasticInferenceServiceSparseEmbeddingsServiceSettings.fromMap(
5351
new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)),
@@ -57,20 +55,6 @@ public void testFromMap() {
5755
assertThat(serviceSettings, is(new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, null, null)));
5856
}
5957

60-
public void testFromMap_InvalidElserModelId() {
61-
var invalidModelId = "invalid";
62-
63-
ValidationException validationException = expectThrows(
64-
ValidationException.class,
65-
() -> ElasticInferenceServiceSparseEmbeddingsServiceSettings.fromMap(
66-
new HashMap<>(Map.of(ServiceFields.MODEL_ID, invalidModelId)),
67-
ConfigurationParseContext.REQUEST
68-
)
69-
);
70-
71-
assertThat(validationException.getMessage(), containsString(Strings.format("unknown ELSER model id [%s]", invalidModelId)));
72-
}
73-
7458
public void testToXContent_WritesAllFields() throws IOException {
7559
var modelId = ElserModels.ELSER_V1_MODEL;
7660
var maxInputTokens = 10;

0 commit comments

Comments
 (0)