|
25 | 25 | import org.elasticsearch.index.query.QueryRewriteContext; |
26 | 26 | import org.elasticsearch.index.query.SearchExecutionContext; |
27 | 27 | import org.elasticsearch.inference.InferenceResults; |
| 28 | +import org.elasticsearch.inference.InferenceServiceResults; |
| 29 | +import org.elasticsearch.inference.TaskType; |
28 | 30 | import org.elasticsearch.inference.WeightedToken; |
29 | 31 | import org.elasticsearch.xcontent.ConstructingObjectParser; |
30 | 32 | import org.elasticsearch.xcontent.ParseField; |
31 | 33 | import org.elasticsearch.xcontent.XContentBuilder; |
32 | 34 | import org.elasticsearch.xcontent.XContentParser; |
33 | | -import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; |
34 | | -import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; |
| 35 | +import org.elasticsearch.xpack.core.inference.action.InferenceAction; |
35 | 36 | import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; |
36 | 37 | import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; |
37 | | -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; |
38 | 38 |
|
39 | 39 | import java.io.IOException; |
40 | 40 | import java.util.ArrayList; |
|
44 | 44 |
|
45 | 45 | import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; |
46 | 46 | import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; |
47 | | -import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; |
| 47 | +import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; |
48 | 48 | import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; |
49 | 49 |
|
50 | 50 | public class SparseVectorQueryBuilder extends AbstractQueryBuilder<SparseVectorQueryBuilder> { |
@@ -272,27 +272,29 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { |
272 | 272 | throw new IllegalArgumentException("inference_id required to perform vector search on query string"); |
273 | 273 | } |
274 | 274 |
|
275 | | - // TODO move this to xpack core and use inference APIs |
276 | | - CoordinatedInferenceAction.Request inferRequest = CoordinatedInferenceAction.Request.forTextInput( |
| 275 | + InferenceAction.Request inferRequest = new InferenceAction.Request( |
| 276 | + TaskType.SPARSE_EMBEDDING, |
277 | 277 | inferenceId, |
278 | | - List.of(query), |
279 | | - TextExpansionConfigUpdate.EMPTY_UPDATE, |
280 | | - false, |
281 | | - null |
| 278 | + null, // query field (not needed for sparse embedding) |
| 279 | + null, // returnDocuments (not needed) |
| 280 | + null, // topN (not needed) |
| 281 | + List.of(query), // input text |
| 282 | + Map.of(), // taskSettings (empty for now) |
| 283 | + null, // input type not allowed for sparse_embedding task type |
| 284 | + null, // timeout (use default) |
| 285 | + false // not streaming |
282 | 286 | ); |
283 | | - inferRequest.setHighPriority(true); |
284 | | - inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH); |
285 | 287 |
|
286 | 288 | SetOnce<TextExpansionResults> textExpansionResultsSupplier = new SetOnce<>(); |
287 | 289 | queryRewriteContext.registerAsyncAction( |
288 | 290 | (client, listener) -> executeAsyncWithOrigin( |
289 | 291 | client, |
290 | | - ML_ORIGIN, |
291 | | - CoordinatedInferenceAction.INSTANCE, |
| 292 | + INFERENCE_ORIGIN, |
| 293 | + InferenceAction.INSTANCE, |
292 | 294 | inferRequest, |
293 | 295 | ActionListener.wrap(inferenceResponse -> { |
294 | 296 |
|
295 | | - List<InferenceResults> inferenceResults = inferenceResponse.getInferenceResults(); |
| 297 | + List<? extends InferenceResults> inferenceResults = inferenceResponse.getResults().transformToCoordinationFormat(); |
296 | 298 | if (inferenceResults.isEmpty()) { |
297 | 299 | listener.onFailure(new IllegalStateException("inference response contain no results")); |
298 | 300 | return; |
|
0 commit comments