Skip to content

Commit 09aa6ea

Browse files
Support ML Inference Search Processor Writing to Search Extension (#3061) (#3127)
(cherry picked from commit d9a56cf) Co-authored-by: Mingshi Liu <[email protected]>
1 parent 43dcefc commit 09aa6ea

File tree

4 files changed

+714
-19
lines changed

4 files changed

+714
-19
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.ml.processor;
6+
7+
import java.io.IOException;
8+
import java.util.Map;
9+
10+
import org.opensearch.action.search.SearchResponse;
11+
import org.opensearch.action.search.SearchResponseSections;
12+
import org.opensearch.action.search.ShardSearchFailure;
13+
import org.opensearch.core.xcontent.XContentBuilder;
14+
15+
public class MLInferenceSearchResponse extends SearchResponse {
16+
private static final String EXT_SECTION_NAME = "ext";
17+
18+
private Map<String, Object> params;
19+
20+
public MLInferenceSearchResponse(
21+
Map<String, Object> params,
22+
SearchResponseSections internalResponse,
23+
String scrollId,
24+
int totalShards,
25+
int successfulShards,
26+
int skippedShards,
27+
long tookInMillis,
28+
ShardSearchFailure[] shardFailures,
29+
Clusters clusters
30+
) {
31+
super(internalResponse, scrollId, totalShards, successfulShards, skippedShards, tookInMillis, shardFailures, clusters);
32+
this.params = params;
33+
}
34+
35+
public void setParams(Map<String, Object> params) {
36+
this.params = params;
37+
}
38+
39+
public Map<String, Object> getParams() {
40+
return this.params;
41+
}
42+
43+
@Override
44+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
45+
builder.startObject();
46+
innerToXContent(builder, params);
47+
48+
if (this.params != null) {
49+
builder.startObject(EXT_SECTION_NAME);
50+
builder.field(MLInferenceSearchResponseProcessor.TYPE, this.params);
51+
52+
builder.endObject();
53+
}
54+
builder.endObject();
55+
return builder;
56+
}
57+
}

plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java

Lines changed: 67 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.HashSet;
2121
import java.util.List;
2222
import java.util.Map;
23+
import java.util.Objects;
2324
import java.util.Set;
2425
import java.util.concurrent.atomic.AtomicBoolean;
2526

@@ -84,6 +85,9 @@ public class MLInferenceSearchResponseProcessor extends AbstractProcessor implem
8485
// it can be overwritten using max_prediction_tasks when creating processor
8586
public static final int DEFAULT_MAX_PREDICTION_TASKS = 10;
8687
public static final String DEFAULT_OUTPUT_FIELD_NAME = "inference_results";
88+
// allow to write to the extension of the search response, the path to point to search extension
89+
// is prefix with ext.ml_inference
90+
public static final String EXTENSION_PREFIX = "ext.ml_inference";
8791

8892
protected MLInferenceSearchResponseProcessor(
8993
String modelId,
@@ -158,7 +162,28 @@ public void processResponseAsync(
158162

159163
// if many to one, run rewriteResponseDocuments
160164
if (!oneToOne) {
161-
rewriteResponseDocuments(response, responseListener);
165+
// use MLInferenceSearchResponseProcessor to allow writing to extension
166+
// check if the search response is in the type of MLInferenceSearchResponse
167+
// if not, initiate a new one MLInferenceSearchResponse
168+
MLInferenceSearchResponse mlInferenceSearchResponse;
169+
170+
if (response instanceof MLInferenceSearchResponse) {
171+
mlInferenceSearchResponse = (MLInferenceSearchResponse) response;
172+
} else {
173+
mlInferenceSearchResponse = new MLInferenceSearchResponse(
174+
null,
175+
response.getInternalResponse(),
176+
response.getScrollId(),
177+
response.getTotalShards(),
178+
response.getSuccessfulShards(),
179+
response.getSkippedShards(),
180+
response.getSuccessfulShards(),
181+
response.getShardFailures(),
182+
response.getClusters()
183+
);
184+
}
185+
186+
rewriteResponseDocuments(mlInferenceSearchResponse, responseListener);
162187
} else {
163188
// if one to one, make one hit search response and run rewriteResponseDocuments
164189
GroupedActionListener<SearchResponse> combineResponseListener = getCombineResponseGroupedActionListener(
@@ -545,22 +570,37 @@ public void onResponse(Map<Integer, MLOutput> multipleMLOutputs) {
545570
} else {
546571
modelOutputValuePerDoc = modelOutputValue;
547572
}
548-
549-
if (sourceAsMap.containsKey(newDocumentFieldName)) {
550-
if (override) {
551-
sourceAsMapWithInference.remove(newDocumentFieldName);
552-
sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc);
573+
// writing to search response extension
574+
if (newDocumentFieldName.startsWith(EXTENSION_PREFIX)) {
575+
Map<String, Object> params = ((MLInferenceSearchResponse) response).getParams();
576+
String paramsName = newDocumentFieldName.replaceFirst(EXTENSION_PREFIX + ".", "");
577+
578+
if (params != null) {
579+
params.put(paramsName, modelOutputValuePerDoc);
580+
((MLInferenceSearchResponse) response).setParams(params);
553581
} else {
554-
logger
555-
.debug(
556-
"{} already exists in the search response hit. Skip processing this field.",
557-
newDocumentFieldName
558-
);
559-
// TODO when the response has the same field name, should it throw exception? currently,
560-
// ingest processor quietly skip it
582+
Map<String, Object> newParams = new HashMap<>();
583+
newParams.put(paramsName, modelOutputValuePerDoc);
584+
((MLInferenceSearchResponse) response).setParams(newParams);
561585
}
562586
} else {
563-
sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc);
587+
// writing to search response hits
588+
if (sourceAsMap.containsKey(newDocumentFieldName)) {
589+
if (override) {
590+
sourceAsMapWithInference.remove(newDocumentFieldName);
591+
sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc);
592+
} else {
593+
logger
594+
.debug(
595+
"{} already exists in the search response hit. Skip processing this field.",
596+
newDocumentFieldName
597+
);
598+
// TODO when the response has the same field name, should it throw exception? currently,
599+
// ingest processor quietly skip it
600+
}
601+
} else {
602+
sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc);
603+
}
564604
}
565605
}
566606
}
@@ -774,6 +814,19 @@ public MLInferenceSearchResponseProcessor create(
774814
+ ". Please adjust mappings."
775815
);
776816
}
817+
boolean writeToSearchExtension = false;
818+
819+
if (outputMaps != null) {
820+
writeToSearchExtension = outputMaps
821+
.stream()
822+
.filter(Objects::nonNull) // To avoid potential NullPointerExceptions from null outputMaps
823+
.flatMap(outputMap -> outputMap.keySet().stream())
824+
.anyMatch(key -> key.startsWith(EXTENSION_PREFIX));
825+
}
826+
827+
if (writeToSearchExtension & oneToOne) {
828+
throw new IllegalArgumentException("Write model response to search extension does not support when one_to_one is true.");
829+
}
777830

778831
return new MLInferenceSearchResponseProcessor(
779832
modelId,

0 commit comments

Comments
 (0)