Skip to content

Commit 4c7de69

Browse files
committed
Pass model ID in sparse model inference request body
1 parent 730ec3e commit 4c7de69

10 files changed

+150
-89
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequest.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,14 @@ public ElasticInferenceServiceSparseEmbeddingsRequest(
4848
@Override
4949
public HttpRequest createHttpRequest() {
5050
var httpPost = new HttpPost(uri);
51-
var requestEntity = Strings.toString(new ElasticInferenceServiceSparseEmbeddingsRequestEntity(truncationResult.input()));
51+
var usageContext = inputTypeToUsageContext(inputType);
52+
var requestEntity = Strings.toString(
53+
new ElasticInferenceServiceSparseEmbeddingsRequestEntity(
54+
truncationResult.input(),
55+
model.getServiceSettings().modelId(),
56+
usageContext
57+
)
58+
);
5259

5360
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
5461
httpPost.setEntity(byteEntity);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestEntity.java

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,39 @@
1414
import java.util.List;
1515
import java.util.Objects;
1616

17-
public record ElasticInferenceServiceSparseEmbeddingsRequestEntity(List<String> inputs) implements ToXContentObject {
17+
public record ElasticInferenceServiceSparseEmbeddingsRequestEntity(
18+
List<String> inputs,
19+
String modelId,
20+
@Nullable ElasticInferenceServiceUsageContext usageContext
21+
) implements ToXContentObject {
1822

1923
private static final String INPUT_FIELD = "input";
24+
private static final String MODEL_ID_FIELD = "model_id";
25+
private static final String USAGE_CONTEXT = "usage_context";
2026

2127
public ElasticInferenceServiceSparseEmbeddingsRequestEntity {
2228
Objects.requireNonNull(inputs);
29+
Objects.requireNonNull(modelId);
2330
}
2431

2532
@Override
2633
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
2734
builder.startObject();
2835
builder.startArray(INPUT_FIELD);
2936

30-
{
31-
for (String input : inputs) {
32-
builder.value(input);
33-
}
37+
for (String input : inputs) {
38+
builder.value(input);
3439
}
3540

3641
builder.endArray();
42+
43+
builder.field(MODEL_ID_FIELD, modelId);
44+
45+
// optional field
46+
if ((usageContext == ElasticInferenceServiceUsageContext.UNSPECIFIED) == false) {
47+
builder.field(USAGE_CONTEXT, usageContext);
48+
}
49+
3750
builder.endObject();
3851

3952
return builder;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModel.java

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,11 @@
2020
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
2121
import org.elasticsearch.xpack.inference.external.action.elastic.ElasticInferenceServiceActionVisitor;
2222
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
23-
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels;
2423

2524
import java.net.URI;
2625
import java.net.URISyntaxException;
27-
import java.util.Locale;
2826
import java.util.Map;
2927

30-
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER;
31-
3228
public class ElasticInferenceServiceSparseEmbeddingsModel extends ElasticInferenceServiceExecutableActionModel {
3329

3430
private final URI uri;
@@ -95,36 +91,15 @@ public URI uri() {
9591
}
9692

9793
private URI createUri() throws ElasticsearchStatusException {
98-
String modelId = getServiceSettings().modelId();
99-
String modelIdUriPath;
100-
101-
switch (modelId) {
102-
case ElserModels.ELSER_V2_MODEL -> modelIdUriPath = "ELSERv2";
103-
default -> throw new ElasticsearchStatusException(
104-
String.format(
105-
Locale.ROOT,
106-
"Unsupported model [%s] for service [%s] and task type [%s]",
107-
modelId,
108-
ELASTIC_INFERENCE_SERVICE_IDENTIFIER,
109-
TaskType.SPARSE_EMBEDDING
110-
),
111-
RestStatus.BAD_REQUEST
112-
);
113-
}
114-
11594
try {
11695
// TODO, consider transforming the base URL into a URI for better error handling.
117-
return new URI(
118-
elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/embed/text/sparse/" + modelIdUriPath
119-
);
96+
return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/embed/text/sparse");
12097
} catch (URISyntaxException e) {
12198
throw new ElasticsearchStatusException(
12299
"Failed to create URI for service ["
123100
+ this.getConfigurations().getService()
124101
+ "] with taskType ["
125102
+ this.getTaskType()
126-
+ "] with model ["
127-
+ this.getServiceSettings().modelId()
128103
+ "]: "
129104
+ e.getMessage(),
130105
RestStatus.BAD_REQUEST,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettings.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import org.elasticsearch.inference.ServiceSettings;
1818
import org.elasticsearch.xcontent.XContentBuilder;
1919
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
20-
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels;
2120
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
2221
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
2322

@@ -61,10 +60,6 @@ public static ElasticInferenceServiceSparseEmbeddingsServiceSettings fromMap(
6160
context
6261
);
6362

64-
if (modelId != null && ElserModels.isValidEisModel(modelId) == false) {
65-
validationException.addValidationError("unknown ELSER model id [" + modelId + "]");
66-
}
67-
6863
if (validationException.validationErrors().isEmpty() == false) {
6964
throw validationException;
7065
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionCreatorTests.java

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,13 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce
8989

9090
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
9191

92-
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer));
93-
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
92+
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), "my-model-id");
93+
var actionCreator = new ElasticInferenceServiceActionCreator(
94+
sender,
95+
createWithEmptySettings(threadPool),
96+
createTraceContext(),
97+
InputType.UNSPECIFIED
98+
);
9499
var action = actionCreator.create(model);
95100

96101
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
@@ -114,10 +119,11 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce
114119
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
115120

116121
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
117-
assertThat(requestMap.size(), is(1));
122+
assertThat(requestMap.size(), is(2));
118123
assertThat(requestMap.get("input"), instanceOf(List.class));
119124
var inputList = (List<String>) requestMap.get("input");
120125
assertThat(inputList, contains("hello world"));
126+
assertThat(requestMap.get("model_id"), is("my-model-id"));
121127
}
122128
}
123129

@@ -145,8 +151,13 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx
145151

146152
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
147153

148-
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer));
149-
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
154+
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), "my-model-id");
155+
var actionCreator = new ElasticInferenceServiceActionCreator(
156+
sender,
157+
createWithEmptySettings(threadPool),
158+
createTraceContext(),
159+
InputType.UNSPECIFIED
160+
);
150161
var action = actionCreator.create(model);
151162

