Skip to content

Commit 2b3a991

Browse files
committed
Updated the sparevector similar to Semantic query builder
1 parent a7f26f3 commit 2b3a991

File tree

1 file changed

+53
-37
lines changed

1 file changed

+53
-37
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java

Lines changed: 53 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,16 @@
2525
import org.elasticsearch.index.query.QueryRewriteContext;
2626
import org.elasticsearch.index.query.SearchExecutionContext;
2727
import org.elasticsearch.inference.InferenceResults;
28+
import org.elasticsearch.inference.InputType;
29+
import org.elasticsearch.inference.TaskType;
2830
import org.elasticsearch.inference.WeightedToken;
2931
import org.elasticsearch.xcontent.ConstructingObjectParser;
3032
import org.elasticsearch.xcontent.ParseField;
3133
import org.elasticsearch.xcontent.XContentBuilder;
3234
import org.elasticsearch.xcontent.XContentParser;
33-
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
34-
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
35+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
3536
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
3637
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
37-
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate;
3838

3939
import java.io.IOException;
4040
import java.util.ArrayList;
@@ -272,60 +272,76 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
272272
throw new IllegalArgumentException("inference_id required to perform vector search on query string");
273273
}
274274

275-
CoordinatedInferenceAction.Request inferRequest = CoordinatedInferenceAction.Request.forTextInput(
275+
InferenceAction.Request inferenceRequest = new InferenceAction.Request(
276+
TaskType.ANY,
276277
inferenceId,
278+
null,
279+
null,
280+
null,
277281
List.of(query),
278-
TextExpansionConfigUpdate.EMPTY_UPDATE,
279-
false,
280-
null
282+
Map.of(),
283+
InputType.INTERNAL_SEARCH,
284+
null,
285+
false
281286
);
282-
inferRequest.setHighPriority(true);
283-
inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH);
284287

285288
SetOnce<TextExpansionResults> textExpansionResultsSupplier = new SetOnce<>();
286289
queryRewriteContext.registerAsyncAction(
287290
(client, listener) -> executeAsyncWithOrigin(
288291
client,
289292
ML_ORIGIN,
290-
CoordinatedInferenceAction.INSTANCE,
291-
inferRequest,
293+
InferenceAction.INSTANCE,
294+
inferenceRequest,
292295
ActionListener.wrap(inferenceResponse -> {
293-
294-
List<InferenceResults> inferenceResults = inferenceResponse.getInferenceResults();
295-
if (inferenceResults.isEmpty()) {
296-
listener.onFailure(new IllegalStateException("inference response contain no results"));
297-
return;
298-
}
299-
if (inferenceResults.size() > 1) {
300-
listener.onFailure(new IllegalStateException("inference response should contain only one result"));
296+
List<? extends InferenceResults> inferenceResults = inferenceResponse.getResults().transformToCoordinationFormat();
297+
TextExpansionResults textExpansionResults;
298+
try {
299+
textExpansionResults = validateAndExtractTextExpansionResults(inferenceResults, inferenceId);
300+
} catch (Exception e) {
301+
listener.onFailure(e);
301302
return;
302303
}
303304

304-
if (inferenceResults.get(0) instanceof TextExpansionResults textExpansionResults) {
305-
textExpansionResultsSupplier.set(textExpansionResults);
306-
listener.onResponse(null);
307-
} else if (inferenceResults.get(0) instanceof WarningInferenceResults warning) {
308-
listener.onFailure(new IllegalStateException(warning.getWarning()));
309-
} else {
310-
listener.onFailure(
311-
new IllegalArgumentException(
312-
"expected a result of type ["
313-
+ TextExpansionResults.NAME
314-
+ "] received ["
315-
+ inferenceResults.get(0).getWriteableName()
316-
+ "]. Is ["
317-
+ inferenceId
318-
+ "] a compatible model?"
319-
)
320-
);
321-
}
305+
textExpansionResultsSupplier.set(textExpansionResults);
306+
listener.onResponse(null);
322307
}, listener::onFailure)
323308
)
324309
);
325310

326311
return new SparseVectorQueryBuilder(this, textExpansionResultsSupplier);
327312
}
328313

314+
private static TextExpansionResults validateAndExtractTextExpansionResults(
315+
List<? extends InferenceResults> inferenceResults,
316+
String inferenceId
317+
) {
318+
if (inferenceResults.isEmpty()) {
319+
throw new IllegalStateException("inference response contain no results");
320+
}
321+
if (inferenceResults.size() > 1) {
322+
throw new IllegalStateException("inference response should contain only one result");
323+
}
324+
325+
InferenceResults result = inferenceResults.getFirst();
326+
if (result instanceof TextExpansionResults textExpansionResults) {
327+
return textExpansionResults;
328+
}
329+
330+
if (result instanceof WarningInferenceResults warning) {
331+
throw new IllegalStateException(warning.getWarning());
332+
}
333+
334+
throw new IllegalArgumentException(
335+
"expected a result of type ["
336+
+ TextExpansionResults.NAME
337+
+ "] received ["
338+
+ result.getWriteableName()
339+
+ "]. Is ["
340+
+ inferenceId
341+
+ "] a compatible model?"
342+
);
343+
}
344+
329345
@Override
330346
protected boolean doEquals(SparseVectorQueryBuilder other) {
331347
return Objects.equals(fieldName, other.fieldName)

0 commit comments

Comments
 (0)