Skip to content

Commit 5da01ac

Browse files
[ML] Switch TextExpansionQueryBuilder and TextEmbeddingQueryVectorBuilder to return 400 instead of 500 errors (#135800)
* Switching to bad requests * Update docs/changelog/135800.yaml * Renaming test * Using yaml test * fixing test
1 parent ebf5091 commit 5da01ac

File tree

7 files changed

+126
-3
lines changed

7 files changed

+126
-3
lines changed

docs/changelog/135800.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 135800
2+
summary: Switch `TextExpansionQueryBuilder` and `TextEmbeddingQueryVectorBuilder`
3+
to return 400 instead of 500 errors
4+
area: Machine Learning
5+
type: bug
6+
issues: []

test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ public final void testVectorFetch() throws Exception {
172172
*/
173173
protected abstract ActionResponse createResponse(float[] array, T builder);
174174

175-
protected static float[] randomVector(int dim) {
175+
public static float[] randomVector(int dim) {
176176
float[] vector = new float[dim];
177177
for (int i = 0; i < vector.length; i++) {
178178
vector[i] = randomFloat();

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/TextExpansionQueryBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
188188
listener.onFailure(new IllegalStateException(warning.getWarning()));
189189
} else {
190190
listener.onFailure(
191-
new IllegalStateException(
191+
new IllegalArgumentException(
192192
"expected a result of type ["
193193
+ TextExpansionResults.NAME
194194
+ "] received ["

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/vectors/TextEmbeddingQueryVectorBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ public void buildVector(Client client, ActionListener<float[]> listener) {
132132
} else if (response.getInferenceResults().get(0) instanceof WarningInferenceResults warning) {
133133
listener.onFailure(new IllegalStateException(warning.getWarning()));
134134
} else {
135-
throw new IllegalStateException(
135+
throw new IllegalArgumentException(
136136
"expected a result of type ["
137137
+ MlTextEmbeddingResults.NAME
138138
+ "] received ["

x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/47_semantic_text_knn.yml

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,3 +624,43 @@ setup:
624624

625625
- match: { error.root_cause.0.type: "illegal_argument_exception" }
626626
- match: { error.root_cause.0.reason: "Field [inference_field] does not use a [text_embedding] model" }
627+
628+
---
629+
"Text expansion query against semantic_text field using a dense vector model returns an failure":
630+
- requires:
631+
cluster_features: "search.new_semantic_query_interceptors"
632+
reason: New semantic query interceptors
633+
test_runner_features: [ "allowed_warnings" ]
634+
635+
- do:
636+
indices.create:
637+
index: test-semantic-text-index-using-dense-vector
638+
body:
639+
mappings:
640+
properties:
641+
inference_field:
642+
type: semantic_text
643+
inference_id: dense-inference-id
644+
645+
- do:
646+
index:
647+
index: test-semantic-text-index-using-dense-vector
648+
id: doc_1
649+
body:
650+
inference_field: [ "inference test", "another inference test" ]
651+
refresh: true
652+
653+
- do:
654+
catch: bad_request
655+
search:
656+
index: test-semantic-text-index-using-dense-vector
657+
body:
658+
query:
659+
text_expansion:
660+
embedding:
661+
model_id: dense-inference-id
662+
model_text: "octopus comforter smells"
663+
allowed_warnings:
664+
- "text_expansion is deprecated. Use sparse_vector instead."
665+
- match: { error.root_cause.0.type: "illegal_argument_exception" }
666+
- match: { error.root_cause.0.reason: "expected a result of type [text_expansion_result] received [text_embedding_result]. Is [dense-inference-id] a compatible model?" }
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.ml.vectors;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.action.ActionRequest;
12+
import org.elasticsearch.action.ActionResponse;
13+
import org.elasticsearch.action.ActionType;
14+
import org.elasticsearch.action.support.PlainActionFuture;
15+
import org.elasticsearch.core.TimeValue;
16+
import org.elasticsearch.inference.WeightedToken;
17+
import org.elasticsearch.test.ESTestCase;
18+
import org.elasticsearch.test.client.NoOpClient;
19+
import org.elasticsearch.threadpool.ThreadPool;
20+
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
21+
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
22+
import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder;
23+
24+
import java.util.List;
25+
26+
import static org.hamcrest.Matchers.containsString;
27+
28+
public class TextEmbeddingQueryVectorBuilderFailureTests extends ESTestCase {
29+
30+
private static class AssertingClient extends NoOpClient {
31+
private final TextEmbeddingQueryVectorBuilder queryVectorBuilder;
32+
33+
AssertingClient(ThreadPool threadPool, TextEmbeddingQueryVectorBuilder queryVectorBuilder) {
34+
super(threadPool);
35+
this.queryVectorBuilder = queryVectorBuilder;
36+
}
37+
38+
@Override
39+
@SuppressWarnings("unchecked")
40+
protected <Request extends ActionRequest, Response extends ActionResponse> void doExecute(
41+
ActionType<Response> action,
42+
Request request,
43+
ActionListener<Response> listener
44+
) {
45+
listener.onResponse((Response) createResponse(queryVectorBuilder.getModelId()));
46+
}
47+
}
48+
49+
public void testReceiving_TextExpansionResults_ThrowsBadRequestException() {
50+
var queryVectorBuilder = createTestInstance();
51+
52+
try (var threadPool = createThreadPool()) {
53+
final var client = new AssertingClient(threadPool, queryVectorBuilder);
54+
PlainActionFuture<float[]> future = new PlainActionFuture<>();
55+
queryVectorBuilder.buildVector(client, future);
56+
57+
var exception = expectThrows(IllegalArgumentException.class, () -> future.actionGet(TimeValue.timeValueSeconds(30)));
58+
assertThat(
59+
exception.getMessage(),
60+
containsString("expected a result of type [text_embedding_result] received [text_expansion_result]")
61+
);
62+
}
63+
}
64+
65+
private static ActionResponse createResponse(String modelId) {
66+
return new InferModelAction.Response(
67+
List.of(new TextExpansionResults("foo", List.of(new WeightedToken("toke", 0.1f)), randomBoolean())),
68+
modelId,
69+
true
70+
);
71+
}
72+
73+
private static TextEmbeddingQueryVectorBuilder createTestInstance() {
74+
return new TextEmbeddingQueryVectorBuilder(randomAlphaOfLength(4), randomAlphaOfLength(4));
75+
}
76+
}

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ protected void doAssertClientRequest(ActionRequest request, TextEmbeddingQueryVe
4646
assertEquals(CoordinatedInferenceAction.Request.RequestModelType.NLP_MODEL, inferRequest.getRequestModelType());
4747
}
4848

49+
@Override
4950
public ActionResponse createResponse(float[] array, TextEmbeddingQueryVectorBuilder builder) {
5051
double[] embedding = new double[array.length];
5152
for (int i = 0; i < embedding.length; i++) {

0 commit comments

Comments
 (0)