@@ -367,7 +367,7 @@ public void testExplicitNull() throws Exception {
367367 }
368368
369369 @ SuppressWarnings ({ "unchecked" , "rawtypes" })
370- public void testSkipGeneratingInference () throws Exception {
370+ public void testHandleEmptyInput () throws Exception {
371371 StaticModel model = StaticModel .createRandomInstance ();
372372 ShardBulkInferenceActionFilter filter = createFilter (
373373 threadPool ,
@@ -383,37 +383,33 @@ public void testSkipGeneratingInference() throws Exception {
383383 BulkShardRequest bulkShardRequest = (BulkShardRequest ) request ;
384384 IndexRequest actualRequest = getIndexRequestOrNull (bulkShardRequest .items ()[0 ].request ());
385385
386- // Create: Empty string
387- assertThat (XContentMapValues .extractValue ("obj" , actualRequest .sourceAsMap (), EXPLICIT_NULL ), equalTo ("" ));
388- assertNull (XContentMapValues .extractValue (InferenceMetadataFieldsMapper .NAME , actualRequest .sourceAsMap (), EXPLICIT_NULL ));
386+ // Create with Empty string
387+ assertInferenceResults (useLegacyFormat , actualRequest , "semantic_text_field" , useLegacyFormat ? EXPLICIT_NULL : "" , 0 );
389388
390- // Create: whitespace only
389+ // Create with whitespace only
391390 actualRequest = getIndexRequestOrNull (bulkShardRequest .items ()[1 ].request ());
392- assertThat (XContentMapValues .extractValue ("obj.field" , actualRequest .sourceAsMap (), EXPLICIT_NULL ), equalTo ("" ));
393- assertNull (XContentMapValues .extractValue (InferenceMetadataFieldsMapper .NAME , actualRequest .sourceAsMap (), EXPLICIT_NULL ));
391+ assertInferenceResults (useLegacyFormat , actualRequest , "semantic_text_field" , useLegacyFormat ? EXPLICIT_NULL : " " , 0 );
394392
395- // Update: Empty string
393+ // Update with multiple Whitespaces
396394 actualRequest = getIndexRequestOrNull (bulkShardRequest .items ()[2 ].request ());
397- assertThat (XContentMapValues .extractValue ("obj" , actualRequest .sourceAsMap (), EXPLICIT_NULL ), equalTo (" " ));
398- assertNull (XContentMapValues .extractValue (InferenceMetadataFieldsMapper .NAME , actualRequest .sourceAsMap (), EXPLICIT_NULL ));
399-
400- // Update: whitespace only
401- actualRequest = getIndexRequestOrNull (bulkShardRequest .items ()[3 ].request ());
402- assertThat (XContentMapValues .extractValue ("obj.field" , actualRequest .sourceAsMap (), EXPLICIT_NULL ), equalTo (" " ));
403- assertNull (XContentMapValues .extractValue (InferenceMetadataFieldsMapper .NAME , actualRequest .sourceAsMap (), EXPLICIT_NULL ));
395+ assertInferenceResults (useLegacyFormat , actualRequest , "semantic_text_field" , useLegacyFormat ? EXPLICIT_NULL : " " , 0 );
404396 } finally {
405397 chainExecuted .countDown ();
406398 }
407399 };
408400 ActionListener actionListener = mock (ActionListener .class );
409401 Task task = mock (Task .class );
402+ Map <String , InferenceFieldMetadata > inferenceFieldMap = Map .of (
403+ "semantic_text_field" ,
404+ new InferenceFieldMetadata ("semantic_text_field" , model .getInferenceEntityId (), new String [] { "semantic_text_field" })
405+ );
410406
411- BulkItemRequest [] items = new BulkItemRequest [4 ];
412- items [0 ] = new BulkItemRequest (0 , new IndexRequest ("index" ).source (Map .of ("obj" , "" )));
413- items [1 ] = new BulkItemRequest (1 , new IndexRequest ("index" ).source (Map .of ("obj" , Map .of ("field" , "" ))));
414- items [2 ] = new BulkItemRequest (2 , new UpdateRequest ().doc (new IndexRequest ("index" ).source (Map .of ("obj" , " " ))));
415- items [3 ] = new BulkItemRequest (3 , new UpdateRequest ().doc (new IndexRequest ("index" ).source (Map .of ("obj" , Map .of ("field" , " " )))));
407+ BulkItemRequest [] items = new BulkItemRequest [3 ];
408+ items [0 ] = new BulkItemRequest (0 , new IndexRequest ("index" ).source (Map .of ("semantic_text_field" , "" )));
409+ items [1 ] = new BulkItemRequest (1 , new IndexRequest ("index" ).source (Map .of ("semantic_text_field" , " " )));
410+ items [2 ] = new BulkItemRequest (2 , new UpdateRequest ().doc (new IndexRequest ("index" ).source (Map .of ("semantic_text_field" , " " ))));
416411 BulkShardRequest request = new BulkShardRequest (new ShardId ("test" , "test" , 0 ), WriteRequest .RefreshPolicy .NONE , items );
412+ request .setInferenceFieldMap (inferenceFieldMap );
417413 filter .apply (task , TransportShardBulkAction .ACTION_NAME , request , actionListener , actionFilterChain );
418414 awaitLatch (chainExecuted , 10 , TimeUnit .SECONDS );
419415 }
@@ -655,9 +651,8 @@ private static void assertInferenceResults(
655651 assertNotNull (chunks );
656652 assertThat (chunks .size (), equalTo (expectedChunkCount ));
657653 } else {
658- // If the expected chunk count is 0, we expect that no inference has been performed. In this case, the source should not be
659- // transformed, and thus the semantic text field structure should not be created.
660- assertNull (chunks );
654+ // If the expected chunk count is 0, we expect that no inference has been performed.
655+ assertTrue (chunks == null || chunks .isEmpty ());
661656 }
662657 } else {
663658 assertThat (XContentMapValues .extractValue (fieldName , requestMap , EXPLICIT_NULL ), equalTo (expectedOriginalValue ));
0 commit comments