Skip to content
Merged
6 changes: 6 additions & 0 deletions docs/changelog/137055.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 137055
summary: Do not create inference endpoint if ID is used in existing mappings
area: Machine Learning
type: bug
issues:
- 124272
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ public record MinimalServiceSettings(

public static final String SERVICE_FIELD = "service";
public static final String TASK_TYPE_FIELD = "task_type";
static final String DIMENSIONS_FIELD = "dimensions";
static final String SIMILARITY_FIELD = "similarity";
static final String ELEMENT_TYPE_FIELD = "element_type";
public static final String DIMENSIONS_FIELD = "dimensions";
public static final String SIMILARITY_FIELD = "similarity";
public static final String ELEMENT_TYPE_FIELD = "element_type";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these need to be public?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize its consistency, but having it in this PR implies a behavior access change, and there isn't.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My mistake, I was originally using these in tests, but then realized that ServiceFields has the same Strings defined and already public. I'll put them back how they were. It would be nice if we could have a single source of truth for these sorts of constants, because they're defined independently in a lot of different places, which makes it a real pain if we ever want to update them.


private static final ConstructingObjectParser<MinimalServiceSettings, Void> PARSER = new ConstructingObjectParser<>(
"model_settings",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,12 @@ public static Map<String, Set<String>> pipelineIdsByResource(ClusterState state,
}

/**
* @param state Current {@link ClusterState}
* @param metadata Current cluster state {@link Metadata}
* @return a map from Model or Deployment IDs or Aliases to each pipeline referencing them.
*/
public static Set<String> pipelineIdsForResource(ClusterState state, Set<String> ids) {
public static Set<String> pipelineIdsForResource(Metadata metadata, Set<String> ids) {
assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures");
Set<String> pipelineIds = new HashSet<>();
Metadata metadata = state.metadata();
if (metadata == null) {
return pipelineIds;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package org.elasticsearch.xpack.inference;

import org.apache.http.util.EntityUtils;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.common.Strings;
Expand Down Expand Up @@ -211,7 +212,7 @@ public void testDeleteEndpointWhileReferencedBySemanticText() throws IOException
final String endpointId = "endpoint_referenced_by_semantic_text";
final String searchEndpointId = "search_endpoint_referenced_by_semantic_text";
final String indexName = randomAlphaOfLength(10).toLowerCase();
final Function<String, String> buildErrorString = endpointName -> " Inference endpoint "
final Function<String, String> buildErrorString = endpointName -> "Inference endpoint "
+ endpointName
+ " is being used in the mapping for indexes: "
+ Set.of(indexName)
Expand Down Expand Up @@ -303,6 +304,138 @@ public void testDeleteEndpointWhileReferencedBySemanticTextAndPipeline() throws
deleteIndex(indexName);
}

public void testCreateEndpoint_withInferenceIdReferencedByPipeline() throws IOException {
String endpointId = "endpoint_referenced_by_pipeline";
putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
var pipelineId1 = "pipeline_referencing_model_1";
var pipelineId2 = "pipeline_referencing_model_2";
putPipeline(pipelineId1, endpointId);
putPipeline(pipelineId2, endpointId);

deleteModel(endpointId, "force=true");

ResponseException responseException = assertThrows(
ResponseException.class,
() -> putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING)
);
assertThat(
responseException.getMessage(),
containsString(
"Inference endpoint ["
+ endpointId
+ "] could not be created because the inference_id is already referenced by pipelines: ["
)
);
assertThat(responseException.getMessage(), containsString(pipelineId1));
assertThat(responseException.getMessage(), containsString(pipelineId2));
assertThat(
responseException.getMessage(),
containsString(
"Please either use a different inference_id or update the index mappings "
+ "and/or pipelines to refer to a different inference_id."
)
);

deletePipeline(pipelineId1);
deletePipeline(pipelineId2);
}

public void testCreateEndpoint_withInferenceIdReferencedBySemanticText() throws IOException {
final String endpointId = "endpoint_referenced_by_semantic_text";
final String otherEndpointId = "other_endpoint_referenced_by_semantic_text";
final String indexName1 = randomAlphaOfLength(10).toLowerCase();
final String indexName2 = randomValueOtherThan(indexName1, () -> randomAlphaOfLength(10).toLowerCase());

putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
putModel(otherEndpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
// Create two indices, one where the inference ID of the endpoint we'll be deleting and recreating is used for
// inference_id and one where it's used for search_inference_id
putSemanticText(endpointId, otherEndpointId, indexName1);
putSemanticText(otherEndpointId, endpointId, indexName2);

// Confirm that we can create the endpoint if there are no documents in the indices using it
deleteModel(endpointId, "force=true");
putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);

// Index a document into each index
var request1 = new Request("PUT", indexName1 + "/_create/1");
request1.setJsonEntity("{\"inference_field\": \"value\"}");
assertStatusOkOrCreated(client().performRequest(request1));

var request2 = new Request("PUT", indexName2 + "/_create/1");
request2.setJsonEntity("{\"inference_field\": \"value\"}");
assertStatusOkOrCreated(client().performRequest(request2));

assertStatusOkOrCreated(client().performRequest(new Request("GET", "_refresh")));

deleteModel(endpointId, "force=true");

// Try to create an inference endpoint with the same ID
ResponseException responseException = assertThrows(
ResponseException.class,
() -> putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING)
);
assertThat(
responseException.getMessage(),
containsString(
"Inference endpoint ["
+ endpointId
+ "] could not be created because the inference_id is already being used in mappings for indices: ["
)
);
assertThat(responseException.getMessage(), containsString(indexName1));
assertThat(responseException.getMessage(), containsString(indexName2));
assertThat(
responseException.getMessage(),
containsString(
"Please either use a different inference_id or update the index mappings "
+ "and/or pipelines to refer to a different inference_id."
)
);

deleteIndex(indexName1);
deleteIndex(indexName2);

deleteModel(otherEndpointId, "force=true");
}

public void testCreateEndpoint_withInferenceIdReferencedBySemanticTextAndPipeline() throws IOException {
String endpointId = "endpoint_referenced_by_semantic_text";
putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
String indexName = randomAlphaOfLength(10).toLowerCase();
putSemanticText(endpointId, indexName);

// Index a document into the index
var indexDocRequest = new Request("PUT", indexName + "/_create/1");
indexDocRequest.setJsonEntity("{\"inference_field\": \"value\"}");
assertStatusOkOrCreated(client().performRequest(indexDocRequest));

assertStatusOkOrCreated(client().performRequest(new Request("GET", "_refresh")));

var pipelineId = "pipeline_referencing_model";
putPipeline(pipelineId, endpointId);

deleteModel(endpointId, "force=true");

String errorString = "Inference endpoint ["
+ endpointId
+ "] could not be created because the inference_id is already being used in mappings for indices: ["
+ indexName
+ "] and referenced by pipelines: ["
+ pipelineId
+ "]. Please either use a different inference_id or update the index mappings "
+ "and/or pipelines to refer to a different inference_id.";

ResponseException responseException = assertThrows(
ResponseException.class,
() -> putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING)
);
assertThat(responseException.getMessage(), containsString(errorString));

deletePipeline(pipelineId);
deleteIndex(indexName);
}

public void testUnsupportedStream() throws Exception {
String modelId = "streaming";
putModel(modelId, mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING, "streaming_completion_test_service"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public List<Factory> getInferenceServiceFactories() {
}

public static class TestInferenceService extends AbstractTestInferenceService {
private static final String NAME = "completion_test_service";
public static final String NAME = "completion_test_service";
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES = EnumSet.of(TaskType.COMPLETION);

public TestInferenceService(InferenceServiceFactoryContext context) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public List<Factory> getInferenceServiceFactories() {
}

public static class TestInferenceService extends AbstractTestInferenceService {
private static final String NAME = "streaming_completion_test_service";
public static final String NAME = "streaming_completion_test_service";
private static final String ALIAS = "streaming_completion_test_service_alias";
private static final Set<TaskType> supportedStreamingTasks = Set.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);

Expand Down
Loading