Skip to content

Commit 635fda1

Browse files
authored
[ML] Do not create inference endpoint if ID is used in existing mappings (#137055)
When creating an inference endpoint, if the inference ID is used in incompatible semantic_text mappings, prevent the endpoint from being created. Closes #124272 - Check if existing semantic text fields have compatible model settings - Update and expand test coverage for the new behaviour - Improve existing test InferenceServiceExtension implementations - Move SemanticTextInfoExtractor from xpack.core.ml.utils to xpack.inference.common
1 parent 6c979d9 commit 635fda1

File tree

19 files changed

+652
-175
lines changed

19 files changed

+652
-175
lines changed

docs/changelog/137055.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 137055
2+
summary: Do not create inference endpoint if ID is used in existing mappings
3+
area: Machine Learning
4+
type: bug
5+
issues:
6+
- 124272

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/InferenceProcessorInfoExtractor.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,12 @@ public static Map<String, Set<String>> pipelineIdsByResource(ClusterState state,
116116
}
117117

118118
/**
119-
* @param state Current {@link ClusterState}
119+
* @param metadata Current cluster state {@link Metadata}
120120
* @return a map from Model or Deployment IDs or Aliases to each pipeline referencing them.
121121
*/
122-
public static Set<String> pipelineIdsForResource(ClusterState state, Set<String> ids) {
122+
public static Set<String> pipelineIdsForResource(Metadata metadata, Set<String> ids) {
123123
assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures");
124124
Set<String> pipelineIds = new HashSet<>();
125-
Metadata metadata = state.metadata();
126125
if (metadata == null) {
127126
return pipelineIds;
128127
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/SemanticTextInfoExtractor.java

Lines changed: 0 additions & 46 deletions
This file was deleted.

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,22 @@ static String mockDenseServiceModelConfig() {
171171
""";
172172
}
173173

174+
static String mockDenseServiceModelConfig(int dimensions) {
175+
return Strings.format("""
176+
{
177+
"task_type": "text_embedding",
178+
"service": "text_embedding_test_service",
179+
"service_settings": {
180+
"model": "my_dense_vector_model",
181+
"api_key": "abc64",
182+
"dimensions": %s
183+
},
184+
"task_settings": {
185+
}
186+
}
187+
""", dimensions);
188+
}
189+
174190
static String mockRerankServiceModelConfig() {
175191
return """
176192
{

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package org.elasticsearch.xpack.inference;
1111

1212
import org.apache.http.util.EntityUtils;
13+
import org.elasticsearch.client.Request;
1314
import org.elasticsearch.client.Response;
1415
import org.elasticsearch.client.ResponseException;
1516
import org.elasticsearch.common.Strings;
@@ -211,7 +212,7 @@ public void testDeleteEndpointWhileReferencedBySemanticText() throws IOException
211212
final String endpointId = "endpoint_referenced_by_semantic_text";
212213
final String searchEndpointId = "search_endpoint_referenced_by_semantic_text";
213214
final String indexName = randomAlphaOfLength(10).toLowerCase();
214-
final Function<String, String> buildErrorString = endpointName -> " Inference endpoint "
215+
final Function<String, String> buildErrorString = endpointName -> "Inference endpoint "
215216
+ endpointName
216217
+ " is being used in the mapping for indexes: "
217218
+ Set.of(indexName)
@@ -303,6 +304,74 @@ public void testDeleteEndpointWhileReferencedBySemanticTextAndPipeline() throws
303304
deleteIndex(indexName);
304305
}
305306

307+
public void testCreateEndpoint_withInferenceIdReferencedBySemanticText() throws IOException {
308+
final String endpointId = "endpoint_referenced_by_semantic_text";
309+
final String otherEndpointId = "other_endpoint_referenced_by_semantic_text";
310+
final String indexName1 = randomAlphaOfLength(10).toLowerCase();
311+
final String indexName2 = randomValueOtherThan(indexName1, () -> randomAlphaOfLength(10).toLowerCase());
312+
313+
putModel(endpointId, mockDenseServiceModelConfig(128), TaskType.TEXT_EMBEDDING);
314+
putModel(otherEndpointId, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
315+
// Create two indices, one where the inference ID of the endpoint we'll be deleting and
316+
// recreating is used for inference_id and one where it's used for search_inference_id
317+
putSemanticText(endpointId, otherEndpointId, indexName1);
318+
putSemanticText(otherEndpointId, endpointId, indexName2);
319+
320+
// Confirm that we can create the endpoint with different settings if there
321+
// are documents in the indices which do not use the semantic text field
322+
var request = new Request("PUT", indexName1 + "/_create/1");
323+
request.setJsonEntity("{\"non_inference_field\": \"value\"}");
324+
assertStatusOkOrCreated(client().performRequest(request));
325+
326+
request = new Request("PUT", indexName2 + "/_create/1");
327+
request.setJsonEntity("{\"non_inference_field\": \"value\"}");
328+
assertStatusOkOrCreated(client().performRequest(request));
329+
330+
assertStatusOkOrCreated(client().performRequest(new Request("GET", "_refresh")));
331+
332+
deleteModel(endpointId, "force=true");
333+
putModel(endpointId, mockDenseServiceModelConfig(64), TaskType.TEXT_EMBEDDING);
334+
335+
// Index a document with the semantic text field into each index
336+
request = new Request("PUT", indexName1 + "/_create/2");
337+
request.setJsonEntity("{\"inference_field\": \"value\"}");
338+
assertStatusOkOrCreated(client().performRequest(request));
339+
340+
request = new Request("PUT", indexName2 + "/_create/2");
341+
request.setJsonEntity("{\"inference_field\": \"value\"}");
342+
assertStatusOkOrCreated(client().performRequest(request));
343+
344+
assertStatusOkOrCreated(client().performRequest(new Request("GET", "_refresh")));
345+
346+
deleteModel(endpointId, "force=true");
347+
348+
// Try to create an inference endpoint with the same ID but different dimensions
349+
// from when the document with the semantic text field was indexed
350+
ResponseException responseException = assertThrows(
351+
ResponseException.class,
352+
() -> putModel(endpointId, mockDenseServiceModelConfig(128), TaskType.TEXT_EMBEDDING)
353+
);
354+
assertThat(
355+
responseException.getMessage(),
356+
containsString(
357+
"Inference endpoint ["
358+
+ endpointId
359+
+ "] could not be created because the inference_id is being used in mappings with incompatible settings for indices: ["
360+
)
361+
);
362+
assertThat(responseException.getMessage(), containsString(indexName1));
363+
assertThat(responseException.getMessage(), containsString(indexName2));
364+
assertThat(
365+
responseException.getMessage(),
366+
containsString("Please either use a different inference_id or update the index mappings to refer to a different inference_id.")
367+
);
368+
369+
deleteIndex(indexName1);
370+
deleteIndex(indexName2);
371+
372+
deleteModel(otherEndpointId, "force=true");
373+
}
374+
306375
public void testUnsupportedStream() throws Exception {
307376
String modelId = "streaming";
308377
putModel(modelId, mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING, "streaming_completion_test_service"));

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestCompletionServiceExtension.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public List<Factory> getInferenceServiceFactories() {
4949
}
5050

5151
public static class TestInferenceService extends AbstractTestInferenceService {
52-
private static final String NAME = "completion_test_service";
52+
public static final String NAME = "completion_test_service";
5353
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES = EnumSet.of(TaskType.COMPLETION);
5454

5555
public TestInferenceService(InferenceServiceFactoryContext context) {}

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,12 @@ public void chunkedInfer(
170170
private DenseEmbeddingFloatResults makeResults(List<String> input, ServiceSettings serviceSettings) {
171171
List<DenseEmbeddingFloatResults.Embedding> embeddings = new ArrayList<>();
172172
for (String inputString : input) {
173-
List<Float> floatEmbeddings = generateEmbedding(inputString, serviceSettings.dimensions(), serviceSettings.elementType());
173+
List<Float> floatEmbeddings = generateEmbedding(
174+
inputString,
175+
serviceSettings.dimensions(),
176+
serviceSettings.elementType(),
177+
serviceSettings.similarity()
178+
);
174179
embeddings.add(DenseEmbeddingFloatResults.Embedding.of(floatEmbeddings));
175180
}
176181
return new DenseEmbeddingFloatResults(embeddings);
@@ -206,7 +211,7 @@ protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceS
206211
* <ul>
207212
* <li>Unique to the input</li>
208213
* <li>Reproducible (i.e given the same input, the same embedding should be generated)</li>
209-
* <li>Valid for the provided element type</li>
214+
* <li>Valid for the provided element type and similarity measure</li>
210215
* </ul>
211216
* <p>
212217
* The embedding is generated by:
@@ -216,6 +221,7 @@ protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceS
216221
* <li>converting the hash code value to a string</li>
217222
* <li>converting the string to a UTF-8 encoded byte array</li>
218223
* <li>repeatedly appending the byte array to the embedding until the desired number of dimensions are populated</li>
224+
* <li>converting the embedding to a unit vector if the similarity measure requires that</li>
219225
* </ul>
220226
* <p>
221227
* Since the hash code value, when interpreted as a string, is guaranteed to only contain digits and the "-" character, the UTF-8
@@ -226,11 +232,17 @@ protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceS
226232
* embedding byte.
227233
* </p>
228234
*
229-
* @param input The input string
230-
* @param dimensions The embedding dimension count
235+
* @param input The input string
236+
* @param dimensions The embedding dimension count
237+
* @param similarityMeasure The similarity measure
231238
* @return An embedding
232239
*/
233-
private static List<Float> generateEmbedding(String input, int dimensions, DenseVectorFieldMapper.ElementType elementType) {
240+
private static List<Float> generateEmbedding(
241+
String input,
242+
int dimensions,
243+
DenseVectorFieldMapper.ElementType elementType,
244+
SimilarityMeasure similarityMeasure
245+
) {
234246
int embeddingLength = getEmbeddingLength(elementType, dimensions);
235247
List<Float> embedding = new ArrayList<>(embeddingLength);
236248

@@ -248,6 +260,9 @@ private static List<Float> generateEmbedding(String input, int dimensions, Dense
248260
if (remainingLength > 0) {
249261
embedding.addAll(embeddingValues.subList(0, remainingLength));
250262
}
263+
if (similarityMeasure == SimilarityMeasure.DOT_PRODUCT) {
264+
embedding = toUnitVector(embedding);
265+
}
251266

252267
return embedding;
253268
}
@@ -263,6 +278,11 @@ private static int getEmbeddingLength(DenseVectorFieldMapper.ElementType element
263278
};
264279
}
265280

281+
private static List<Float> toUnitVector(List<Float> embedding) {
282+
var magnitude = (float) Math.sqrt(embedding.stream().reduce(0f, (a, b) -> a + (b * b)));
283+
return embedding.stream().map(v -> v / magnitude).toList();
284+
}
285+
266286
public static class Configuration {
267287
public static InferenceServiceConfiguration get() {
268288
return configuration.getOrCompute();
@@ -304,9 +324,13 @@ public record TestServiceSettings(
304324
public static TestServiceSettings fromMap(Map<String, Object> map) {
305325
ValidationException validationException = new ValidationException();
306326

307-
String model = (String) map.remove("model");
327+
String model = (String) map.remove("model_id");
328+
308329
if (model == null) {
309-
validationException.addValidationError("missing model");
330+
model = (String) map.remove("model");
331+
if (model == null) {
332+
validationException.addValidationError("missing model");
333+
}
310334
}
311335

312336
Integer dimensions = (Integer) map.remove("dimensions");

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,10 @@ public static TestServiceSettings fromMap(Map<String, Object> map) {
318318
String model = (String) map.remove("model_id");
319319

320320
if (model == null) {
321-
validationException.addValidationError("missing model");
321+
model = (String) map.remove("model");
322+
if (model == null) {
323+
validationException.addValidationError("missing model");
324+
}
322325
}
323326

324327
if (validationException.validationErrors().isEmpty() == false) {

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,10 +268,13 @@ public record TestServiceSettings(String model, String hiddenField, boolean shou
268268
public static TestServiceSettings fromMap(Map<String, Object> map) {
269269
ValidationException validationException = new ValidationException();
270270

271-
String model = (String) map.remove("model");
271+
String model = (String) map.remove("model_id");
272272

273273
if (model == null) {
274-
validationException.addValidationError("missing model");
274+
model = (String) map.remove("model");
275+
if (model == null) {
276+
validationException.addValidationError("missing model");
277+
}
275278
}
276279

277280
String hiddenField = (String) map.remove("hidden_field");

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ public List<Factory> getInferenceServiceFactories() {
5959
}
6060

6161
public static class TestInferenceService extends AbstractTestInferenceService {
62-
private static final String NAME = "streaming_completion_test_service";
62+
public static final String NAME = "streaming_completion_test_service";
6363
private static final String ALIAS = "streaming_completion_test_service_alias";
6464
private static final Set<TaskType> supportedStreamingTasks = Set.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
6565

@@ -343,12 +343,15 @@ public TestServiceSettings(StreamInput in) throws IOException {
343343
}
344344

345345
public static TestServiceSettings fromMap(Map<String, Object> map) {
346-
var modelId = map.remove("model").toString();
346+
String modelId = (String) map.remove("model_id");
347347

348348
if (modelId == null) {
349-
ValidationException validationException = new ValidationException();
350-
validationException.addValidationError("missing model id");
351-
throw validationException;
349+
modelId = (String) map.remove("model");
350+
if (modelId == null) {
351+
ValidationException validationException = new ValidationException();
352+
validationException.addValidationError("missing model id");
353+
throw validationException;
354+
}
352355
}
353356

354357
return new TestServiceSettings(modelId);

0 commit comments

Comments
 (0)