Skip to content
Merged
6 changes: 6 additions & 0 deletions docs/changelog/137055.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 137055
summary: Do not create inference endpoint if ID is used in existing mappings
area: Machine Learning
type: bug
issues:
- 124272
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ public record MinimalServiceSettings(

public static final String SERVICE_FIELD = "service";
public static final String TASK_TYPE_FIELD = "task_type";
static final String DIMENSIONS_FIELD = "dimensions";
static final String SIMILARITY_FIELD = "similarity";
static final String ELEMENT_TYPE_FIELD = "element_type";
public static final String DIMENSIONS_FIELD = "dimensions";
public static final String SIMILARITY_FIELD = "similarity";
public static final String ELEMENT_TYPE_FIELD = "element_type";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these need to be public?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize its consistency, but having it in this PR implies a behavior access change, and there isn't.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My mistake, I was originally using these in tests, but then realized that ServiceFields has the same Strings defined and already public. I'll put them back how they were. It would be nice if we could have a single source of truth for these sorts of constants, because they're defined independently in a lot of different places, which makes it a real pain if we ever want to update them.


private static final ConstructingObjectParser<MinimalServiceSettings, Void> PARSER = new ConstructingObjectParser<>(
"model_settings",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,12 @@ public static Map<String, Set<String>> pipelineIdsByResource(ClusterState state,
}

/**
* @param state Current {@link ClusterState}
* @param metadata Current cluster state {@link Metadata}
* @return a map from Model or Deployment IDs or Aliases to each pipeline referencing them.
*/
public static Set<String> pipelineIdsForResource(ClusterState state, Set<String> ids) {
public static Set<String> pipelineIdsForResource(Metadata metadata, Set<String> ids) {
assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures");
Set<String> pipelineIds = new HashSet<>();
Metadata metadata = state.metadata();
if (metadata == null) {
return pipelineIds;
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,22 @@ static String mockDenseServiceModelConfig() {
""";
}

static String mockDenseServiceModelConfig(int dimensions) {
return Strings.format("""
{
"task_type": "text_embedding",
"service": "text_embedding_test_service",
"service_settings": {
"model": "my_dense_vector_model",
"api_key": "abc64",
"dimensions": %s
},
"task_settings": {
}
}
""", dimensions);
}

static String mockRerankServiceModelConfig() {
return """
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package org.elasticsearch.xpack.inference;

import org.apache.http.util.EntityUtils;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.common.Strings;
Expand Down Expand Up @@ -211,7 +212,7 @@ public void testDeleteEndpointWhileReferencedBySemanticText() throws IOException
final String endpointId = "endpoint_referenced_by_semantic_text";
final String searchEndpointId = "search_endpoint_referenced_by_semantic_text";
final String indexName = randomAlphaOfLength(10).toLowerCase();
final Function<String, String> buildErrorString = endpointName -> " Inference endpoint "
final Function<String, String> buildErrorString = endpointName -> "Inference endpoint "
+ endpointName
+ " is being used in the mapping for indexes: "
+ Set.of(indexName)
Expand Down Expand Up @@ -303,6 +304,74 @@ public void testDeleteEndpointWhileReferencedBySemanticTextAndPipeline() throws
deleteIndex(indexName);
}

public void testCreateEndpoint_withInferenceIdReferencedBySemanticText() throws IOException {
final String endpointId = "endpoint_referenced_by_semantic_text";
final String otherEndpointId = "other_endpoint_referenced_by_semantic_text";
final String indexName1 = randomAlphaOfLength(10).toLowerCase();
final String indexName2 = randomValueOtherThan(indexName1, () -> randomAlphaOfLength(10).toLowerCase());

putModel(endpointId, mockDenseServiceModelConfig(128), TaskType.TEXT_EMBEDDING);
putModel(otherEndpointId, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
// Create two indices, one where the inference ID of the endpoint we'll be deleting and
// recreating is used for inference_id and one where it's used for search_inference_id
putSemanticText(endpointId, otherEndpointId, indexName1);
putSemanticText(otherEndpointId, endpointId, indexName2);

// Confirm that we can create the endpoint with different settings if there
// are documents in the indices which do not use the semantic text field
var request = new Request("PUT", indexName1 + "/_create/1");
request.setJsonEntity("{\"non_inference_field\": \"value\"}");
assertStatusOkOrCreated(client().performRequest(request));

request = new Request("PUT", indexName2 + "/_create/1");
request.setJsonEntity("{\"non_inference_field\": \"value\"}");
assertStatusOkOrCreated(client().performRequest(request));

assertStatusOkOrCreated(client().performRequest(new Request("GET", "_refresh")));

deleteModel(endpointId, "force=true");
putModel(endpointId, mockDenseServiceModelConfig(64), TaskType.TEXT_EMBEDDING);

// Index a document with the semantic text field into each index
request = new Request("PUT", indexName1 + "/_create/2");
request.setJsonEntity("{\"inference_field\": \"value\"}");
assertStatusOkOrCreated(client().performRequest(request));

request = new Request("PUT", indexName2 + "/_create/2");
request.setJsonEntity("{\"inference_field\": \"value\"}");
assertStatusOkOrCreated(client().performRequest(request));

assertStatusOkOrCreated(client().performRequest(new Request("GET", "_refresh")));

deleteModel(endpointId, "force=true");

// Try to create an inference endpoint with the same ID but different dimensions
// from when the document with the semantic text field was indexed
ResponseException responseException = assertThrows(
ResponseException.class,
() -> putModel(endpointId, mockDenseServiceModelConfig(128), TaskType.TEXT_EMBEDDING)
);
Comment on lines +348 to +353
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this should fail even if there aren't any docs in the index.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean that it should fail if no docs with semantic text fields were ever indexed, or it should fail if docs were indexed but then removed from the index? In the former case, we wouldn't have stored the model details, because that information is only populated when we get the inference results back when ingesting a document with a semantic text field. In the latter case, the endpoint creation does fail with these changes, because the semantic text field mapping isn't cleared when documents are deleted.

Since we don't persist information about the model until a doc with a semantic text field is indexed, I don't think there's any risk to the user if they create an endpoint with certain settings, delete it, then create a new one with different settings, but never index any documents with semantic text fields before creating the endpoint for the second time.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we don't persist information about the model until a doc with a semantic text field is indexed, I don't think there's any risk to the user if they create an endpoint with certain settings,

Yep, you are correct. I have been out of the loop here for a bit. LGTM!

assertThat(
responseException.getMessage(),
containsString(
"Inference endpoint ["
+ endpointId
+ "] could not be created because the inference_id is being used in mappings with incompatible settings for indices: ["
)
);
assertThat(responseException.getMessage(), containsString(indexName1));
assertThat(responseException.getMessage(), containsString(indexName2));
assertThat(
responseException.getMessage(),
containsString("Please either use a different inference_id or update the index mappings to refer to a different inference_id.")
);

deleteIndex(indexName1);
deleteIndex(indexName2);

deleteModel(otherEndpointId, "force=true");
}

public void testUnsupportedStream() throws Exception {
String modelId = "streaming";
putModel(modelId, mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING, "streaming_completion_test_service"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public List<Factory> getInferenceServiceFactories() {
}

public static class TestInferenceService extends AbstractTestInferenceService {
private static final String NAME = "completion_test_service";
public static final String NAME = "completion_test_service";
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES = EnumSet.of(TaskType.COMPLETION);

public TestInferenceService(InferenceServiceFactoryContext context) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,12 @@ public void chunkedInfer(
private DenseEmbeddingFloatResults makeResults(List<String> input, ServiceSettings serviceSettings) {
List<DenseEmbeddingFloatResults.Embedding> embeddings = new ArrayList<>();
for (String inputString : input) {
List<Float> floatEmbeddings = generateEmbedding(inputString, serviceSettings.dimensions(), serviceSettings.elementType());
List<Float> floatEmbeddings = generateEmbedding(
inputString,
serviceSettings.dimensions(),
serviceSettings.elementType(),
serviceSettings.similarity()
);
embeddings.add(DenseEmbeddingFloatResults.Embedding.of(floatEmbeddings));
}
return new DenseEmbeddingFloatResults(embeddings);
Expand Down Expand Up @@ -206,7 +211,7 @@ protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceS
* <ul>
* <li>Unique to the input</li>
* <li>Reproducible (i.e given the same input, the same embedding should be generated)</li>
* <li>Valid for the provided element type</li>
* <li>Valid for the provided element type and similarity measure</li>
* </ul>
* <p>
* The embedding is generated by:
Expand All @@ -216,6 +221,7 @@ protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceS
* <li>converting the hash code value to a string</li>
* <li>converting the string to a UTF-8 encoded byte array</li>
* <li>repeatedly appending the byte array to the embedding until the desired number of dimensions are populated</li>
* <li>converting the embedding to a unit vector if the similarity measure requires that</li>
* </ul>
* <p>
* Since the hash code value, when interpreted as a string, is guaranteed to only contain digits and the "-" character, the UTF-8
Expand All @@ -226,11 +232,17 @@ protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceS
* embedding byte.
* </p>
*
* @param input The input string
* @param dimensions The embedding dimension count
* @param input The input string
* @param dimensions The embedding dimension count
* @param similarityMeasure The similarity measure
* @return An embedding
*/
private static List<Float> generateEmbedding(String input, int dimensions, DenseVectorFieldMapper.ElementType elementType) {
private static List<Float> generateEmbedding(
String input,
int dimensions,
DenseVectorFieldMapper.ElementType elementType,
SimilarityMeasure similarityMeasure
) {
int embeddingLength = getEmbeddingLength(elementType, dimensions);
List<Float> embedding = new ArrayList<>(embeddingLength);

Expand All @@ -248,6 +260,9 @@ private static List<Float> generateEmbedding(String input, int dimensions, Dense
if (remainingLength > 0) {
embedding.addAll(embeddingValues.subList(0, remainingLength));
}
if (similarityMeasure == SimilarityMeasure.DOT_PRODUCT) {
embedding = toUnitVector(embedding);
}

return embedding;
}
Expand All @@ -263,6 +278,11 @@ private static int getEmbeddingLength(DenseVectorFieldMapper.ElementType element
};
}

private static List<Float> toUnitVector(List<Float> embedding) {
var magnitude = (float) Math.sqrt(embedding.stream().reduce(0f, (a, b) -> a + (b * b)));
return embedding.stream().map(v -> v / magnitude).toList();
}

public static class Configuration {
public static InferenceServiceConfiguration get() {
return configuration.getOrCompute();
Expand Down Expand Up @@ -304,9 +324,13 @@ public record TestServiceSettings(
public static TestServiceSettings fromMap(Map<String, Object> map) {
ValidationException validationException = new ValidationException();

String model = (String) map.remove("model");
String model = (String) map.remove("model_id");

if (model == null) {
validationException.addValidationError("missing model");
model = (String) map.remove("model");
if (model == null) {
validationException.addValidationError("missing model");
}
}

Integer dimensions = (Integer) map.remove("dimensions");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,10 @@ public static TestServiceSettings fromMap(Map<String, Object> map) {
String model = (String) map.remove("model_id");

if (model == null) {
validationException.addValidationError("missing model");
model = (String) map.remove("model");
if (model == null) {
validationException.addValidationError("missing model");
}
}

if (validationException.validationErrors().isEmpty() == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,13 @@ public record TestServiceSettings(String model, String hiddenField, boolean shou
public static TestServiceSettings fromMap(Map<String, Object> map) {
ValidationException validationException = new ValidationException();

String model = (String) map.remove("model");
String model = (String) map.remove("model_id");

if (model == null) {
validationException.addValidationError("missing model");
model = (String) map.remove("model");
if (model == null) {
validationException.addValidationError("missing model");
}
}

String hiddenField = (String) map.remove("hidden_field");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public List<Factory> getInferenceServiceFactories() {
}

public static class TestInferenceService extends AbstractTestInferenceService {
private static final String NAME = "streaming_completion_test_service";
public static final String NAME = "streaming_completion_test_service";
private static final String ALIAS = "streaming_completion_test_service_alias";
private static final Set<TaskType> supportedStreamingTasks = Set.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);

Expand Down Expand Up @@ -343,12 +343,15 @@ public TestServiceSettings(StreamInput in) throws IOException {
}

public static TestServiceSettings fromMap(Map<String, Object> map) {
var modelId = map.remove("model").toString();
String modelId = (String) map.remove("model_id");

if (modelId == null) {
ValidationException validationException = new ValidationException();
validationException.addValidationError("missing model id");
throw validationException;
modelId = (String) map.remove("model");
if (modelId == null) {
ValidationException validationException = new ValidationException();
validationException.addValidationError("missing model id");
throw validationException;
}
}

return new TestServiceSettings(modelId);
Expand Down
Loading