152163
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
@@ -163,10 +174,11 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx
163174
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
164175

165176
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
166-
assertThat(requestMap.size(), is(1));
177+
assertThat(requestMap.size(), is(2));
167178
assertThat(requestMap.get("input"), instanceOf(List.class));
168179
var inputList = (List<String>) requestMap.get("input");
169180
assertThat(inputList, contains("hello world"));
181+
assertThat(requestMap.get("model_id"), is("my-model-id"));
170182
}
171183
}
172184

@@ -197,8 +209,13 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc
197209
webServer.enqueue(new MockResponse().setResponseCode(413).setBody(responseJsonContentTooLarge));
198210
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
199211

200-
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer));
201-
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
212+
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), "my-model-id");
213+
var actionCreator = new ElasticInferenceServiceActionCreator(
214+
sender,
215+
createWithEmptySettings(threadPool),
216+
createTraceContext(),
217+
InputType.UNSPECIFIED
218+
);
202219
var action = actionCreator.create(model);
203220

204221
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
@@ -257,8 +274,13 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException {
257274
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
258275

259276
// truncated to 1 token = 3 characters
260-
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), 1);
261-
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
277+
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), "my-model-id", 1);
278+
var actionCreator = new ElasticInferenceServiceActionCreator(
279+
sender,
280+
createWithEmptySettings(threadPool),
281+
createTraceContext(),
282+
InputType.UNSPECIFIED
283+
);
262284
var action = actionCreator.create(model);
263285

264286
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestEntityTests.java

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,66 @@
2020

2121
public class ElasticInferenceServiceSparseEmbeddingsRequestEntityTests extends ESTestCase {
2222

23-
public void testToXContent_SingleInput() throws IOException {
24-
var entity = new ElasticInferenceServiceSparseEmbeddingsRequestEntity(List.of("abc"));
23+
public void testToXContent_SingleInput_UnspecifiedUsageContext() throws IOException {
24+
var entity = new ElasticInferenceServiceSparseEmbeddingsRequestEntity(
25+
List.of("abc"),
26+
"my-model-id",
27+
ElasticInferenceServiceUsageContext.UNSPECIFIED
28+
);
2529
String xContentString = xContentEntityToString(entity);
2630
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
2731
{
28-
"input": ["abc"]
32+
"input": ["abc"],
33+
"model_id": "my-model-id"
2934
}"""));
3035
}
3136

32-
public void testToXContent_MultipleInputs() throws IOException {
33-
var entity = new ElasticInferenceServiceSparseEmbeddingsRequestEntity(List.of("abc", "def"));
37+
public void testToXContent_MultipleInputs_UnspecifiedUsageContext() throws IOException {
38+
var entity = new ElasticInferenceServiceSparseEmbeddingsRequestEntity(
39+
List.of("abc", "def"),
40+
"my-model-id",
41+
ElasticInferenceServiceUsageContext.UNSPECIFIED
42+
);
3443
String xContentString = xContentEntityToString(entity);
3544
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
3645
{
3746
"input": [
3847
"abc",
3948
"def"
40-
]
49+
],
50+
"model_id": "my-model-id"
51+
}
52+
"""));
53+
}
54+
55+
public void testToXContent_MultipleInputs_SearchUsageContext() throws IOException {
56+
var entity = new ElasticInferenceServiceSparseEmbeddingsRequestEntity(
57+
List.of("abc"),
58+
"my-model-id",
59+
ElasticInferenceServiceUsageContext.SEARCH
60+
);
61+
String xContentString = xContentEntityToString(entity);
62+
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
63+
{
64+
"input": ["abc"],
65+
"model_id": "my-model-id",
66+
"usage_context": "search"
67+
}
68+
"""));
69+
}
70+
71+
public void testToXContent_MultipleInputs_IngestUsageContext() throws IOException {
72+
var entity = new ElasticInferenceServiceSparseEmbeddingsRequestEntity(
73+
List.of("abc"),
74+
"my-model-id",
75+
ElasticInferenceServiceUsageContext.INGEST
76+
);
77+
String xContentString = xContentEntityToString(entity);
78+
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
79+
{
80+
"input": ["abc"],
81+
"model_id": "my-model-id",
82+
"usage_context": "ingest"
4183
}
4284
"""));
4385
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestTests.java

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,24 +31,28 @@ public class ElasticInferenceServiceSparseEmbeddingsRequestTests extends ESTestC
3131
public void testCreateHttpRequest() throws IOException {
3232
var url = "http://eis-gateway.com";
3333
var input = "input";
34+
var modelId = "my-model-id";
3435

35-
var request = createRequest(url, input);
36+
var request = createRequest(url, modelId, input, InputType.SEARCH);
3637
var httpRequest = request.createHttpRequest();
3738

3839
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
3940
var httpPost = (HttpPost) httpRequest.httpRequestBase();
4041

4142
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
4243
var requestMap = entityAsMap(httpPost.getEntity().getContent());
43-
assertThat(requestMap.size(), equalTo(1));
44+
assertThat(requestMap.size(), equalTo(3));
4445
assertThat(requestMap.get("input"), is(List.of(input)));
46+
assertThat(requestMap.get("model_id"), is(modelId));
47+
assertThat(requestMap.get("usage_context"), equalTo("search"));
4548
}
4649

