Skip to content

Commit 4f01193

Browse files
Enable pass query string to input_map in ml inference search response processor (#2899) (#3129)
* enable add query_text to model_config Signed-off-by: Mingshi Liu <[email protected]> * change javadoc Signed-off-by: Mingshi Liu <[email protected]> * add more tests Signed-off-by: Mingshi Liu <[email protected]> * use standard json path config Signed-off-by: Mingshi Liu <[email protected]> * add example in javadoc Signed-off-by: Mingshi Liu <[email protected]> * read query mapping from input_map Signed-off-by: Mingshi Liu <[email protected]> * recognize query mapping by prefix _request. Signed-off-by: Mingshi Liu <[email protected]> --------- Signed-off-by: Mingshi Liu <[email protected]> (cherry picked from commit 083abad) Co-authored-by: Mingshi Liu <[email protected]>
1 parent 09aa6ea commit 4f01193

File tree

5 files changed

+508
-47
lines changed

5 files changed

+508
-47
lines changed

common/build.gradle

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ dependencies {
2626
compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
2727
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
2828
compileOnly group: 'org.json', name: 'json', version: '20231013'
29-
3029
implementation('com.google.guava:guava:32.1.2-jre') {
3130
exclude group: 'com.google.guava', module: 'failureaccess'
3231
exclude group: 'com.google.code.findbugs', module: 'jsr305'

common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import static org.opensearch.ml.common.utils.StringUtils.TO_STRING_FUNCTION_NAME;
1010
import static org.opensearch.ml.common.utils.StringUtils.collectToStringPrefixes;
1111
import static org.opensearch.ml.common.utils.StringUtils.getJsonPath;
12+
import static org.opensearch.ml.common.utils.StringUtils.isValidJSONPath;
1213
import static org.opensearch.ml.common.utils.StringUtils.obtainFieldNameFromJsonPath;
1314
import static org.opensearch.ml.common.utils.StringUtils.parseParameters;
1415
import static org.opensearch.ml.common.utils.StringUtils.toJson;
@@ -457,4 +458,53 @@ public void testGetJsonPath_ValidJsonPathWithoutSource() {
457458
String result = getJsonPath(input);
458459
assertEquals("$.response.body.data[*].embedding", result);
459460
}
461+
462+
@Test
463+
public void testisValidJSONPath_InvalidInputs() {
464+
Assert.assertFalse(isValidJSONPath("..bar"));
465+
Assert.assertFalse(isValidJSONPath("."));
466+
Assert.assertFalse(isValidJSONPath(".."));
467+
Assert.assertFalse(isValidJSONPath("foo.bar."));
468+
Assert.assertFalse(isValidJSONPath(".foo.bar."));
469+
}
470+
471+
@Test
472+
public void testisValidJSONPath_NullInput() {
473+
Assert.assertFalse(isValidJSONPath(null));
474+
}
475+
476+
@Test
477+
public void testisValidJSONPath_EmptyInput() {
478+
Assert.assertFalse(isValidJSONPath(""));
479+
}
480+
481+
@Test
482+
public void testisValidJSONPath_ValidInputs() {
483+
Assert.assertTrue(isValidJSONPath("foo"));
484+
Assert.assertTrue(isValidJSONPath("foo.bar"));
485+
Assert.assertTrue(isValidJSONPath("foo.bar.baz"));
486+
Assert.assertTrue(isValidJSONPath("foo.bar.baz.qux"));
487+
Assert.assertTrue(isValidJSONPath(".foo"));
488+
Assert.assertTrue(isValidJSONPath("$.foo"));
489+
Assert.assertTrue(isValidJSONPath(".foo.bar"));
490+
Assert.assertTrue(isValidJSONPath("$.foo.bar"));
491+
}
492+
493+
@Test
494+
public void testisValidJSONPath_WithFilter() {
495+
Assert.assertTrue(isValidJSONPath("$.store['book']"));
496+
Assert.assertTrue(isValidJSONPath("$['store']['book'][0]['title']"));
497+
Assert.assertTrue(isValidJSONPath("$.store.book[0]"));
498+
Assert.assertTrue(isValidJSONPath("$.store.book[1,2]"));
499+
Assert.assertTrue(isValidJSONPath("$.store.book[-1:] "));
500+
Assert.assertTrue(isValidJSONPath("$.store.book[0:2]"));
501+
Assert.assertTrue(isValidJSONPath("$.store.book[*]"));
502+
Assert.assertTrue(isValidJSONPath("$.store.book[?(@.price < 10)]"));
503+
Assert.assertTrue(isValidJSONPath("$.store.book[?(@.author == 'J.K. Rowling')]"));
504+
Assert.assertTrue(isValidJSONPath("$..author"));
505+
Assert.assertTrue(isValidJSONPath("$..book[?(@.price > 15)]"));
506+
Assert.assertTrue(isValidJSONPath("$.store.book[0,1]"));
507+
Assert.assertTrue(isValidJSONPath("$['store','warehouse']"));
508+
Assert.assertTrue(isValidJSONPath("$..book[?(@.price > 20)].title"));
509+
}
460510
}

plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java

Lines changed: 88 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.opensearch.ml.processor;
77

88
import static java.lang.Math.max;
9+
import static org.opensearch.ml.common.utils.StringUtils.toJson;
910
import static org.opensearch.ml.processor.InferenceProcessorAttributes.INPUT_MAP;
1011
import static org.opensearch.ml.processor.InferenceProcessorAttributes.MAX_PREDICTION_TASKS;
1112
import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_CONFIG;
@@ -55,12 +56,11 @@
5556
import org.opensearch.search.pipeline.Processor;
5657
import org.opensearch.search.pipeline.SearchResponseProcessor;
5758

58-
import com.jayway.jsonpath.Configuration;
5959
import com.jayway.jsonpath.JsonPath;
60-
import com.jayway.jsonpath.Option;
6160

6261
public class MLInferenceSearchResponseProcessor extends AbstractProcessor implements SearchResponseProcessor, ModelExecutor {
6362

63+
public static final String REQUEST_PREFIX = "_request.";
6464
private final NamedXContentRegistry xContentRegistry;
6565
private static final Logger logger = LogManager.getLogger(MLInferenceSearchResponseProcessor.class);
6666
private final InferenceProcessorAttributes inferenceProcessorAttributes;
@@ -155,6 +155,8 @@ public void processResponseAsync(
155155
try {
156156
SearchHit[] hits = response.getHits().getHits();
157157
// skip processing when there is no hit
158+
159+
String queryString = request.source().toString();
158160
if (hits.length == 0) {
159161
responseListener.onResponse(response);
160162
return;
@@ -183,7 +185,7 @@ public void processResponseAsync(
183185
);
184186
}
185187

186-
rewriteResponseDocuments(mlInferenceSearchResponse, responseListener);
188+
rewriteResponseDocuments(mlInferenceSearchResponse, responseListener, queryString);
187189
} else {
188190
// if one to one, make one hit search response and run rewriteResponseDocuments
189191
GroupedActionListener<SearchResponse> combineResponseListener = getCombineResponseGroupedActionListener(
@@ -198,7 +200,7 @@ public void processResponseAsync(
198200
newHits[0] = hit;
199201
SearchResponse oneHitResponse = SearchResponseUtil.replaceHits(newHits, response);
200202
ActionListener<SearchResponse> oneHitListener = getOneHitListener(combineResponseListener, isOneHitListenerFailed);
201-
rewriteResponseDocuments(oneHitResponse, oneHitListener);
203+
rewriteResponseDocuments(oneHitResponse, oneHitListener, queryString);
202204
// if any OneHitListener failure, try stop the rest of the predictions
203205
if (isOneHitListenerFailed.get()) {
204206
break;
@@ -305,9 +307,11 @@ public void onFailure(Exception e) {
305307
*
306308
* @param response the search response
307309
* @param responseListener the listener to be notified when the response is processed
310+
* @param queryString the query body in string format, for example, "{ \"query\": { \"match_all\": {} } }\n"
308311
* @throws IOException if an I/O error occurs during the rewriting process
309312
*/
310-
private void rewriteResponseDocuments(SearchResponse response, ActionListener<SearchResponse> responseListener) throws IOException {
313+
private void rewriteResponseDocuments(SearchResponse response, ActionListener<SearchResponse> responseListener, String queryString)
314+
throws IOException {
311315
List<Map<String, String>> processInputMap = inferenceProcessorAttributes.getInputMaps();
312316
List<Map<String, String>> processOutputMap = inferenceProcessorAttributes.getOutputMaps();
313317
int inputMapSize = (processInputMap == null) ? 0 : processInputMap.size();
@@ -329,7 +333,7 @@ private void rewriteResponseDocuments(SearchResponse response, ActionListener<Se
329333
);
330334
SearchHit[] hits = response.getHits().getHits();
331335
for (int inputMapIndex = 0; inputMapIndex < max(inputMapSize, 1); inputMapIndex++) {
332-
processPredictions(hits, processInputMap, inputMapIndex, batchPredictionListener, hitCountInPredictions);
336+
processPredictions(hits, processInputMap, inputMapIndex, batchPredictionListener, hitCountInPredictions, queryString);
333337
}
334338
}
335339

@@ -341,56 +345,80 @@ private void rewriteResponseDocuments(SearchResponse response, ActionListener<Se
341345
* @param inputMapIndex the index of the input mapping to process
342346
* @param batchPredictionListener the listener to be notified when the predictions are processed
343347
* @param hitCountInPredictions a map to keep track of the count of hits that have the required input fields for each round of prediction
348+
* @param queryString the query body in string format, for example, "{ \"query\": { \"match_all\": {} } }\n"
344349
* @throws IOException if an I/O error occurs during the prediction process
345350
*/
346351
private void processPredictions(
347352
SearchHit[] hits,
348353
List<Map<String, String>> processInputMap,
349354
int inputMapIndex,
350355
GroupedActionListener<Map<Integer, MLOutput>> batchPredictionListener,
351-
Map<Integer, Integer> hitCountInPredictions
356+
Map<Integer, Integer> hitCountInPredictions,
357+
String queryString
352358
) throws IOException {
353359

354360
Map<String, String> modelParameters = new HashMap<>();
355361
Map<String, String> modelConfigs = new HashMap<>();
356362

357363
if (inferenceProcessorAttributes.getModelConfigMaps() != null) {
358-
modelParameters.putAll(inferenceProcessorAttributes.getModelConfigMaps());
359-
modelConfigs.putAll(inferenceProcessorAttributes.getModelConfigMaps());
364+
Map<String, String> modelConfigMapsInput = inferenceProcessorAttributes.getModelConfigMaps();
365+
366+
modelParameters.putAll(modelConfigMapsInput);
367+
modelConfigs.putAll(modelConfigMapsInput);
368+
360369
}
361370

362371
Map<String, Object> modelInputParameters = new HashMap<>();
363372

364373
Map<String, String> inputMapping;
365374
if (processInputMap != null && !processInputMap.isEmpty()) {
366375
inputMapping = processInputMap.get(inputMapIndex);
376+
boolean isRequestInputMissing = checkIsRequestInputMissing(queryString, inputMapping);
377+
if (isRequestInputMissing) {
378+
if (!ignoreMissing) {
379+
throw new IllegalArgumentException(
380+
"Missing required input field in query body. input_map: " + inputMapping.values() + ", query body:" + queryString
381+
);
382+
}
383+
}
367384

368385
for (SearchHit hit : hits) {
369386
Map<String, Object> document = hit.getSourceAsMap();
370-
boolean isModelInputMissing = checkIsModelInputMissing(document, inputMapping);
371-
if (!isModelInputMissing) {
387+
boolean isDocumentFieldMissing = checkIsDocumentFieldMissing(document, inputMapping);
388+
if (!isDocumentFieldMissing) {
372389
MapUtils.incrementCounter(hitCountInPredictions, inputMapIndex);
373390
for (Map.Entry<String, String> entry : inputMapping.entrySet()) {
374391
// model field as key, document field name as value
375392
String modelInputFieldName = entry.getKey();
376393
String documentFieldName = entry.getValue();
377-
378-
Object documentJson = JsonPath.parse(document).read("$");
379-
Configuration configuration = Configuration
380-
.builder()
381-
.options(Option.SUPPRESS_EXCEPTIONS, Option.DEFAULT_PATH_LEAF_TO_NULL)
382-
.build();
383-
384-
Object documentValue = JsonPath.using(configuration).parse(documentJson).read(documentFieldName);
385-
if (documentValue != null) {
386-
// when not existed in the map, add into the modelInputParameters map
387-
updateModelInputParameters(modelInputParameters, modelInputFieldName, documentValue);
394+
// read the query string when the mapping field name starts with "$._request." or "_request."
395+
// skip when modelInputParameters already has this modelInputFieldName to avoid duplicate read
396+
if (StringUtils.isValidJSONPath(documentFieldName)
397+
&& (documentFieldName.startsWith("$." + REQUEST_PREFIX) || documentFieldName.startsWith(REQUEST_PREFIX))
398+
&& !modelInputParameters.containsKey(modelInputFieldName)) {
399+
String requestFieldName = documentFieldName.replaceFirst(REQUEST_PREFIX, "");
400+
401+
Object queryText = JsonPath.using(suppressExceptionConfiguration).parse(queryString).read(requestFieldName);
402+
if (queryText != null) {
403+
modelInputParameters.put(modelInputFieldName, toJson(queryText));
404+
}
405+
} else {
406+
Object documentValue = JsonPath.using(suppressExceptionConfiguration).parse(document).read(documentFieldName);
407+
if (documentValue != null) {
408+
// when not existed in the map, add into the modelInputParameters map
409+
updateModelInputParameters(modelInputParameters, modelInputFieldName, documentValue);
410+
}
388411
}
389412
}
390413
} else { // when document does not contain the documentFieldName, skip when ignoreMissing
391414
if (!ignoreMissing) {
392415
throw new IllegalArgumentException(
393-
"cannot find all required input fields: " + inputMapping.values() + " in hit:" + hit
416+
"cannot find all required input fields: "
417+
+ inputMapping.values()
418+
+ " in hit:"
419+
+ hit
420+
+ " and query body:"
421+
+ queryString
394422
);
395423
}
396424
}
@@ -542,11 +570,11 @@ public void onResponse(Map<Integer, MLOutput> multipleMLOutputs) {
542570
Map<String, String> inputMapping = getDefaultInputMapping(sourceAsMap, mappingIndex, processInputMap);
543571
Map<String, String> outputMapping = getDefaultOutputMapping(mappingIndex, processOutputMap);
544572

545-
boolean isModelInputMissing = false;
573+
boolean isDocumentFieldMissing = false;
546574
if (processInputMap != null && !processInputMap.isEmpty()) {
547-
isModelInputMissing = checkIsModelInputMissing(document, inputMapping);
575+
isDocumentFieldMissing = checkIsDocumentFieldMissing(document, inputMapping);
548576
}
549-
if (!isModelInputMissing) {
577+
if (!isDocumentFieldMissing) {
550578
// Iterate over outputMapping
551579
for (Map.Entry<String, String> outputMapEntry : outputMapping.entrySet()) {
552580

@@ -637,22 +665,45 @@ public void onFailure(Exception e) {
637665

638666
/**
639667
* Checks if the document is missing any of the required input fields specified in the input mapping.
668+
* When model config contains the default model_input value, it's not considered as missing model input.
640669
*
641670
* @param document the document map
642671
* @param inputMapping the input mapping
643672
* @return true if the document is missing any of the required input fields, false otherwise
644673
*/
645-
private boolean checkIsModelInputMissing(Map<String, Object> document, Map<String, String> inputMapping) {
646-
boolean isModelInputMissing = false;
647-
for (Map.Entry<String, String> inputMapEntry : inputMapping.entrySet()) {
648-
String oldDocumentFieldName = inputMapEntry.getValue();
649-
boolean checkSingleModelInputPresent = hasField(document, oldDocumentFieldName);
650-
if (!checkSingleModelInputPresent) {
651-
isModelInputMissing = true;
652-
break;
653-
}
654-
}
655-
return isModelInputMissing;
674+
private boolean checkIsDocumentFieldMissing(Map<String, Object> document, Map<String, String> inputMapping) {
675+
return inputMapping
676+
.values()
677+
.stream()
678+
.filter(fieldName -> !(fieldName.startsWith("$." + REQUEST_PREFIX) || fieldName.startsWith(REQUEST_PREFIX)))
679+
.anyMatch(fieldName -> {
680+
boolean isFieldPresentInDocument = document != null && hasField(document, fieldName);
681+
boolean isFieldPresentInModelConfig = this.inferenceProcessorAttributes.modelConfigMaps != null
682+
&& this.inferenceProcessorAttributes.modelConfigMaps.containsKey(fieldName);
683+
return !isFieldPresentInDocument && !isFieldPresentInModelConfig;
684+
});
685+
}
686+
687+
/**
688+
* Checks if the request is missing any of the required input fields specified in the input mapping.
689+
* When model config contains the default model_input value, it's not considered as missing model input.
690+
*
691+
* @param queryString the query body in string format, e.g., "{ \"query\": { \"match_all\": {} } }\n"
692+
* @param inputMapping the input mapping
693+
* @return true if the document is missing any of the required input fields, false otherwise
694+
*/
695+
private boolean checkIsRequestInputMissing(String queryString, Map<String, String> inputMapping) {
696+
return inputMapping
697+
.values()
698+
.stream()
699+
.filter(fieldName -> fieldName.startsWith("$." + REQUEST_PREFIX) || fieldName.startsWith(REQUEST_PREFIX))
700+
.map(fieldName -> fieldName.replaceFirst(REQUEST_PREFIX, ""))
701+
.anyMatch(requestFieldName -> {
702+
boolean isFieldPresentInQuery = queryString != null && hasField(queryString, requestFieldName);
703+
boolean isFieldPresentInModelConfig = this.inferenceProcessorAttributes.modelConfigMaps != null
704+
&& this.inferenceProcessorAttributes.modelConfigMaps.containsKey(requestFieldName);
705+
return !isFieldPresentInQuery && !isFieldPresentInModelConfig;
706+
});
656707
}
657708

658709
/**

plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,12 @@ default String toString(Object originalFieldValue) {
282282
}
283283

284284
default boolean hasField(Object json, String path) {
285-
Object value = JsonPath.using(suppressExceptionConfiguration).parse(json).read(path);
286-
285+
Object value;
286+
if (json instanceof String) {
287+
value = JsonPath.using(suppressExceptionConfiguration).parse((String) json).read(path);
288+
} else {
289+
value = JsonPath.using(suppressExceptionConfiguration).parse(json).read(path);
290+
}
287291
if (value != null) {
288292
return true;
289293
}

0 commit comments

Comments
 (0)