Skip to content

Commit fc352bd

Browse files
[Inference API] Replace model_id with inference_id in inference API except when stored (#111366) (#111417)
* Replace model_id with inference_id in inference API except when storing ModelConfigs * Update docs/changelog/111366.yaml * replace missed literals in tests
1 parent 6ee8747 commit fc352bd

File tree

7 files changed

+38
-14
lines changed

7 files changed

+38
-14
lines changed

docs/changelog/111366.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 111366
2+
summary: "[Inference API] Replace `model_id` with `inference_id` in inference API\
3+
\ except when stored"
4+
area: Machine Learning
5+
type: bug
6+
issues: []

server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020

2121
public class ModelConfigurations implements ToFilteredXContentObject, VersionedNamedWriteable {
2222

23-
public static final String MODEL_ID = "model_id";
23+
// Due to refactoring, we now have different field names for the inference ID when it is serialized and stored to an index vs when it
24+
// is returned as part of a GetInferenceModelAction
25+
public static final String INDEX_ONLY_ID_FIELD_NAME = "model_id";
26+
public static final String INFERENCE_ID_FIELD_NAME = "inference_id";
27+
public static final String USE_ID_FOR_INDEX = "for_index";
2428
public static final String SERVICE = "service";
2529
public static final String SERVICE_SETTINGS = "service_settings";
2630
public static final String TASK_SETTINGS = "task_settings";
@@ -119,7 +123,11 @@ public TaskSettings getTaskSettings() {
119123
@Override
120124
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
121125
builder.startObject();
122-
builder.field(MODEL_ID, inferenceEntityId);
126+
if (params.paramAsBoolean(USE_ID_FOR_INDEX, false)) {
127+
builder.field(INDEX_ONLY_ID_FIELD_NAME, inferenceEntityId);
128+
} else {
129+
builder.field(INFERENCE_ID_FIELD_NAME, inferenceEntityId);
130+
}
123131
builder.field(TaskType.NAME, taskType.toString());
124132
builder.field(SERVICE, service);
125133
builder.field(SERVICE_SETTINGS, serviceSettings);
@@ -131,7 +139,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
131139
@Override
132140
public XContentBuilder toFilteredXContent(XContentBuilder builder, Params params) throws IOException {
133141
builder.startObject();
134-
builder.field(MODEL_ID, inferenceEntityId);
142+
if (params.paramAsBoolean(USE_ID_FOR_INDEX, false)) {
143+
builder.field(INDEX_ONLY_ID_FIELD_NAME, inferenceEntityId);
144+
} else {
145+
builder.field(INFERENCE_ID_FIELD_NAME, inferenceEntityId);
146+
}
135147
builder.field(TaskType.NAME, taskType.toString());
136148
builder.field(SERVICE, service);
137149
builder.field(SERVICE_SETTINGS, serviceSettings.getFilteredXContentObject());

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public void testGet() throws IOException {
4848

4949
var singleModel = getModels("se_model_1", TaskType.SPARSE_EMBEDDING);
5050
assertThat(singleModel, hasSize(1));
51-
assertEquals("se_model_1", singleModel.get(0).get("model_id"));
51+
assertEquals("se_model_1", singleModel.get(0).get("inference_id"));
5252

5353
for (int i = 0; i < 5; i++) {
5454
deleteModel("se_model_" + i, TaskType.SPARSE_EMBEDDING);
@@ -81,7 +81,7 @@ public void testGetModelWithAnyTaskType() throws IOException {
8181
String inferenceEntityId = "sparse_embedding_model";
8282
putModel(inferenceEntityId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
8383
var singleModel = getModels(inferenceEntityId, TaskType.ANY);
84-
assertEquals(inferenceEntityId, singleModel.get(0).get("model_id"));
84+
assertEquals(inferenceEntityId, singleModel.get(0).get("inference_id"));
8585
assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get(0).get("task_type"));
8686
}
8787

@@ -90,7 +90,7 @@ public void testApisWithoutTaskType() throws IOException {
9090
String modelId = "no_task_type_in_url";
9191
putModel(modelId, mockSparseServiceModelConfig(TaskType.SPARSE_EMBEDDING));
9292
var singleModel = getModel(modelId);
93-
assertEquals(modelId, singleModel.get("model_id"));
93+
assertEquals(modelId, singleModel.get("inference_id"));
9494
assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get("task_type"));
9595

9696
var inference = inferOnMockService(modelId, List.of(randomAlphaOfLength(10)));

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockDenseInferenceServiceIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ public void testMockService() throws IOException {
2222
var model = getModels(inferenceEntityId, TaskType.TEXT_EMBEDDING).get(0);
2323

2424
for (var modelMap : List.of(putModel, model)) {
25-
assertEquals(inferenceEntityId, modelMap.get("model_id"));
25+
assertEquals(inferenceEntityId, modelMap.get("inference_id"));
2626
assertEquals(TaskType.TEXT_EMBEDDING, TaskType.fromString((String) modelMap.get("task_type")));
2727
assertEquals("text_embedding_test_service", modelMap.get("service"));
2828
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockSparseInferenceServiceIT.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public void testMockService() throws IOException {
2424
var model = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING).get(0);
2525

2626
for (var modelMap : List.of(putModel, model)) {
27-
assertEquals(inferenceEntityId, modelMap.get("model_id"));
27+
assertEquals(inferenceEntityId, modelMap.get("inference_id"));
2828
assertEquals(TaskType.SPARSE_EMBEDDING, TaskType.fromString((String) modelMap.get("task_type")));
2929
assertEquals("test_service", modelMap.get("service"));
3030
}
@@ -77,7 +77,7 @@ public void testMockService_DoesNotReturnHiddenField_InModelResponses() throws I
7777
var model = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING).get(0);
7878

7979
for (var modelMap : List.of(putModel, model)) {
80-
assertEquals(inferenceEntityId, modelMap.get("model_id"));
80+
assertEquals(inferenceEntityId, modelMap.get("inference_id"));
8181
assertThat(modelMap.get("service_settings"), is(Map.of("model", "my_model")));
8282
assertEquals(TaskType.SPARSE_EMBEDDING, TaskType.fromString((String) modelMap.get("task_type")));
8383
assertEquals("test_service", modelMap.get("service"));
@@ -95,7 +95,7 @@ public void testMockService_DoesReturnHiddenField_InModelResponses() throws IOEx
9595
var model = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING).get(0);
9696

9797
for (var modelMap : List.of(putModel, model)) {
98-
assertEquals(inferenceEntityId, modelMap.get("model_id"));
98+
assertEquals(inferenceEntityId, modelMap.get("inference_id"));
9999
assertThat(modelMap.get("service_settings"), is(Map.of("model", "my_model", "hidden_field", "my_hidden_value")));
100100
assertEquals(TaskType.SPARSE_EMBEDDING, TaskType.fromString((String) modelMap.get("task_type")));
101101
assertEquals("test_service", modelMap.get("service"));

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
405405
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
406406
builder.startObject();
407407
builder.field("unknown_field", "foo");
408-
builder.field(MODEL_ID, getInferenceEntityId());
408+
builder.field(INDEX_ONLY_ID_FIELD_NAME, getInferenceEntityId());
409409
builder.field(TaskType.NAME, getTaskType().toString());
410410
builder.field(SERVICE, getService());
411411
builder.field(SERVICE_SETTINGS, getServiceSettings());
@@ -431,7 +431,7 @@ private static class ModelWithUnknownField extends ModelConfigurations {
431431
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
432432
builder.startObject();
433433
builder.field("unknown_field", "foo");
434-
builder.field(MODEL_ID, getInferenceEntityId());
434+
builder.field(INDEX_ONLY_ID_FIELD_NAME, getInferenceEntityId());
435435
builder.field(TaskType.NAME, getTaskType().toString());
436436
builder.field(SERVICE, getService());
437437
builder.field(SERVICE_SETTINGS, getServiceSettings());

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,10 @@ public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap)
7474
if (modelConfigMap.config() == null) {
7575
throw new ElasticsearchStatusException("Missing config map", RestStatus.BAD_REQUEST);
7676
}
77-
String inferenceEntityId = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.MODEL_ID);
77+
String inferenceEntityId = ServiceUtils.removeStringOrThrowIfNull(
78+
modelConfigMap.config(),
79+
ModelConfigurations.INDEX_ONLY_ID_FIELD_NAME
80+
);
7881
String service = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.SERVICE);
7982
String taskTypeStr = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), TaskType.NAME);
8083
TaskType taskType = TaskType.fromString(taskTypeStr);
@@ -375,7 +378,10 @@ public void deleteModel(String inferenceEntityId, ActionListener<Boolean> listen
375378
private static IndexRequest createIndexRequest(String docId, String indexName, ToXContentObject body, boolean allowOverwriting) {
376379
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
377380
var request = new IndexRequest(indexName);
378-
XContentBuilder source = body.toXContent(builder, ToXContent.EMPTY_PARAMS);
381+
XContentBuilder source = body.toXContent(
382+
builder,
383+
new ToXContent.MapParams(Map.of(ModelConfigurations.USE_ID_FOR_INDEX, Boolean.TRUE.toString()))
384+
);
379385
var operation = allowOverwriting ? DocWriteRequest.OpType.INDEX : DocWriteRequest.OpType.CREATE;
380386

381387
return request.opType(operation).id(docId).source(source);

0 commit comments

Comments
 (0)