Skip to content

Commit 23bacd3

Browse files
[Backport 2.16] add initial MLInferenceSearchResponseProcessor (#2734)
* add initial MLInferenceSearchResponseProcessor (#2688) * add MLInferenceSearchResponseProcessor Signed-off-by: Mingshi Liu <[email protected]> * add ITs Signed-off-by: Mingshi Liu <[email protected]> * add code coverage Signed-off-by: Mingshi Liu <[email protected]> * add many_to_one flag Signed-off-by: Mingshi Liu <[email protected]> * avoid import * Signed-off-by: Mingshi Liu <[email protected]> * remove extra hits Signed-off-by: Mingshi Liu <[email protected]> * spotlessApply Signed-off-by: Mingshi Liu <[email protected]> * remove extra brackets Signed-off-by: Mingshi Liu <[email protected]> --------- Signed-off-by: Mingshi Liu <[email protected]> (cherry picked from commit 01084b4) * fix http package Signed-off-by: Mingshi Liu <[email protected]> --------- Signed-off-by: Mingshi Liu <[email protected]> Co-authored-by: Mingshi Liu <[email protected]>
1 parent ca6bbe7 commit 23bacd3

File tree

10 files changed

+2871
-2
lines changed

10 files changed

+2871
-2
lines changed

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@
214214
import org.opensearch.ml.model.MLModelManager;
215215
import org.opensearch.ml.processor.MLInferenceIngestProcessor;
216216
import org.opensearch.ml.processor.MLInferenceSearchRequestProcessor;
217+
import org.opensearch.ml.processor.MLInferenceSearchResponseProcessor;
217218
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList;
218219
import org.opensearch.ml.rest.RestMLCreateConnectorAction;
219220
import org.opensearch.ml.rest.RestMLCreateControllerAction;
@@ -996,6 +997,12 @@ public Map<String, Processor.Factory<SearchResponseProcessor>> getResponseProces
996997
new GenerativeQAResponseProcessor.Factory(this.client, () -> this.ragSearchPipelineEnabled)
997998
);
998999

1000+
responseProcessors
1001+
.put(
1002+
MLInferenceSearchResponseProcessor.TYPE,
1003+
new MLInferenceSearchResponseProcessor.Factory(parameters.client, parameters.namedXContentRegistry)
1004+
);
1005+
9991006
return responseProcessors;
10001007
}
10011008

plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java

Lines changed: 674 additions & 0 deletions
Large diffs are not rendered by default.

plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,15 @@ default String toString(Object originalFieldValue) {
281281
return StringUtils.toJson(originalFieldValue);
282282
}
283283

