|
25 | 25 | import org.elasticsearch.index.query.QueryRewriteContext; |
26 | 26 | import org.elasticsearch.index.query.SearchExecutionContext; |
27 | 27 | import org.elasticsearch.inference.InferenceResults; |
| 28 | +import org.elasticsearch.inference.InputType; |
| 29 | +import org.elasticsearch.inference.TaskType; |
28 | 30 | import org.elasticsearch.inference.WeightedToken; |
29 | 31 | import org.elasticsearch.xcontent.ConstructingObjectParser; |
30 | 32 | import org.elasticsearch.xcontent.ParseField; |
31 | 33 | import org.elasticsearch.xcontent.XContentBuilder; |
32 | 34 | import org.elasticsearch.xcontent.XContentParser; |
33 | | -import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; |
34 | | -import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; |
| 35 | +import org.elasticsearch.xpack.core.inference.action.InferenceAction; |
35 | 36 | import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; |
36 | 37 | import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; |
37 | | -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; |
38 | 38 |
|
39 | 39 | import java.io.IOException; |
40 | 40 | import java.util.ArrayList; |
@@ -272,60 +272,76 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { |
272 | 272 | throw new IllegalArgumentException("inference_id required to perform vector search on query string"); |
273 | 273 | } |
274 | 274 |
|
275 | | - CoordinatedInferenceAction.Request inferRequest = CoordinatedInferenceAction.Request.forTextInput( |
| 275 | + InferenceAction.Request inferenceRequest = new InferenceAction.Request( |
| 276 | + TaskType.ANY, |
276 | 277 | inferenceId, |
| 278 | + null, |
| 279 | + null, |
| 280 | + null, |
277 | 281 | List.of(query), |
278 | | - TextExpansionConfigUpdate.EMPTY_UPDATE, |
279 | | - false, |
280 | | - null |
| 282 | + Map.of(), |
| 283 | + InputType.INTERNAL_SEARCH, |
| 284 | + null, |
| 285 | + false |
281 | 286 | ); |
282 | | - inferRequest.setHighPriority(true); |
283 | | - inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH); |
284 | 287 |
|
285 | 288 | SetOnce<TextExpansionResults> textExpansionResultsSupplier = new SetOnce<>(); |
286 | 289 | queryRewriteContext.registerAsyncAction( |
287 | 290 | (client, listener) -> executeAsyncWithOrigin( |
288 | 291 | client, |
289 | 292 | ML_ORIGIN, |
290 | | - CoordinatedInferenceAction.INSTANCE, |
291 | | - inferRequest, |
| 293 | + InferenceAction.INSTANCE, |
| 294 | + inferenceRequest, |
292 | 295 | ActionListener.wrap(inferenceResponse -> { |
293 | | - |
294 | | - List<InferenceResults> inferenceResults = inferenceResponse.getInferenceResults(); |
295 | | - if (inferenceResults.isEmpty()) { |
296 | | - listener.onFailure(new IllegalStateException("inference response contain no results")); |
297 | | - return; |
298 | | - } |
299 | | - if (inferenceResults.size() > 1) { |
300 | | - listener.onFailure(new IllegalStateException("inference response should contain only one result")); |
| 296 | + List<? extends InferenceResults> inferenceResults = inferenceResponse.getResults().transformToCoordinationFormat(); |
| 297 | + TextExpansionResults textExpansionResults; |
| 298 | + try { |
| 299 | + textExpansionResults = validateAndExtractTextExpansionResults(inferenceResults, inferenceId); |
| 300 | + } catch (Exception e) { |
| 301 | + listener.onFailure(e); |
301 | 302 | return; |
302 | 303 | } |
303 | 304 |
|
304 | | - if (inferenceResults.get(0) instanceof TextExpansionResults textExpansionResults) { |
305 | | - textExpansionResultsSupplier.set(textExpansionResults); |
306 | | - listener.onResponse(null); |
307 | | - } else if (inferenceResults.get(0) instanceof WarningInferenceResults warning) { |
308 | | - listener.onFailure(new IllegalStateException(warning.getWarning())); |
309 | | - } else { |
310 | | - listener.onFailure( |
311 | | - new IllegalArgumentException( |
312 | | - "expected a result of type [" |
313 | | - + TextExpansionResults.NAME |
314 | | - + "] received [" |
315 | | - + inferenceResults.get(0).getWriteableName() |
316 | | - + "]. Is [" |
317 | | - + inferenceId |
318 | | - + "] a compatible model?" |
319 | | - ) |
320 | | - ); |
321 | | - } |
| 305 | + textExpansionResultsSupplier.set(textExpansionResults); |
| 306 | + listener.onResponse(null); |
322 | 307 | }, listener::onFailure) |
323 | 308 | ) |
324 | 309 | ); |
325 | 310 |
|
326 | 311 | return new SparseVectorQueryBuilder(this, textExpansionResultsSupplier); |
327 | 312 | } |
328 | 313 |
|
| 314 | + private static TextExpansionResults validateAndExtractTextExpansionResults( |
| 315 | + List<? extends InferenceResults> inferenceResults, |
| 316 | + String inferenceId |
| 317 | + ) { |
| 318 | + if (inferenceResults.isEmpty()) { |
| 319 | + throw new IllegalStateException("inference response contain no results"); |
| 320 | + } |
| 321 | + if (inferenceResults.size() > 1) { |
| 322 | + throw new IllegalStateException("inference response should contain only one result"); |
| 323 | + } |
| 324 | + |
| 325 | + InferenceResults result = inferenceResults.getFirst(); |
| 326 | + if (result instanceof TextExpansionResults textExpansionResults) { |
| 327 | + return textExpansionResults; |
| 328 | + } |
| 329 | + |
| 330 | + if (result instanceof WarningInferenceResults warning) { |
| 331 | + throw new IllegalStateException(warning.getWarning()); |
| 332 | + } |
| 333 | + |
| 334 | + throw new IllegalArgumentException( |
| 335 | + "expected a result of type [" |
| 336 | + + TextExpansionResults.NAME |
| 337 | + + "] received [" |
| 338 | + + result.getWriteableName() |
| 339 | + + "]. Is [" |
| 340 | + + inferenceId |
| 341 | + + "] a compatible model?" |
| 342 | + ); |
| 343 | + } |
| 344 | + |
329 | 345 | @Override |
330 | 346 | protected boolean doEquals(SparseVectorQueryBuilder other) { |
331 | 347 | return Objects.equals(fieldName, other.fieldName) |
|
0 commit comments