Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/135800.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 135800
summary: Switch `TextExpansionQueryBuilder` and `TextEmbeddingQueryVectorBuilder`
to return 400 instead of 500 errors
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ public final void testVectorFetch() throws Exception {
*/
protected abstract ActionResponse createResponse(float[] array, T builder);

protected static float[] randomVector(int dim) {
public static float[] randomVector(int dim) {
float[] vector = new float[dim];
for (int i = 0; i < vector.length; i++) {
vector[i] = randomFloat();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
listener.onFailure(new IllegalStateException(warning.getWarning()));
} else {
listener.onFailure(
new IllegalStateException(
new IllegalArgumentException(
"expected a result of type ["
+ TextExpansionResults.NAME
+ "] received ["
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ public void buildVector(Client client, ActionListener<float[]> listener) {
} else if (response.getInferenceResults().get(0) instanceof WarningInferenceResults warning) {
listener.onFailure(new IllegalStateException(warning.getWarning()));
} else {
throw new IllegalStateException(
throw new IllegalArgumentException(
"expected a result of type ["
+ MlTextEmbeddingResults.NAME
+ "] received ["
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -624,3 +624,43 @@ setup:

- match: { error.root_cause.0.type: "illegal_argument_exception" }
- match: { error.root_cause.0.reason: "Field [inference_field] does not use a [text_embedding] model" }

---
"Text expansion query against semantic_text field using a dense vector model returns an failure":
- requires:
cluster_features: "search.new_semantic_query_interceptors"
reason: New semantic query interceptors
test_runner_features: [ "allowed_warnings" ]

- do:
indices.create:
index: test-semantic-text-index-using-dense-vector
body:
mappings:
properties:
inference_field:
type: semantic_text
inference_id: dense-inference-id

- do:
index:
index: test-semantic-text-index-using-dense-vector
id: doc_1
body:
inference_field: [ "inference test", "another inference test" ]
refresh: true

- do:
catch: bad_request
search:
index: test-semantic-text-index-using-dense-vector
body:
query:
text_expansion:
embedding:
model_id: dense-inference-id
model_text: "octopus comforter smells"
allowed_warnings:
- "text_expansion is deprecated. Use sparse_vector instead."
- match: { error.root_cause.0.type: "illegal_argument_exception" }
- match: { error.root_cause.0.reason: "expected a result of type [text_expansion_result] received [text_embedding_result]. Is [dense-inference-id] a compatible model?" }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why do we need to use a semantic text field to get this error? That feels unrelated since it's not even used in the query.
We can also move it to resources/rest-api-spec/test/ml/search_knn_query_vector_builder.yml ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good point, I'll move it and remove the dependency

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I can't seem to get it to work with a regular dense vector field:

XPackRestIT > test {p0=ml/search_knn_query_vector_builder/Text expansion query against semantic_text field using a dense vector model returns an failure} FAILED
    java.lang.AssertionError: Failure at [ml/search_knn_query_vector_builder:112]: expected [400] status code but api [search] returned [403 Forbidden] [{"error":{"root_cause":[{"type":"status_exception","reason":"Trained model [text_embedding_model] is configured for task [text_embedding] but called with task [text_expansion]"

I'm going to leave it in the semantic text tests.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks like another bug?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean we should return a 400 in that situation too instead of a 403? Yeah I typically would think the 403 would be for permissions issues.

Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.vectors;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.WeightedToken;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.client.NoOpClient;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder;

import java.util.List;

import static org.hamcrest.Matchers.containsString;

public class TextEmbeddingQueryVectorBuilderFailureTests extends ESTestCase {

private static class AssertingClient extends NoOpClient {
private final TextEmbeddingQueryVectorBuilder queryVectorBuilder;

AssertingClient(ThreadPool threadPool, TextEmbeddingQueryVectorBuilder queryVectorBuilder) {
super(threadPool);
this.queryVectorBuilder = queryVectorBuilder;
}

@Override
@SuppressWarnings("unchecked")
protected <Request extends ActionRequest, Response extends ActionResponse> void doExecute(
ActionType<Response> action,
Request request,
ActionListener<Response> listener
) {
listener.onResponse((Response) createResponse(queryVectorBuilder.getModelId()));
}
}

public void testReceiving_TextExpansionResults_ThrowsBadRequestException() {
var queryVectorBuilder = createTestInstance();

try (var threadPool = createThreadPool()) {
final var client = new AssertingClient(threadPool, queryVectorBuilder);
PlainActionFuture<float[]> future = new PlainActionFuture<>();
queryVectorBuilder.buildVector(client, future);

var exception = expectThrows(IllegalArgumentException.class, () -> future.actionGet(TimeValue.timeValueSeconds(30)));
assertThat(
exception.getMessage(),
containsString("expected a result of type [text_embedding_result] received [text_expansion_result]")
);
}
}

private static ActionResponse createResponse(String modelId) {
return new InferModelAction.Response(
List.of(new TextExpansionResults("foo", List.of(new WeightedToken("toke", 0.1f)), randomBoolean())),
modelId,
true
);
}

private static TextEmbeddingQueryVectorBuilder createTestInstance() {
return new TextEmbeddingQueryVectorBuilder(randomAlphaOfLength(4), randomAlphaOfLength(4));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ protected void doAssertClientRequest(ActionRequest request, TextEmbeddingQueryVe
assertEquals(CoordinatedInferenceAction.Request.RequestModelType.NLP_MODEL, inferRequest.getRequestModelType());
}

@Override
public ActionResponse createResponse(float[] array, TextEmbeddingQueryVectorBuilder builder) {
double[] embedding = new double[array.length];
for (int i = 0; i < embedding.length; i++) {
Expand Down