284+
default boolean hasField(Object json, String path) {
285+
Object value = JsonPath.using(suppressExceptionConfiguration).parse(json).read(path);
286+
287+
if (value != null) {
288+
return true;
289+
}
290+
return false;
291+
}
292+
284293
/**
285294
* Writes a new dot path for a nested object within the given JSON object.
286295
* This method is useful when dealing with arrays or nested objects in the JSON structure.
@@ -321,5 +330,4 @@ default List<String> writeNewDotPathForNestedObject(Object json, String dotPath)
321330
default String convertToDotPath(String path) {
322331
return path.replaceAll("\\[(\\d+)\\]", "$1\\.").replaceAll("\\['(.*?)']", "$1\\.").replaceAll("^\\$", "").replaceAll("\\.$", "");
323332
}
324-
325333
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.utils;
7+
8+
import java.util.HashMap;
9+
import java.util.Map;
10+
11+
public class MapUtils {
12+
13+
/**
14+
* Increments the counter for the given key in the specified version.
15+
* If the key doesn't exist, it initializes the counter to 0.
16+
*
17+
* @param version the version of the counter
18+
* @param key the key for which the counter needs to be incremented
19+
*/
20+
public static void incrementCounter(Map<Integer, Map<String, Integer>> versionedCounters, int version, String key) {
21+
Map<String, Integer> counters = versionedCounters.computeIfAbsent(version, k -> new HashMap<>());
22+
counters.put(key, counters.getOrDefault(key, -1) + 1);
23+
}
24+
25+
/**
26+
* Retrieves the counter value for the given key in the specified version.
27+
* If the key doesn't exist, it returns 0.
28+
*
29+
* @param version the version of the counter
30+
* @param key the key for which the counter needs to be retrieved
31+
* @return the counter value for the given key
32+
*/
33+
public static int getCounter(Map<Integer, Map<String, Integer>> versionedCounters, int version, String key) {
34+
Map<String, Integer> counters = versionedCounters.get(version);
35+
return counters != null ? counters.getOrDefault(key, -1) : 0;
36+
}
37+
38+
/**
39+
* Increments the counter value for the given key in the provided counters map.
40+
* If the key does not exist in the map, it is added with an initial counter value of 0.
41+
*
42+
* @param counters A map that stores integer counters for each integer key.
43+
* @param key The integer key for which the counter needs to be incremented.
44+
*/
45+
public static void incrementCounter(Map<Integer, Integer> counters, int key) {
46+
counters.put(key, counters.getOrDefault(key, 0) + 1);
47+
}
48+
49+
public static int getCounter(Map<Integer, Integer> counters, int key) {
50+
return counters.getOrDefault(key, 0);
51+
}
52+
53+
}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.utils;
7+
8+
import org.opensearch.action.search.SearchResponse;
9+
import org.opensearch.action.search.SearchResponseSections;
10+
import org.opensearch.search.SearchHit;
11+
import org.opensearch.search.SearchHits;
12+
import org.opensearch.search.aggregations.InternalAggregations;
13+
import org.opensearch.search.internal.InternalSearchResponse;
14+
import org.opensearch.search.profile.SearchProfileShardResults;
15+
16+
public class SearchResponseUtil {
17+
private SearchResponseUtil() {}
18+
19+
/**
20+
* Construct a new {@link SearchResponse} based on an existing one, replacing just the {@link SearchHits}.
21+
* @param newHits new {@link SearchHits}
22+
* @param response the existing search response
23+
* @return a new search response where the {@link SearchHits} has been replaced
24+
*/
25+
public static SearchResponse replaceHits(SearchHits newHits, SearchResponse response) {
26+
SearchResponseSections searchResponseSections;
27+
if (response.getAggregations() == null || response.getAggregations() instanceof InternalAggregations) {
28+
// We either have no aggregations, or we have Writeable InternalAggregations.
29+
// Either way, we can produce a Writeable InternalSearchResponse.
30+
searchResponseSections = new InternalSearchResponse(
31+
newHits,
32+
(InternalAggregations) response.getAggregations(),
33+
response.getSuggest(),
34+
new SearchProfileShardResults(response.getProfileResults()),
35+
response.isTimedOut(),
36+
response.isTerminatedEarly(),
37+
response.getNumReducePhases()
38+
);
39+
} else {
40+
// We have non-Writeable Aggregations, so the whole SearchResponseSections is non-Writeable.
41+
searchResponseSections = new SearchResponseSections(
42+
newHits,
43+
response.getAggregations(),
44+
response.getSuggest(),
45+
response.isTimedOut(),
46+
response.isTerminatedEarly(),
47+
new SearchProfileShardResults(response.getProfileResults()),
48+
response.getNumReducePhases()
49+
);
50+
}
51+
52+
return new SearchResponse(
53+
searchResponseSections,
54+
response.getScrollId(),
55+
response.getTotalShards(),
56+
response.getSuccessfulShards(),
57+
response.getSkippedShards(),
58+
response.getTook().millis(),
59+
response.getShardFailures(),
60+
response.getClusters(),
61+
response.pointInTimeId()
62+
);
63+
}
64+
65+
/**
66+
* Convenience method when only replacing the {@link SearchHit} array within the {@link SearchHits} in a {@link SearchResponse}.
67+
* @param newHits the new array of {@link SearchHit} elements.
68+
* @param response the search response to update
69+
* @return a {@link SearchResponse} where the underlying array of {@link SearchHit} within the {@link SearchHits} has been replaced.
70+
*/
71+
public static SearchResponse replaceHits(SearchHit[] newHits, SearchResponse response) {
72+
if (response.getHits() == null) {
73+
throw new IllegalStateException("Response must have hits");
74+
}
75+
SearchHits searchHits = new SearchHits(
76+
newHits,
77+
response.getHits().getTotalHits(),
78+
response.getHits().getMaxScore(),
79+
response.getHits().getSortFields(),
80+
response.getHits().getCollapseField(),
81+
response.getHits().getCollapseValues()
82+
);
83+
return replaceHits(searchHits, response);
84+
}
85+
}

plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.opensearch.ml.common.spi.tools.Tool;
4040
import org.opensearch.ml.engine.tools.MLModelTool;
4141
import org.opensearch.ml.processor.MLInferenceSearchRequestProcessor;
42+
import org.opensearch.ml.processor.MLInferenceSearchResponseProcessor;
4243
import org.opensearch.plugins.ExtensiblePlugin;
4344
import org.opensearch.plugins.SearchPipelinePlugin;
4445
import org.opensearch.plugins.SearchPlugin;
@@ -85,10 +86,11 @@ public void testGetRequestProcessors() {
8586
public void testGetResponseProcessors() {
8687
SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class);
8788
Map<String, ?> responseProcessors = plugin.getResponseProcessors(parameters);
88-
assertEquals(1, responseProcessors.size());
89+
assertEquals(2, responseProcessors.size());
8990
assertTrue(
9091
responseProcessors.get(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE) instanceof GenerativeQAResponseProcessor.Factory
9192
);
93+
assertTrue(responseProcessors.get(MLInferenceSearchResponseProcessor.TYPE) instanceof MLInferenceSearchResponseProcessor.Factory);
9294
}
9395

9496
@Test

0 commit comments

Comments
 (0)