4750
public void testTraceContextPropagatedThroughHTTPHeaders() {
4851
var url = "http://eis-gateway.com";
4952
var input = "input";
53+
var modelId = "my-model-id";
5054

51-
var request = createRequest(url, input);
55+
var request = createRequest(url, modelId, input, InputType.UNSPECIFIED);
5256
var httpRequest = request.createHttpRequest();
5357

5458
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@@ -64,32 +68,52 @@ public void testTraceContextPropagatedThroughHTTPHeaders() {
6468
public void testTruncate_ReducesInputTextSizeByHalf() throws IOException {
6569
var url = "http://eis-gateway.com";
6670
var input = "abcd";
71+
var modelId = "my-model-id";
6772

68-
var request = createRequest(url, input);
73+
var request = createRequest(url, modelId, input, InputType.UNSPECIFIED);
6974
var truncatedRequest = request.truncate();
7075

7176
var httpRequest = truncatedRequest.createHttpRequest();
7277
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
7378

7479
var httpPost = (HttpPost) httpRequest.httpRequestBase();
7580
var requestMap = entityAsMap(httpPost.getEntity().getContent());
76-
assertThat(requestMap, aMapWithSize(1));
81+
assertThat(requestMap, aMapWithSize(2));
7782
assertThat(requestMap.get("input"), is(List.of("ab")));
83+
assertThat(requestMap.get("model_id"), is(modelId));
7884
}
7985

8086
public void testIsTruncated_ReturnsTrue() {
8187
var url = "http://eis-gateway.com";
8288
var input = "abcd";
89+
var modelId = "my-model-id";
8390

84-
var request = createRequest(url, input);
91+
var request = createRequest(url, modelId, input, InputType.UNSPECIFIED);
8592
assertFalse(request.getTruncationInfo()[0]);
8693

8794
var truncatedRequest = request.truncate();
8895
assertTrue(truncatedRequest.getTruncationInfo()[0]);
8996
}
9097

91-
public ElasticInferenceServiceSparseEmbeddingsRequest createRequest(String url, String input) {
92-
var embeddingsModel = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(url);
98+
public void testInputTypeToUsageContext_Search() {
99+
assertThat(inputTypeToUsageContext(InputType.SEARCH), equalTo(ElasticInferenceServiceUsageContext.SEARCH));
100+
}
101+
102+
public void testInputTypeToUsageContext_Ingest() {
103+
assertThat(inputTypeToUsageContext(InputType.INGEST), equalTo(ElasticInferenceServiceUsageContext.INGEST));
104+
}
105+
106+
public void testInputTypeToUsageContext_Unspecified() {
107+
assertThat(inputTypeToUsageContext(InputType.UNSPECIFIED), equalTo(ElasticInferenceServiceUsageContext.UNSPECIFIED));
108+
}
109+
110+
public void testInputTypeToUsageContext_Unknown_DefaultToUnspecified() {
111+
assertThat(inputTypeToUsageContext(InputType.CLASSIFICATION), equalTo(ElasticInferenceServiceUsageContext.UNSPECIFIED));
112+
assertThat(inputTypeToUsageContext(InputType.CLUSTERING), equalTo(ElasticInferenceServiceUsageContext.UNSPECIFIED));
113+
}
114+
115+
public ElasticInferenceServiceSparseEmbeddingsRequest createRequest(String url, String modelId, String input, InputType inputType) {
116+
var embeddingsModel = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(url, modelId);
93117

94118
return new ElasticInferenceServiceSparseEmbeddingsRequest(
95119
TruncatorTests.createTruncator(),

0 commit comments

Comments
 (0)