Skip to content

Commit b5ef793

Browse files
committed
Moved to core
1 parent 3c481ca commit b5ef793

File tree

2 files changed

+30
-29
lines changed

2 files changed

+30
-29
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/search/vectors/SparseVectorQueryBuilder.java

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,16 @@
2525
import org.elasticsearch.index.query.QueryRewriteContext;
2626
import org.elasticsearch.index.query.SearchExecutionContext;
2727
import org.elasticsearch.inference.InferenceResults;
28+
import org.elasticsearch.inference.InferenceServiceResults;
29+
import org.elasticsearch.inference.TaskType;
2830
import org.elasticsearch.inference.WeightedToken;
2931
import org.elasticsearch.xcontent.ConstructingObjectParser;
3032
import org.elasticsearch.xcontent.ParseField;
3133
import org.elasticsearch.xcontent.XContentBuilder;
3234
import org.elasticsearch.xcontent.XContentParser;
33-
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
34-
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
35+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
3536
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
3637
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
37-
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate;
3838

3939
import java.io.IOException;
4040
import java.util.ArrayList;
@@ -44,7 +44,7 @@
4444

4545
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
4646
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
47-
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
47+
import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN;
4848
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
4949

5050
public class SparseVectorQueryBuilder extends AbstractQueryBuilder<SparseVectorQueryBuilder> {
@@ -272,27 +272,29 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
272272
throw new IllegalArgumentException("inference_id required to perform vector search on query string");
273273
}
274274

275-
// TODO move this to xpack core and use inference APIs
276-
CoordinatedInferenceAction.Request inferRequest = CoordinatedInferenceAction.Request.forTextInput(
275+
InferenceAction.Request inferRequest = new InferenceAction.Request(
276+
TaskType.SPARSE_EMBEDDING,
277277
inferenceId,
278-
List.of(query),
279-
TextExpansionConfigUpdate.EMPTY_UPDATE,
280-
false,
281-
null
278+
null, // query field (not needed for sparse embedding)
279+
null, // returnDocuments (not needed)
280+
null, // topN (not needed)
281+
List.of(query), // input text
282+
Map.of(), // taskSettings (empty for now)
283+
null, // input type not allowed for sparse_embedding task type
284+
null, // timeout (use default)
285+
false // not streaming
282286
);
283-
inferRequest.setHighPriority(true);
284-
inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH);
285287

286288
SetOnce<TextExpansionResults> textExpansionResultsSupplier = new SetOnce<>();
287289
queryRewriteContext.registerAsyncAction(
288290
(client, listener) -> executeAsyncWithOrigin(
289291
client,
290-
ML_ORIGIN,
291-
CoordinatedInferenceAction.INSTANCE,
292+
INFERENCE_ORIGIN,
293+
InferenceAction.INSTANCE,
292294
inferRequest,
293295
ActionListener.wrap(inferenceResponse -> {
294296

295-
List<InferenceResults> inferenceResults = inferenceResponse.getInferenceResults();
297+
List<? extends InferenceResults> inferenceResults = inferenceResponse.getResults().transformToCoordinationFormat();
296298
if (inferenceResults.isEmpty()) {
297299
listener.onFailure(new IllegalStateException("inference response contain no results"));
298300
return;

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/search/vectors/SparseVectorQueryBuilderTests.java

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,9 @@
3939
import org.elasticsearch.test.AbstractQueryTestCase;
4040
import org.elasticsearch.test.index.IndexVersionUtils;
4141
import org.elasticsearch.xpack.core.XPackClientPlugin;
42-
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
43-
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
44-
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
45-
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
42+
import org.elasticsearch.inference.TaskType;
43+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
44+
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
4645

4746
import java.io.IOException;
4847
import java.lang.reflect.Method;
@@ -118,15 +117,14 @@ protected Settings createTestIndexSettings() {
118117
@Override
119118
protected boolean canSimulateMethod(Method method, Object[] args) throws NoSuchMethodException {
120119
return method.equals(Client.class.getMethod("execute", ActionType.class, ActionRequest.class, ActionListener.class))
121-
&& (args[0] instanceof CoordinatedInferenceAction);
120+
&& (args[0] instanceof InferenceAction);
122121
}
123122

124123
@Override
125124
protected Object simulateMethod(Method method, Object[] args) {
126-
CoordinatedInferenceAction.Request request = (CoordinatedInferenceAction.Request) args[1];
127-
assertNull(request.getInferenceTimeout());
128-
assertEquals(TrainedModelPrefixStrings.PrefixType.SEARCH, request.getPrefixType());
129-
assertEquals(CoordinatedInferenceAction.Request.RequestModelType.NLP_MODEL, request.getRequestModelType());
125+
InferenceAction.Request request = (InferenceAction.Request) args[1];
126+
assertEquals(TaskType.SPARSE_EMBEDDING, request.getTaskType());
127+
assertNull(request.getInputType()); // Should be null for sparse_embedding
130128

131129
// Randomisation cannot be used here as {@code #doAssertLuceneQuery}
132130
// asserts that 2 rewritten queries are the same
@@ -135,12 +133,13 @@ protected Object simulateMethod(Method method, Object[] args) {
135133
tokens.add(new WeightedToken(Integer.toString(i), (i + 1) * 1.0f));
136134
}
137135

138-
var response = InferModelAction.Response.builder()
139-
.setId(request.getModelId())
140-
.addInferenceResults(List.of(new TextExpansionResults("foo", tokens, randomBoolean())))
141-
.build();
136+
var embeddings = List.of(
137+
new SparseEmbeddingResults.Embedding(tokens, randomBoolean())
138+
);
139+
var results = new SparseEmbeddingResults(embeddings);
140+
var response = new InferenceAction.Response(results);
142141
@SuppressWarnings("unchecked") // We matched the method above.
143-
ActionListener<InferModelAction.Response> listener = (ActionListener<InferModelAction.Response>) args[2];
142+
ActionListener<InferenceAction.Response> listener = (ActionListener<InferenceAction.Response>) args[2];
144143
listener.onResponse(response);
145144
return null;
146145
}

0 commit comments

Comments
 (0)