66
77import static org .opensearch .ml .processor .InferenceProcessorAttributes .*;
88
9+ import java .io .IOException ;
910import java .util .ArrayList ;
1011import java .util .Collection ;
1112import java .util .HashMap ;
13+ import java .util .HashSet ;
1214import java .util .List ;
1315import java .util .Map ;
1416import java .util .Set ;
1517import java .util .function .BiConsumer ;
1618
19+ import org .apache .logging .log4j .LogManager ;
20+ import org .apache .logging .log4j .Logger ;
1721import org .opensearch .action .ActionRequest ;
1822import org .opensearch .action .support .GroupedActionListener ;
1923import org .opensearch .client .Client ;
2024import org .opensearch .core .action .ActionListener ;
2125import org .opensearch .core .common .Strings ;
26+ import org .opensearch .core .xcontent .NamedXContentRegistry ;
2227import org .opensearch .ingest .AbstractProcessor ;
2328import org .opensearch .ingest .ConfigurationUtils ;
2429import org .opensearch .ingest .IngestDocument ;
2530import org .opensearch .ingest .Processor ;
2631import org .opensearch .ingest .ValueSource ;
27- import org .opensearch .ml .common .output .model .ModelTensorOutput ;
32+ import org .opensearch .ml .common .FunctionName ;
33+ import org .opensearch .ml .common .output .MLOutput ;
2834import org .opensearch .ml .common .transport .MLTaskResponse ;
2935import org .opensearch .ml .common .transport .prediction .MLPredictionTaskAction ;
3036import org .opensearch .ml .common .utils .StringUtils ;
4248 */
4349public class MLInferenceIngestProcessor extends AbstractProcessor implements ModelExecutor {
4450
51+ private static final Logger logger = LogManager .getLogger (MLInferenceIngestProcessor .class );
52+
4553 public static final String DOT_SYMBOL = "." ;
4654 private final InferenceProcessorAttributes inferenceProcessorAttributes ;
4755 private final boolean ignoreMissing ;
56+ private final String functionName ;
57+ private final boolean fullResponsePath ;
4858 private final boolean ignoreFailure ;
59+ private final boolean override ;
60+ private final String modelInput ;
4961 private final ScriptService scriptService ;
5062 private static Client client ;
5163 public static final String TYPE = "ml_inference" ;
5264 public static final String DEFAULT_OUTPUT_FIELD_NAME = "inference_results" ;
5365 // allow to ignore a field from mapping is not present in the document, and when the outfield is not found in the
5466 // prediction outcomes, return the whole prediction outcome by skipping filtering
5567 public static final String IGNORE_MISSING = "ignore_missing" ;
68+ public static final String OVERRIDE = "override" ;
69+ public static final String FUNCTION_NAME = "function_name" ;
70+ public static final String FULL_RESPONSE_PATH = "full_response_path" ;
71+ public static final String MODEL_INPUT = "model_input" ;
5672 // At default, ml inference processor allows maximum 10 prediction tasks running in parallel
5773 // it can be overwritten using max_prediction_tasks when creating processor
5874 public static final int DEFAULT_MAX_PREDICTION_TASKS = 10 ;
75+ private final NamedXContentRegistry xContentRegistry ;
5976
6077 private Configuration suppressExceptionConfiguration = Configuration
6178 .builder ()
@@ -71,9 +88,14 @@ protected MLInferenceIngestProcessor(
7188 String tag ,
7289 String description ,
7390 boolean ignoreMissing ,
91+ String functionName ,
92+ boolean fullResponsePath ,
7493 boolean ignoreFailure ,
94+ boolean override ,
95+ String modelInput ,
7596 ScriptService scriptService ,
76- Client client
97+ Client client ,
98+ NamedXContentRegistry xContentRegistry
7799 ) {
78100 super (tag , description );
79101 this .inferenceProcessorAttributes = new InferenceProcessorAttributes (
@@ -84,9 +106,14 @@ protected MLInferenceIngestProcessor(
84106 maxPredictionTask
85107 );
86108 this .ignoreMissing = ignoreMissing ;
109+ this .functionName = functionName ;
110+ this .fullResponsePath = fullResponsePath ;
87111 this .ignoreFailure = ignoreFailure ;
112+ this .override = override ;
113+ this .modelInput = modelInput ;
88114 this .scriptService = scriptService ;
89115 this .client = client ;
116+ this .xContentRegistry = xContentRegistry ;
90117 }
91118
92119 /**
@@ -162,10 +189,48 @@ private void processPredictions(
162189 List <Map <String , String >> processOutputMap ,
163190 int inputMapIndex ,
164191 int inputMapSize
165- ) {
192+ ) throws IOException {
166193 Map <String , String > modelParameters = new HashMap <>();
194+ Map <String , String > modelConfigs = new HashMap <>();
195+
167196 if (inferenceProcessorAttributes .getModelConfigMaps () != null ) {
168197 modelParameters .putAll (inferenceProcessorAttributes .getModelConfigMaps ());
198+ modelConfigs .putAll (inferenceProcessorAttributes .getModelConfigMaps ());
199+ }
200+
201+ Map <String , Object > ingestDocumentSourceAndMetaData = new HashMap <>();
202+ ingestDocumentSourceAndMetaData .putAll (ingestDocument .getSourceAndMetadata ());
203+ ingestDocumentSourceAndMetaData .put (IngestDocument .INGEST_KEY , ingestDocument .getIngestMetadata ());
204+
205+ Map <String , List <String >> newOutputMapping = new HashMap <>();
206+ if (processOutputMap != null ) {
207+
208+ Map <String , String > outputMapping = processOutputMap .get (inputMapIndex );
209+ for (Map .Entry <String , String > entry : outputMapping .entrySet ()) {
210+ String newDocumentFieldName = entry .getKey ();
211+ List <String > dotPathsInArray = writeNewDotPathForNestedObject (ingestDocumentSourceAndMetaData , newDocumentFieldName );
212+ newOutputMapping .put (newDocumentFieldName , dotPathsInArray );
213+ }
214+
215+ for (Map .Entry <String , String > entry : outputMapping .entrySet ()) {
216+ String newDocumentFieldName = entry .getKey ();
217+ List <String > dotPaths = newOutputMapping .get (newDocumentFieldName );
218+
219+ int existingFields = 0 ;
220+ for (String path : dotPaths ) {
221+ if (ingestDocument .hasField (path )) {
222+ existingFields ++;
223+ }
224+ }
225+ if (!override && existingFields == dotPaths .size ()) {
226+ logger .debug ("{} already exists in the ingest document. Removing it from output mapping" , newDocumentFieldName );
227+ newOutputMapping .remove (newDocumentFieldName );
228+ }
229+ }
230+ if (newOutputMapping .size () == 0 ) {
231+ batchPredictionListener .onResponse (null );
232+ return ;
233+ }
169234 }
170235 // when no input mapping is provided, default to read all fields from documents as model input
171236 if (inputMapSize == 0 ) {
@@ -184,15 +249,30 @@ private void processPredictions(
184249 }
185250 }
186251
187- ActionRequest request = getRemoteModelInferenceRequest (modelParameters , inferenceProcessorAttributes .getModelId ());
252+ Set <String > inputMapKeys = new HashSet <>(modelParameters .keySet ());
253+ inputMapKeys .removeAll (modelConfigs .keySet ());
254+
255+ Map <String , String > inputMappings = new HashMap <>();
256+ for (String k : inputMapKeys ) {
257+ inputMappings .put (k , modelParameters .get (k ));
258+ }
259+ ActionRequest request = getMLModelInferenceRequest (
260+ xContentRegistry ,
261+ modelParameters ,
262+ modelConfigs ,
263+ inputMappings ,
264+ inferenceProcessorAttributes .getModelId (),
265+ functionName ,
266+ modelInput
267+ );
188268
189269 client .execute (MLPredictionTaskAction .INSTANCE , request , new ActionListener <>() {
190270
191271 @ Override
192272 public void onResponse (MLTaskResponse mlTaskResponse ) {
193- ModelTensorOutput modelTensorOutput = ( ModelTensorOutput ) mlTaskResponse .getOutput ();
273+ MLOutput mlOutput = mlTaskResponse .getOutput ();
194274 if (processOutputMap == null || processOutputMap .isEmpty ()) {
195- appendFieldValue (modelTensorOutput , null , DEFAULT_OUTPUT_FIELD_NAME , ingestDocument );
275+ appendFieldValue (mlOutput , null , DEFAULT_OUTPUT_FIELD_NAME , ingestDocument );
196276 } else {
197277 // outMapping serves as a filter to modelTensorOutput, the fields that are not specified
198278 // in the outputMapping will not write to document
@@ -202,14 +282,10 @@ public void onResponse(MLTaskResponse mlTaskResponse) {
202282 // document field as key, model field as value
203283 String newDocumentFieldName = entry .getKey ();
204284 String modelOutputFieldName = entry .getValue ();
205- if (ingestDocument .hasField (newDocumentFieldName )) {
206- throw new IllegalArgumentException (
207- "document already has field name "
208- + newDocumentFieldName
209- + ". Not allow to overwrite the same field name, please check output_map."
210- );
285+ if (!newOutputMapping .containsKey (newDocumentFieldName )) {
286+ continue ;
211287 }
212- appendFieldValue (modelTensorOutput , modelOutputFieldName , newDocumentFieldName , ingestDocument );
288+ appendFieldValue (mlOutput , modelOutputFieldName , newDocumentFieldName , ingestDocument );
213289 }
214290 }
215291 batchPredictionListener .onResponse (null );
@@ -305,63 +381,61 @@ private String getFieldPath(IngestDocument ingestDocument, String documentFieldN
305381 /**
306382 * Appends the model output value to the specified field in the IngestDocument without modifying the source.
307383 *
308- * @param modelTensorOutput the ModelTensorOutput containing the model output
384+ * @param mlOutput the MLOutput containing the model output
309385 * @param modelOutputFieldName the name of the field in the model output
310386 * @param newDocumentFieldName the name of the field in the IngestDocument to append the value to
311387 * @param ingestDocument the IngestDocument to append the value to
312388 */
313389 private void appendFieldValue (
314- ModelTensorOutput modelTensorOutput ,
390+ MLOutput mlOutput ,
315391 String modelOutputFieldName ,
316392 String newDocumentFieldName ,
317393 IngestDocument ingestDocument
318394 ) {
319- Object modelOutputValue = null ;
320395
321- if (modelTensorOutput .getMlModelOutputs () != null && modelTensorOutput .getMlModelOutputs ().size () > 0 ) {
396+ if (mlOutput == null ) {
397+ throw new RuntimeException ("model inference output is null" );
398+ }
322399
323- modelOutputValue = getModelOutputValue (modelTensorOutput , modelOutputFieldName , ignoreMissing );
400+ Object modelOutputValue = getModelOutputValue (mlOutput , modelOutputFieldName , ignoreMissing , fullResponsePath );
324401
325- Map <String , Object > ingestDocumentSourceAndMetaData = new HashMap <>();
326- ingestDocumentSourceAndMetaData .putAll (ingestDocument .getSourceAndMetadata ());
327- ingestDocumentSourceAndMetaData .put (IngestDocument .INGEST_KEY , ingestDocument .getIngestMetadata ());
328- List <String > dotPathsInArray = writeNewDotPathForNestedObject (ingestDocumentSourceAndMetaData , newDocumentFieldName );
402+ Map <String , Object > ingestDocumentSourceAndMetaData = new HashMap <>();
403+ ingestDocumentSourceAndMetaData .putAll (ingestDocument .getSourceAndMetadata ());
404+ ingestDocumentSourceAndMetaData .put (IngestDocument .INGEST_KEY , ingestDocument .getIngestMetadata ());
405+ List <String > dotPathsInArray = writeNewDotPathForNestedObject (ingestDocumentSourceAndMetaData , newDocumentFieldName );
329406
330- if (dotPathsInArray .size () == 1 ) {
331- ValueSource ingestValue = ValueSource .wrap (modelOutputValue , scriptService );
407+ if (dotPathsInArray .size () == 1 ) {
408+ ValueSource ingestValue = ValueSource .wrap (modelOutputValue , scriptService );
409+ TemplateScript .Factory ingestField = ConfigurationUtils
410+ .compileTemplate (TYPE , tag , dotPathsInArray .get (0 ), dotPathsInArray .get (0 ), scriptService );
411+ ingestDocument .setFieldValue (ingestField , ingestValue , ignoreMissing );
412+ } else {
413+ if (!(modelOutputValue instanceof List )) {
414+ throw new IllegalArgumentException ("Model output is not an array, cannot assign to array in documents." );
415+ }
416+ List <?> modelOutputValueArray = (List <?>) modelOutputValue ;
417+ // check length of the prediction array to be the same of the document array
418+ if (dotPathsInArray .size () != modelOutputValueArray .size ()) {
419+ throw new RuntimeException (
420+ "the prediction field: "
421+ + modelOutputFieldName
422+ + " is an array in size of "
423+ + modelOutputValueArray .size ()
424+ + " but the document field array from field "
425+ + newDocumentFieldName
426+ + " is in size of "
427+ + dotPathsInArray .size ()
428+ );
429+ }
430+ // Iterate over dotPathInArray
431+ for (int i = 0 ; i < dotPathsInArray .size (); i ++) {
432+ String dotPathInArray = dotPathsInArray .get (i );
433+ Object modelOutputValueInArray = modelOutputValueArray .get (i );
434+ ValueSource ingestValue = ValueSource .wrap (modelOutputValueInArray , scriptService );
332435 TemplateScript .Factory ingestField = ConfigurationUtils
333- .compileTemplate (TYPE , tag , dotPathsInArray . get ( 0 ), dotPathsInArray . get ( 0 ) , scriptService );
436+ .compileTemplate (TYPE , tag , dotPathInArray , dotPathInArray , scriptService );
334437 ingestDocument .setFieldValue (ingestField , ingestValue , ignoreMissing );
335- } else {
336- if (!(modelOutputValue instanceof List )) {
337- throw new IllegalArgumentException ("Model output is not an array, cannot assign to array in documents." );
338- }
339- List <?> modelOutputValueArray = (List <?>) modelOutputValue ;
340- // check length of the prediction array to be the same of the document array
341- if (dotPathsInArray .size () != modelOutputValueArray .size ()) {
342- throw new RuntimeException (
343- "the prediction field: "
344- + modelOutputFieldName
345- + " is an array in size of "
346- + modelOutputValueArray .size ()
347- + " but the document field array from field "
348- + newDocumentFieldName
349- + " is in size of "
350- + dotPathsInArray .size ()
351- );
352- }
353- // Iterate over dotPathInArray
354- for (int i = 0 ; i < dotPathsInArray .size (); i ++) {
355- String dotPathInArray = dotPathsInArray .get (i );
356- Object modelOutputValueInArray = modelOutputValueArray .get (i );
357- ValueSource ingestValue = ValueSource .wrap (modelOutputValueInArray , scriptService );
358- TemplateScript .Factory ingestField = ConfigurationUtils
359- .compileTemplate (TYPE , tag , dotPathInArray , dotPathInArray , scriptService );
360- ingestDocument .setFieldValue (ingestField , ingestValue , ignoreMissing );
361- }
362438 }
363- } else {
364- throw new RuntimeException ("model inference output cannot be null" );
365439 }
366440 }
367441
@@ -374,16 +448,18 @@ public static class Factory implements Processor.Factory {
374448
375449 private final ScriptService scriptService ;
376450 private final Client client ;
451+ private final NamedXContentRegistry xContentRegistry ;
377452
378453 /**
379454 * Constructs a new instance of the Factory class.
380455 *
381456 * @param scriptService the ScriptService instance to be used by the Factory
382457 * @param client the Client instance to be used by the Factory
383458 */
384- public Factory (ScriptService scriptService , Client client ) {
459+ public Factory (ScriptService scriptService , Client client , NamedXContentRegistry xContentRegistry ) {
385460 this .scriptService = scriptService ;
386461 this .client = client ;
462+ this .xContentRegistry = xContentRegistry ;
387463 }
388464
389465 /**
@@ -410,6 +486,14 @@ public MLInferenceIngestProcessor create(
410486 int maxPredictionTask = ConfigurationUtils
411487 .readIntProperty (TYPE , processorTag , config , MAX_PREDICTION_TASKS , DEFAULT_MAX_PREDICTION_TASKS );
412488 boolean ignoreMissing = ConfigurationUtils .readBooleanProperty (TYPE , processorTag , config , IGNORE_MISSING , false );
489+ boolean override = ConfigurationUtils .readBooleanProperty (TYPE , processorTag , config , OVERRIDE , false );
490+ String functionName = ConfigurationUtils
491+ .readStringProperty (TYPE , processorTag , config , FUNCTION_NAME , FunctionName .REMOTE .name ());
492+ String modelInput = ConfigurationUtils
493+ .readStringProperty (TYPE , processorTag , config , MODEL_INPUT , "{ \" parameters\" : ${ml_inference.parameters} }" );
494+ boolean defaultValue = !functionName .equalsIgnoreCase ("remote" );
495+ boolean fullResponsePath = ConfigurationUtils .readBooleanProperty (TYPE , processorTag , config , FULL_RESPONSE_PATH , defaultValue );
496+
413497 boolean ignoreFailure = ConfigurationUtils
414498 .readBooleanProperty (TYPE , processorTag , config , ConfigurationUtils .IGNORE_FAILURE_KEY , false );
415499 // convert model config user input data structure to Map<String, String>
@@ -440,9 +524,14 @@ public MLInferenceIngestProcessor create(
440524 processorTag ,
441525 description ,
442526 ignoreMissing ,
527+ functionName ,
528+ fullResponsePath ,
443529 ignoreFailure ,
530+ override ,
531+ modelInput ,
444532 scriptService ,
445- client
533+ client ,
534+ xContentRegistry
446535 );
447536 }
448537 }
0 commit comments