|
20 | 20 | import java.util.HashSet; |
21 | 21 | import java.util.List; |
22 | 22 | import java.util.Map; |
| 23 | +import java.util.Objects; |
23 | 24 | import java.util.Set; |
24 | 25 | import java.util.concurrent.atomic.AtomicBoolean; |
25 | 26 |
|
@@ -84,6 +85,9 @@ public class MLInferenceSearchResponseProcessor extends AbstractProcessor implem |
84 | 85 | // it can be overwritten using max_prediction_tasks when creating processor |
85 | 86 | public static final int DEFAULT_MAX_PREDICTION_TASKS = 10; |
86 | 87 | public static final String DEFAULT_OUTPUT_FIELD_NAME = "inference_results"; |
| 88 | + // allow to write to the extension of the search response, the path to point to search extension |
| 89 | + // is prefix with ext.ml_inference |
| 90 | + public static final String EXTENSION_PREFIX = "ext.ml_inference"; |
87 | 91 |
|
88 | 92 | protected MLInferenceSearchResponseProcessor( |
89 | 93 | String modelId, |
@@ -158,7 +162,28 @@ public void processResponseAsync( |
158 | 162 |
|
159 | 163 | // if many to one, run rewriteResponseDocuments |
160 | 164 | if (!oneToOne) { |
161 | | - rewriteResponseDocuments(response, responseListener); |
| 165 | + // use MLInferenceSearchResponseProcessor to allow writing to extension |
| 166 | + // check if the search response is in the type of MLInferenceSearchResponse |
| 167 | + // if not, initiate a new one MLInferenceSearchResponse |
| 168 | + MLInferenceSearchResponse mlInferenceSearchResponse; |
| 169 | + |
| 170 | + if (response instanceof MLInferenceSearchResponse) { |
| 171 | + mlInferenceSearchResponse = (MLInferenceSearchResponse) response; |
| 172 | + } else { |
| 173 | + mlInferenceSearchResponse = new MLInferenceSearchResponse( |
| 174 | + null, |
| 175 | + response.getInternalResponse(), |
| 176 | + response.getScrollId(), |
| 177 | + response.getTotalShards(), |
| 178 | + response.getSuccessfulShards(), |
| 179 | + response.getSkippedShards(), |
| 180 | + response.getSuccessfulShards(), |
| 181 | + response.getShardFailures(), |
| 182 | + response.getClusters() |
| 183 | + ); |
| 184 | + } |
| 185 | + |
| 186 | + rewriteResponseDocuments(mlInferenceSearchResponse, responseListener); |
162 | 187 | } else { |
163 | 188 | // if one to one, make one hit search response and run rewriteResponseDocuments |
164 | 189 | GroupedActionListener<SearchResponse> combineResponseListener = getCombineResponseGroupedActionListener( |
@@ -545,22 +570,37 @@ public void onResponse(Map<Integer, MLOutput> multipleMLOutputs) { |
545 | 570 | } else { |
546 | 571 | modelOutputValuePerDoc = modelOutputValue; |
547 | 572 | } |
548 | | - |
549 | | - if (sourceAsMap.containsKey(newDocumentFieldName)) { |
550 | | - if (override) { |
551 | | - sourceAsMapWithInference.remove(newDocumentFieldName); |
552 | | - sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc); |
| 573 | + // writing to search response extension |
| 574 | + if (newDocumentFieldName.startsWith(EXTENSION_PREFIX)) { |
| 575 | + Map<String, Object> params = ((MLInferenceSearchResponse) response).getParams(); |
| 576 | + String paramsName = newDocumentFieldName.replaceFirst(EXTENSION_PREFIX + ".", ""); |
| 577 | + |
| 578 | + if (params != null) { |
| 579 | + params.put(paramsName, modelOutputValuePerDoc); |
| 580 | + ((MLInferenceSearchResponse) response).setParams(params); |
553 | 581 | } else { |
554 | | - logger |
555 | | - .debug( |
556 | | - "{} already exists in the search response hit. Skip processing this field.", |
557 | | - newDocumentFieldName |
558 | | - ); |
559 | | - // TODO when the response has the same field name, should it throw exception? currently, |
560 | | - // ingest processor quietly skip it |
| 582 | + Map<String, Object> newParams = new HashMap<>(); |
| 583 | + newParams.put(paramsName, modelOutputValuePerDoc); |
| 584 | + ((MLInferenceSearchResponse) response).setParams(newParams); |
561 | 585 | } |
562 | 586 | } else { |
563 | | - sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc); |
| 587 | + // writing to search response hits |
| 588 | + if (sourceAsMap.containsKey(newDocumentFieldName)) { |
| 589 | + if (override) { |
| 590 | + sourceAsMapWithInference.remove(newDocumentFieldName); |
| 591 | + sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc); |
| 592 | + } else { |
| 593 | + logger |
| 594 | + .debug( |
| 595 | + "{} already exists in the search response hit. Skip processing this field.", |
| 596 | + newDocumentFieldName |
| 597 | + ); |
| 598 | + // TODO when the response has the same field name, should it throw exception? currently, |
| 599 | + // ingest processor quietly skip it |
| 600 | + } |
| 601 | + } else { |
| 602 | + sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc); |
| 603 | + } |
564 | 604 | } |
565 | 605 | } |
566 | 606 | } |
@@ -774,6 +814,19 @@ public MLInferenceSearchResponseProcessor create( |
774 | 814 | + ". Please adjust mappings." |
775 | 815 | ); |
776 | 816 | } |
| 817 | + boolean writeToSearchExtension = false; |
| 818 | + |
| 819 | + if (outputMaps != null) { |
| 820 | + writeToSearchExtension = outputMaps |
| 821 | + .stream() |
| 822 | + .filter(Objects::nonNull) // To avoid potential NullPointerExceptions from null outputMaps |
| 823 | + .flatMap(outputMap -> outputMap.keySet().stream()) |
| 824 | + .anyMatch(key -> key.startsWith(EXTENSION_PREFIX)); |
| 825 | + } |
| 826 | + |
| 827 | + if (writeToSearchExtension & oneToOne) { |
| 828 | + throw new IllegalArgumentException("Write model response to search extension does not support when one_to_one is true."); |
| 829 | + } |
777 | 830 |
|
778 | 831 | return new MLInferenceSearchResponseProcessor( |
779 | 832 | modelId, |
|
0 commit comments