@@ -1036,6 +1036,126 @@ public void testProcessResponseSuccessWriteToExt() throws Exception {
10361036 @ Override
10371037 public void onResponse (SearchResponse newSearchResponse ) {
10381038 assertEquals (newSearchResponse .getHits ().getHits ().length , 5 );
1039+ MLInferenceSearchResponse mLInferenceSearchResponse = (MLInferenceSearchResponse ) newSearchResponse ;
1040+ String resultsInResponse = (String ) mLInferenceSearchResponse .getParams ().get ("llm_response" );
1041+ assertEquals ("there is 1 value" , resultsInResponse );
1042+ }
1043+
1044+ @ Override
1045+ public void onFailure (Exception e ) {
1046+ throw new RuntimeException (e );
1047+ }
1048+
1049+ };
1050+ responseProcessor .processResponseAsync (request , response , responseContext , listener );
1051+ verify (client , times (1 )).execute (any (), any (), any ());
1052+ }
1053+
1054+ /**
1055+ * Tests the successful processing of a response with a single pair of input and output mappings.
1056+ * read the query text into model config with query extensions
1057+ * read the prediction outcome as array and store in search extension
1058+ * @throws Exception if an error occurs during the test
1059+ */
1060+ @ Test
1061+ public void testProcessResponseSuccessArrayWriteToExt () throws Exception {
1062+ String documentField = "text" ;
1063+ String modelInputField = "context" ;
1064+ List <Map <String , String >> inputMap = new ArrayList <>();
1065+ Map <String , String > input = new HashMap <>();
1066+ input .put (modelInputField , documentField );
1067+ inputMap .add (input );
1068+
1069+ String newDocumentField = "ext.ml_inference.results" ;
1070+ String modelOutputField = "results[*].document.text" ;
1071+ List <Map <String , String >> outputMap = new ArrayList <>();
1072+ Map <String , String > output = new HashMap <>();
1073+ output .put (newDocumentField , modelOutputField );
1074+ outputMap .add (output );
1075+ Map <String , String > modelConfig = new HashMap <>();
1076+ modelConfig .put ("query" , "positive review" );
1077+ MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor (
1078+ "model1" ,
1079+ inputMap ,
1080+ outputMap ,
1081+ optionalInputMaps ,
1082+ optionalOutputMaps ,
1083+ modelConfig ,
1084+ DEFAULT_MAX_PREDICTION_TASKS ,
1085+ PROCESSOR_TAG ,
1086+ DESCRIPTION ,
1087+ false ,
1088+ "remote" ,
1089+ false ,
1090+ false ,
1091+ false ,
1092+ "{ \" parameters\" : ${ml_inference.parameters} }" ,
1093+ client ,
1094+ TEST_XCONTENT_REGISTRY_FOR_QUERY ,
1095+ false
1096+ );
1097+
1098+ SearchRequest request = getSearchRequest ();
1099+ String fieldName = "text" ;
1100+ SearchResponse response = getSearchResponse (5 , true , fieldName );
1101+
1102+ Map <String , Object > inferenceResultMap = new HashMap <>();
1103+
1104+ Map <String , Object > doc1 = new HashMap <>();
1105+ Map <String , Object > doc1Text = new HashMap <>();
1106+ doc1Text .put ("text" , "value1" );
1107+ doc1 .put ("document" , doc1Text );
1108+ doc1 .put ("index" , 0.0 );
1109+ doc1 .put ("relevance_score" , 2.6480842E-5 );
1110+
1111+ Map <String , Object > doc2 = new HashMap <>();
1112+ Map <String , Object > doc2Text = new HashMap <>();
1113+ doc2Text .put ("text" , "value5" );
1114+ doc2 .put ("document" , doc2Text );
1115+ doc2 .put ("index" , 4.0 );
1116+ doc2 .put ("relevance_score" , 2.5071593E-5 );
1117+
1118+ Map <String , Object > doc3 = new HashMap <>();
1119+ Map <String , Object > doc3Text = new HashMap <>();
1120+ doc3Text .put ("text" , "value4" );
1121+ doc3 .put ("document" , doc3Text );
1122+ doc3 .put ("index" , 3.0 );
1123+ doc3 .put ("relevance_score" , 2.373734E-5 );
1124+
1125+ Map <String , Object > doc4 = new HashMap <>();
1126+ Map <String , Object > doc4Text = new HashMap <>();
1127+ doc4Text .put ("text" , "value2" );
1128+ doc4 .put ("document" , doc4Text );
1129+ doc4 .put ("index" , 1.0 );
1130+ doc4 .put ("relevance_score" , 2.1112483E-5 );
1131+
1132+ Map <String , Object > doc5 = new HashMap <>();
1133+ Map <String , Object > doc5Text = new HashMap <>();
1134+ doc5Text .put ("text" , "value3" );
1135+ doc5 .put ("document" , doc5Text );
1136+ doc5 .put ("index" , 2.0 );
1137+ doc5 .put ("relevance_score" , 1.6187581E-5 );
1138+
1139+ inferenceResultMap .put ("results" , Arrays .asList (doc1 , doc2 , doc3 , doc4 , doc5 ));
1140+
1141+ ModelTensor modelTensor = ModelTensor .builder ().dataAsMap (inferenceResultMap ).build ();
1142+ ModelTensors modelTensors = ModelTensors .builder ().mlModelTensors (Arrays .asList (modelTensor )).build ();
1143+ ModelTensorOutput mlModelTensorOutput = ModelTensorOutput .builder ().mlModelOutputs (Arrays .asList (modelTensors )).build ();
1144+
1145+ doAnswer (invocation -> {
1146+ ActionListener <MLTaskResponse > actionListener = invocation .getArgument (2 );
1147+ actionListener .onResponse (MLTaskResponse .builder ().output (mlModelTensorOutput ).build ());
1148+ return null ;
1149+ }).when (client ).execute (any (), any (), any ());
1150+
1151+ ActionListener <SearchResponse > listener = new ActionListener <>() {
1152+ @ Override
1153+ public void onResponse (SearchResponse newSearchResponse ) {
1154+ assertEquals (newSearchResponse .getHits ().getHits ().length , 5 );
1155+ MLInferenceSearchResponse mLInferenceSearchResponse = (MLInferenceSearchResponse ) newSearchResponse ;
1156+ List <Map <String , Object >> results = (List <Map <String , Object >>) inferenceResultMap .get ("results" );
1157+ List <String > resultsInResponse = (List <String >) mLInferenceSearchResponse .getParams ().get ("results" );
1158+ assertEquals (results .size (), resultsInResponse .size ());
10391159 }
10401160
10411161 @ Override
@@ -5166,17 +5286,151 @@ public void testOutputMapsExceedInputMaps() throws Exception {
51665286 } catch (IllegalArgumentException e ) {
51675287 assertEquals (
51685288 e .getMessage (),
5169- "when output_maps/optional_output_maps and input_maps/optional_input_maps are provided, their length needs to match. The input is in length of 2, while output_maps is in the length of 2 . Please adjust mappings."
5289+ "when output_maps/optional_output_maps and input_maps/optional_input_maps are provided, their length needs to match. The input is in length of 2, while output_maps is in the length of 3 . Please adjust mappings."
51705290 );
51715291
51725292 }
51735293 }
51745294
51755295 /**
5176- * Tests the creation of the MLInferenceSearchResponseProcessor with optional fields.
5177- *
5178- * @throws Exception if an error occurs during the test
5179- */
5296+ * Tests the case where only the input maps are provided in the configuration.
5297+ *
5298+ * @throws Exception if an error occurs during the test
5299+ */
5300+ public void testOnlyInputMapsProvided () throws Exception {
5301+ Map <String , Object > config = new HashMap <>();
5302+ config .put (MODEL_ID , "model2" );
5303+ List <Map <String , String >> inputMap = new ArrayList <>();
5304+ Map <String , String > input0 = new HashMap <>();
5305+ input0 .put ("inputs" , "text" );
5306+ inputMap .add (input0 );
5307+ Map <String , String > input1 = new HashMap <>();
5308+ input1 .put ("inputs" , "hashtag" );
5309+ inputMap .add (input1 );
5310+ config .put (INPUT_MAP , inputMap );
5311+ config .put (MAX_PREDICTION_TASKS , 2 );
5312+ String processorTag = randomAlphaOfLength (10 );
5313+
5314+ factory .create (Collections .emptyMap (), processorTag , null , false , config , null );
5315+ }
5316+
5317+ /**
5318+ * Tests the case where the input maps and empty output map are provided in the configuration.
5319+ *
5320+ * @throws Exception if an error occurs during the test
5321+ */
5322+ public void testInputMapsEmptyOutputMapProvided () throws Exception {
5323+ Map <String , Object > config = new HashMap <>();
5324+ config .put (MODEL_ID , "model2" );
5325+ List <Map <String , String >> inputMap = new ArrayList <>();
5326+ Map <String , String > input0 = new HashMap <>();
5327+ input0 .put ("inputs" , "text" );
5328+ inputMap .add (input0 );
5329+ Map <String , String > input1 = new HashMap <>();
5330+ input1 .put ("inputs" , "hashtag" );
5331+ inputMap .add (input1 );
5332+ config .put (INPUT_MAP , inputMap );
5333+ config .put (MAX_PREDICTION_TASKS , 2 );
5334+ String processorTag = randomAlphaOfLength (10 );
5335+
5336+ List <Map <String , String >> outputMap = new ArrayList <>();
5337+ config .put (OUTPUT_MAP , outputMap );
5338+
5339+ factory .create (Collections .emptyMap (), processorTag , null , false , config , null );
5340+ }
5341+
5342+ /**
5343+ * Tests the case where only the Optional input maps are provided in the configuration.
5344+ *
5345+ * @throws Exception if an error occurs during the test
5346+ */
5347+ public void testOnlyOptionalInputMapsProvided () throws Exception {
5348+ Map <String , Object > config = new HashMap <>();
5349+ config .put (MODEL_ID , "model2" );
5350+ List <Map <String , String >> inputMap = new ArrayList <>();
5351+ Map <String , String > input0 = new HashMap <>();
5352+ input0 .put ("inputs" , "text" );
5353+ inputMap .add (input0 );
5354+ Map <String , String > input1 = new HashMap <>();
5355+ input1 .put ("inputs" , "hashtag" );
5356+ inputMap .add (input1 );
5357+ config .put (OPTIONAL_INPUT_MAP , inputMap );
5358+ config .put (MAX_PREDICTION_TASKS , 2 );
5359+ String processorTag = randomAlphaOfLength (10 );
5360+
5361+ factory .create (Collections .emptyMap (), processorTag , null , false , config , null );
5362+
5363+ }
5364+
5365+ /**
5366+ * Tests the case where only the Optional input maps are provided in the configuration.
5367+ *
5368+ * @throws Exception if an error occurs during the test
5369+ */
5370+ public void testOnlyOptionalInputMapsEmptyOptionalOutputProvided () throws Exception {
5371+ Map <String , Object > config = new HashMap <>();
5372+ config .put (MODEL_ID , "model2" );
5373+ List <Map <String , String >> inputMap = new ArrayList <>();
5374+ Map <String , String > input0 = new HashMap <>();
5375+ input0 .put ("inputs" , "text" );
5376+ inputMap .add (input0 );
5377+ Map <String , String > input1 = new HashMap <>();
5378+ input1 .put ("inputs" , "hashtag" );
5379+ inputMap .add (input1 );
5380+ config .put (OPTIONAL_INPUT_MAP , inputMap );
5381+ config .put (MAX_PREDICTION_TASKS , 2 );
5382+ String processorTag = randomAlphaOfLength (10 );
5383+ List <Map <String , String >> outputMap = new ArrayList <>();
5384+ config .put (OPTIONAL_OUTPUT_MAP , outputMap );
5385+ factory .create (Collections .emptyMap (), processorTag , null , false , config , null );
5386+
5387+ }
5388+
5389+ /**
5390+ * Tests the case where only the output maps are provided in the configuration.
5391+ *
5392+ * @throws Exception if an error occurs during the test
5393+ */
5394+ public void testOnlyOutputMapsProvided () throws Exception {
5395+ Map <String , Object > config = new HashMap <>();
5396+ config .put (MODEL_ID , "model2" );
5397+ List <Map <String , String >> outputMap = new ArrayList <>();
5398+ Map <String , String > output = new HashMap <>();
5399+ output .put ("text_embedding" , "$.inference_results[0].output[0].data" );
5400+ outputMap .add (output );
5401+ config .put (OUTPUT_MAP , outputMap );
5402+ config .put (MAX_PREDICTION_TASKS , 2 );
5403+ String processorTag = randomAlphaOfLength (10 );
5404+
5405+ factory .create (Collections .emptyMap (), processorTag , null , false , config , null );
5406+ }
5407+
5408+ /**
5409+ * Tests the case where only the output maps are provided in the configuration.
5410+ *
5411+ * @throws Exception if an error occurs during the test
5412+ */
5413+ public void testOnlyOutputMapsEmptyInputProvided () throws Exception {
5414+ Map <String , Object > config = new HashMap <>();
5415+ config .put (MODEL_ID , "model2" );
5416+ List <Map <String , String >> inputMap = new ArrayList <>();
5417+ List <Map <String , String >> outputMap = new ArrayList <>();
5418+ Map <String , String > output = new HashMap <>();
5419+ output .put ("text_embedding" , "$.inference_results[0].output[0].data" );
5420+ outputMap .add (output );
5421+ config .put (INPUT_MAP , inputMap );
5422+ config .put (OUTPUT_MAP , outputMap );
5423+ config .put (MAX_PREDICTION_TASKS , 2 );
5424+ String processorTag = randomAlphaOfLength (10 );
5425+
5426+ factory .create (Collections .emptyMap (), processorTag , null , false , config , null );
5427+ }
5428+
5429+ /**
5430+ * Tests the creation of the MLInferenceSearchResponseProcessor with optional fields.
5431+ *
5432+ * @throws Exception if an error occurs during the test
5433+ */
51805434 public void testCreateOptionalFields () throws Exception {
51815435 Map <String , Object > config = new HashMap <>();
51825436 config .put (MODEL_ID , "model2" );
0 commit comments