66package org .opensearch .ml .processor ;
77
88import static java .lang .Math .max ;
9+ import static org .opensearch .ml .common .utils .StringUtils .toJson ;
910import static org .opensearch .ml .processor .InferenceProcessorAttributes .INPUT_MAP ;
1011import static org .opensearch .ml .processor .InferenceProcessorAttributes .MAX_PREDICTION_TASKS ;
1112import static org .opensearch .ml .processor .InferenceProcessorAttributes .MODEL_CONFIG ;
5556import org .opensearch .search .pipeline .Processor ;
5657import org .opensearch .search .pipeline .SearchResponseProcessor ;
5758
58- import com .jayway .jsonpath .Configuration ;
5959import com .jayway .jsonpath .JsonPath ;
60- import com .jayway .jsonpath .Option ;
6160
6261public class MLInferenceSearchResponseProcessor extends AbstractProcessor implements SearchResponseProcessor , ModelExecutor {
6362
63+ public static final String REQUEST_PREFIX = "_request." ;
6464 private final NamedXContentRegistry xContentRegistry ;
6565 private static final Logger logger = LogManager .getLogger (MLInferenceSearchResponseProcessor .class );
6666 private final InferenceProcessorAttributes inferenceProcessorAttributes ;
@@ -155,6 +155,8 @@ public void processResponseAsync(
155155 try {
156156 SearchHit [] hits = response .getHits ().getHits ();
157157 // skip processing when there is no hit
158+
159+ String queryString = request .source ().toString ();
158160 if (hits .length == 0 ) {
159161 responseListener .onResponse (response );
160162 return ;
@@ -183,7 +185,7 @@ public void processResponseAsync(
183185 );
184186 }
185187
186- rewriteResponseDocuments (mlInferenceSearchResponse , responseListener );
188+ rewriteResponseDocuments (mlInferenceSearchResponse , responseListener , queryString );
187189 } else {
188190 // if one to one, make one hit search response and run rewriteResponseDocuments
189191 GroupedActionListener <SearchResponse > combineResponseListener = getCombineResponseGroupedActionListener (
@@ -198,7 +200,7 @@ public void processResponseAsync(
198200 newHits [0 ] = hit ;
199201 SearchResponse oneHitResponse = SearchResponseUtil .replaceHits (newHits , response );
200202 ActionListener <SearchResponse > oneHitListener = getOneHitListener (combineResponseListener , isOneHitListenerFailed );
201- rewriteResponseDocuments (oneHitResponse , oneHitListener );
203+ rewriteResponseDocuments (oneHitResponse , oneHitListener , queryString );
202204 // if any OneHitListener failure, try stop the rest of the predictions
203205 if (isOneHitListenerFailed .get ()) {
204206 break ;
@@ -305,9 +307,11 @@ public void onFailure(Exception e) {
305307 *
306308 * @param response the search response
307309 * @param responseListener the listener to be notified when the response is processed
310+ * @param queryString the query body in string format, for example, "{ \"query\": { \"match_all\": {} } }\n"
308311 * @throws IOException if an I/O error occurs during the rewriting process
309312 */
310- private void rewriteResponseDocuments (SearchResponse response , ActionListener <SearchResponse > responseListener ) throws IOException {
313+ private void rewriteResponseDocuments (SearchResponse response , ActionListener <SearchResponse > responseListener , String queryString )
314+ throws IOException {
311315 List <Map <String , String >> processInputMap = inferenceProcessorAttributes .getInputMaps ();
312316 List <Map <String , String >> processOutputMap = inferenceProcessorAttributes .getOutputMaps ();
313317 int inputMapSize = (processInputMap == null ) ? 0 : processInputMap .size ();
@@ -329,7 +333,7 @@ private void rewriteResponseDocuments(SearchResponse response, ActionListener<Se
329333 );
330334 SearchHit [] hits = response .getHits ().getHits ();
331335 for (int inputMapIndex = 0 ; inputMapIndex < max (inputMapSize , 1 ); inputMapIndex ++) {
332- processPredictions (hits , processInputMap , inputMapIndex , batchPredictionListener , hitCountInPredictions );
336+ processPredictions (hits , processInputMap , inputMapIndex , batchPredictionListener , hitCountInPredictions , queryString );
333337 }
334338 }
335339
@@ -341,56 +345,80 @@ private void rewriteResponseDocuments(SearchResponse response, ActionListener<Se
341345 * @param inputMapIndex the index of the input mapping to process
342346 * @param batchPredictionListener the listener to be notified when the predictions are processed
343347 * @param hitCountInPredictions a map to keep track of the count of hits that have the required input fields for each round of prediction
348+ * @param queryString the query body in string format, for example, "{ \"query\": { \"match_all\": {} } }\n"
344349 * @throws IOException if an I/O error occurs during the prediction process
345350 */
346351 private void processPredictions (
347352 SearchHit [] hits ,
348353 List <Map <String , String >> processInputMap ,
349354 int inputMapIndex ,
350355 GroupedActionListener <Map <Integer , MLOutput >> batchPredictionListener ,
351- Map <Integer , Integer > hitCountInPredictions
356+ Map <Integer , Integer > hitCountInPredictions ,
357+ String queryString
352358 ) throws IOException {
353359
354360 Map <String , String > modelParameters = new HashMap <>();
355361 Map <String , String > modelConfigs = new HashMap <>();
356362
357363 if (inferenceProcessorAttributes .getModelConfigMaps () != null ) {
358- modelParameters .putAll (inferenceProcessorAttributes .getModelConfigMaps ());
359- modelConfigs .putAll (inferenceProcessorAttributes .getModelConfigMaps ());
364+ Map <String , String > modelConfigMapsInput = inferenceProcessorAttributes .getModelConfigMaps ();
365+
366+ modelParameters .putAll (modelConfigMapsInput );
367+ modelConfigs .putAll (modelConfigMapsInput );
368+
360369 }
361370
362371 Map <String , Object > modelInputParameters = new HashMap <>();
363372
364373 Map <String , String > inputMapping ;
365374 if (processInputMap != null && !processInputMap .isEmpty ()) {
366375 inputMapping = processInputMap .get (inputMapIndex );
376+ boolean isRequestInputMissing = checkIsRequestInputMissing (queryString , inputMapping );
377+ if (isRequestInputMissing ) {
378+ if (!ignoreMissing ) {
379+ throw new IllegalArgumentException (
380+ "Missing required input field in query body. input_map: " + inputMapping .values () + ", query body:" + queryString
381+ );
382+ }
383+ }
367384
368385 for (SearchHit hit : hits ) {
369386 Map <String , Object > document = hit .getSourceAsMap ();
370- boolean isModelInputMissing = checkIsModelInputMissing (document , inputMapping );
371- if (!isModelInputMissing ) {
387+ boolean isDocumentFieldMissing = checkIsDocumentFieldMissing (document , inputMapping );
388+ if (!isDocumentFieldMissing ) {
372389 MapUtils .incrementCounter (hitCountInPredictions , inputMapIndex );
373390 for (Map .Entry <String , String > entry : inputMapping .entrySet ()) {
374391 // model field as key, document field name as value
375392 String modelInputFieldName = entry .getKey ();
376393 String documentFieldName = entry .getValue ();
377-
378- Object documentJson = JsonPath .parse (document ).read ("$" );
379- Configuration configuration = Configuration
380- .builder ()
381- .options (Option .SUPPRESS_EXCEPTIONS , Option .DEFAULT_PATH_LEAF_TO_NULL )
382- .build ();
383-
384- Object documentValue = JsonPath .using (configuration ).parse (documentJson ).read (documentFieldName );
385- if (documentValue != null ) {
386- // when not existed in the map, add into the modelInputParameters map
387- updateModelInputParameters (modelInputParameters , modelInputFieldName , documentValue );
394+ // read the query string when the mapping field name starts with "$._request." or "_request."
395+ // skip when modelInputParameters already has this modelInputFieldName to avoid duplicate read
396+ if (StringUtils .isValidJSONPath (documentFieldName )
397+ && (documentFieldName .startsWith ("$." + REQUEST_PREFIX ) || documentFieldName .startsWith (REQUEST_PREFIX ))
398+ && !modelInputParameters .containsKey (modelInputFieldName )) {
399+ String requestFieldName = documentFieldName .replaceFirst (REQUEST_PREFIX , "" );
400+
401+ Object queryText = JsonPath .using (suppressExceptionConfiguration ).parse (queryString ).read (requestFieldName );
402+ if (queryText != null ) {
403+ modelInputParameters .put (modelInputFieldName , toJson (queryText ));
404+ }
405+ } else {
406+ Object documentValue = JsonPath .using (suppressExceptionConfiguration ).parse (document ).read (documentFieldName );
407+ if (documentValue != null ) {
408+ // when not existed in the map, add into the modelInputParameters map
409+ updateModelInputParameters (modelInputParameters , modelInputFieldName , documentValue );
410+ }
388411 }
389412 }
390413 } else { // when document does not contain the documentFieldName, skip when ignoreMissing
391414 if (!ignoreMissing ) {
392415 throw new IllegalArgumentException (
393- "cannot find all required input fields: " + inputMapping .values () + " in hit:" + hit
416+ "cannot find all required input fields: "
417+ + inputMapping .values ()
418+ + " in hit:"
419+ + hit
420+ + " and query body:"
421+ + queryString
394422 );
395423 }
396424 }
@@ -542,11 +570,11 @@ public void onResponse(Map<Integer, MLOutput> multipleMLOutputs) {
542570 Map <String , String > inputMapping = getDefaultInputMapping (sourceAsMap , mappingIndex , processInputMap );
543571 Map <String , String > outputMapping = getDefaultOutputMapping (mappingIndex , processOutputMap );
544572
545- boolean isModelInputMissing = false ;
573+ boolean isDocumentFieldMissing = false ;
546574 if (processInputMap != null && !processInputMap .isEmpty ()) {
547- isModelInputMissing = checkIsModelInputMissing (document , inputMapping );
575+ isDocumentFieldMissing = checkIsDocumentFieldMissing (document , inputMapping );
548576 }
549- if (!isModelInputMissing ) {
577+ if (!isDocumentFieldMissing ) {
550578 // Iterate over outputMapping
551579 for (Map .Entry <String , String > outputMapEntry : outputMapping .entrySet ()) {
552580
@@ -637,22 +665,45 @@ public void onFailure(Exception e) {
637665
638666 /**
639667 * Checks if the document is missing any of the required input fields specified in the input mapping.
668+ * When model config contains the default model_input value, it's not considered as missing model input.
640669 *
641670 * @param document the document map
642671 * @param inputMapping the input mapping
643672 * @return true if the document is missing any of the required input fields, false otherwise
644673 */
645- private boolean checkIsModelInputMissing (Map <String , Object > document , Map <String , String > inputMapping ) {
646- boolean isModelInputMissing = false ;
647- for (Map .Entry <String , String > inputMapEntry : inputMapping .entrySet ()) {
648- String oldDocumentFieldName = inputMapEntry .getValue ();
649- boolean checkSingleModelInputPresent = hasField (document , oldDocumentFieldName );
650- if (!checkSingleModelInputPresent ) {
651- isModelInputMissing = true ;
652- break ;
653- }
654- }
655- return isModelInputMissing ;
674+ private boolean checkIsDocumentFieldMissing (Map <String , Object > document , Map <String , String > inputMapping ) {
675+ return inputMapping
676+ .values ()
677+ .stream ()
678+ .filter (fieldName -> !(fieldName .startsWith ("$." + REQUEST_PREFIX ) || fieldName .startsWith (REQUEST_PREFIX )))
679+ .anyMatch (fieldName -> {
680+ boolean isFieldPresentInDocument = document != null && hasField (document , fieldName );
681+ boolean isFieldPresentInModelConfig = this .inferenceProcessorAttributes .modelConfigMaps != null
682+ && this .inferenceProcessorAttributes .modelConfigMaps .containsKey (fieldName );
683+ return !isFieldPresentInDocument && !isFieldPresentInModelConfig ;
684+ });
685+ }
686+
687+ /**
688+ * Checks if the request is missing any of the required input fields specified in the input mapping.
689+ * When model config contains the default model_input value, it's not considered as missing model input.
690+ *
691+ * @param queryString the query body in string format, e.g., "{ \"query\": { \"match_all\": {} } }\n"
692+ * @param inputMapping the input mapping
693+ * @return true if the document is missing any of the required input fields, false otherwise
694+ */
695+ private boolean checkIsRequestInputMissing (String queryString , Map <String , String > inputMapping ) {
696+ return inputMapping
697+ .values ()
698+ .stream ()
699+ .filter (fieldName -> fieldName .startsWith ("$." + REQUEST_PREFIX ) || fieldName .startsWith (REQUEST_PREFIX ))
700+ .map (fieldName -> fieldName .replaceFirst (REQUEST_PREFIX , "" ))
701+ .anyMatch (requestFieldName -> {
702+ boolean isFieldPresentInQuery = queryString != null && hasField (queryString , requestFieldName );
703+ boolean isFieldPresentInModelConfig = this .inferenceProcessorAttributes .modelConfigMaps != null
704+ && this .inferenceProcessorAttributes .modelConfigMaps .containsKey (requestFieldName );
705+ return !isFieldPresentInQuery && !isFieldPresentInModelConfig ;
706+ });
656707 }
657708
658709 /**
0 commit comments