1212import org .apache .logging .log4j .Level ;
1313import org .elasticsearch .ElasticsearchStatusException ;
1414import org .elasticsearch .action .ActionListener ;
15+ import org .elasticsearch .action .LatchedActionListener ;
1516import org .elasticsearch .action .support .PlainActionFuture ;
1617import org .elasticsearch .client .internal .Client ;
1718import org .elasticsearch .cluster .service .ClusterService ;
6061import org .elasticsearch .xpack .inference .InferencePlugin ;
6162import org .elasticsearch .xpack .inference .chunking .ChunkingSettingsTests ;
6263import org .elasticsearch .xpack .inference .chunking .EmbeddingRequestChunker ;
64+ import org .elasticsearch .xpack .inference .chunking .WordBoundaryChunkingSettings ;
6365import org .elasticsearch .xpack .inference .services .ServiceFields ;
6466import org .junit .After ;
6567import org .junit .Before ;
7375import java .util .Map ;
7476import java .util .Optional ;
7577import java .util .Set ;
78+ import java .util .concurrent .CountDownLatch ;
7679import java .util .concurrent .atomic .AtomicBoolean ;
7780import java .util .concurrent .atomic .AtomicInteger ;
7881import java .util .concurrent .atomic .AtomicReference ;
@@ -936,17 +939,17 @@ public void testParsePersistedConfig() {
936939 }
937940 }
938941
939- public void testChunkInfer_E5WithNullChunkingSettings () {
942+ public void testChunkInfer_E5WithNullChunkingSettings () throws InterruptedException {
940943 testChunkInfer_e5 (null );
941944 }
942945
943- public void testChunkInfer_E5ChunkingSettingsSetAndFeatureFlagEnabled () {
946+ public void testChunkInfer_E5ChunkingSettingsSetAndFeatureFlagEnabled () throws InterruptedException {
944947 assumeTrue ("Only if 'inference_chunking_settings' feature flag is enabled" , ChunkingSettingsFeatureFlag .isEnabled ());
945948 testChunkInfer_e5 (ChunkingSettingsTests .createRandomChunkingSettings ());
946949 }
947950
948951 @ SuppressWarnings ("unchecked" )
949- private void testChunkInfer_e5 (ChunkingSettings chunkingSettings ) {
952+ private void testChunkInfer_e5 (ChunkingSettings chunkingSettings ) throws InterruptedException {
950953 var mlTrainedModelResults = new ArrayList <InferenceResults >();
951954 mlTrainedModelResults .add (MlTextEmbeddingResultsTests .createRandomResults ());
952955 mlTrainedModelResults .add (MlTextEmbeddingResultsTests .createRandomResults ());
@@ -994,6 +997,9 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) {
994997 gotResults .set (true );
995998 }, ESTestCase ::fail );
996999
1000+ var latch = new CountDownLatch (1 );
1001+ var latchedListener = new LatchedActionListener <>(resultsListener , latch );
1002+
9971003 service .chunkedInfer (
9981004 model ,
9991005 null ,
@@ -1002,23 +1008,24 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) {
10021008 InputType .SEARCH ,
10031009 new ChunkingOptions (null , null ),
10041010 InferenceAction .Request .DEFAULT_TIMEOUT ,
1005- ActionListener . runAfter ( resultsListener , () -> terminate ( threadPool ))
1011+ latchedListener
10061012 );
10071013
1014+ latch .await ();
10081015 assertTrue ("Listener not called" , gotResults .get ());
10091016 }
10101017
1011- public void testChunkInfer_SparseWithNullChunkingSettings () {
1018+ public void testChunkInfer_SparseWithNullChunkingSettings () throws InterruptedException {
10121019 testChunkInfer_Sparse (null );
10131020 }
10141021
1015- public void testChunkInfer_SparseWithChunkingSettingsSetAndFeatureFlagEnabled () {
1022+ public void testChunkInfer_SparseWithChunkingSettingsSetAndFeatureFlagEnabled () throws InterruptedException {
10161023 assumeTrue ("Only if 'inference_chunking_settings' feature flag is enabled" , ChunkingSettingsFeatureFlag .isEnabled ());
10171024 testChunkInfer_Sparse (ChunkingSettingsTests .createRandomChunkingSettings ());
10181025 }
10191026
10201027 @ SuppressWarnings ("unchecked" )
1021- private void testChunkInfer_Sparse (ChunkingSettings chunkingSettings ) {
1028+ private void testChunkInfer_Sparse (ChunkingSettings chunkingSettings ) throws InterruptedException {
10221029 var mlTrainedModelResults = new ArrayList <InferenceResults >();
10231030 mlTrainedModelResults .add (TextExpansionResultsTests .createRandomResults ());
10241031 mlTrainedModelResults .add (TextExpansionResultsTests .createRandomResults ());
@@ -1042,6 +1049,7 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) {
10421049 var service = createService (client );
10431050
10441051 var gotResults = new AtomicBoolean ();
1052+
10451053 var resultsListener = ActionListener .<List <ChunkedInferenceServiceResults >>wrap (chunkedResponse -> {
10461054 assertThat (chunkedResponse , hasSize (2 ));
10471055 assertThat (chunkedResponse .get (0 ), instanceOf (InferenceChunkedSparseEmbeddingResults .class ));
@@ -1061,6 +1069,9 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) {
10611069 gotResults .set (true );
10621070 }, ESTestCase ::fail );
10631071
1072+ var latch = new CountDownLatch (1 );
1073+ var latchedListener = new LatchedActionListener <>(resultsListener , latch );
1074+
10641075 service .chunkedInfer (
10651076 model ,
10661077 null ,
@@ -1069,23 +1080,24 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) {
10691080 InputType .SEARCH ,
10701081 new ChunkingOptions (null , null ),
10711082 InferenceAction .Request .DEFAULT_TIMEOUT ,
1072- ActionListener . runAfter ( resultsListener , () -> terminate ( threadPool ))
1083+ latchedListener
10731084 );
10741085
1086+ latch .await ();
10751087 assertTrue ("Listener not called" , gotResults .get ());
10761088 }
10771089
1078- public void testChunkInfer_ElserWithNullChunkingSettings () {
1090+ public void testChunkInfer_ElserWithNullChunkingSettings () throws InterruptedException {
10791091 testChunkInfer_Elser (null );
10801092 }
10811093
1082- public void testChunkInfer_ElserWithChunkingSettingsSetAndFeatureFlagEnabled () {
1094+ public void testChunkInfer_ElserWithChunkingSettingsSetAndFeatureFlagEnabled () throws InterruptedException {
10831095 assumeTrue ("Only if 'inference_chunking_settings' feature flag is enabled" , ChunkingSettingsFeatureFlag .isEnabled ());
10841096 testChunkInfer_Elser (ChunkingSettingsTests .createRandomChunkingSettings ());
10851097 }
10861098
10871099 @ SuppressWarnings ("unchecked" )
1088- private void testChunkInfer_Elser (ChunkingSettings chunkingSettings ) {
1100+ private void testChunkInfer_Elser (ChunkingSettings chunkingSettings ) throws InterruptedException {
10891101 var mlTrainedModelResults = new ArrayList <InferenceResults >();
10901102 mlTrainedModelResults .add (TextExpansionResultsTests .createRandomResults ());
10911103 mlTrainedModelResults .add (TextExpansionResultsTests .createRandomResults ());
@@ -1129,6 +1141,9 @@ private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) {
11291141 gotResults .set (true );
11301142 }, ESTestCase ::fail );
11311143
1144+ var latch = new CountDownLatch (1 );
1145+ var latchedListener = new LatchedActionListener <>(resultsListener , latch );
1146+
11321147 service .chunkedInfer (
11331148 model ,
11341149 null ,
@@ -1137,9 +1152,10 @@ private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) {
11371152 InputType .SEARCH ,
11381153 new ChunkingOptions (null , null ),
11391154 InferenceAction .Request .DEFAULT_TIMEOUT ,
1140- ActionListener . runAfter ( resultsListener , () -> terminate ( threadPool ))
1155+ latchedListener
11411156 );
11421157
1158+ latch .await ();
11431159 assertTrue ("Listener not called" , gotResults .get ());
11441160 }
11451161
@@ -1200,7 +1216,7 @@ public void testChunkInferSetsTokenization() {
12001216 }
12011217
12021218 @ SuppressWarnings ("unchecked" )
1203- public void testChunkInfer_FailsBatch () {
1219+ public void testChunkInfer_FailsBatch () throws InterruptedException {
12041220 var mlTrainedModelResults = new ArrayList <InferenceResults >();
12051221 mlTrainedModelResults .add (MlTextEmbeddingResultsTests .createRandomResults ());
12061222 mlTrainedModelResults .add (MlTextEmbeddingResultsTests .createRandomResults ());
@@ -1236,6 +1252,9 @@ public void testChunkInfer_FailsBatch() {
12361252 gotResults .set (true );
12371253 }, ESTestCase ::fail );
12381254
1255+ var latch = new CountDownLatch (1 );
1256+ var latchedListener = new LatchedActionListener <>(resultsListener , latch );
1257+
12391258 service .chunkedInfer (
12401259 model ,
12411260 null ,
@@ -1244,9 +1263,93 @@ public void testChunkInfer_FailsBatch() {
12441263 InputType .SEARCH ,
12451264 new ChunkingOptions (null , null ),
12461265 InferenceAction .Request .DEFAULT_TIMEOUT ,
1247- ActionListener .runAfter (resultsListener , () -> terminate (threadPool ))
1266+ latchedListener
1267+ );
1268+
1269+ latch .await ();
1270+ assertTrue ("Listener not called" , gotResults .get ());
1271+ }
1272+
1273+ @ SuppressWarnings ("unchecked" )
1274+ public void testChunkingLargeDocument () throws InterruptedException {
1275+ assumeTrue ("Only if 'inference_chunking_settings' feature flag is enabled" , ChunkingSettingsFeatureFlag .isEnabled ());
1276+
1277+ int wordsPerChunk = 10 ;
1278+ int numBatches = randomIntBetween (3 , 6 );
1279+ int numChunks = randomIntBetween (
1280+ ((numBatches - 1 ) * ElasticsearchInternalService .EMBEDDING_MAX_BATCH_SIZE ) + 1 ,
1281+ numBatches * ElasticsearchInternalService .EMBEDDING_MAX_BATCH_SIZE
1282+ );
1283+
1284+ // build a doc with enough words to make numChunks of chunks
1285+ int numWords = numChunks * wordsPerChunk ;
1286+ var docBuilder = new StringBuilder ();
1287+ for (int i = 0 ; i < numWords ; i ++) {
1288+ docBuilder .append ("word " );
1289+ }
1290+
1291+ // how many response objects to return in each batch
1292+ int [] numResponsesPerBatch = new int [numBatches ];
1293+ for (int i = 0 ; i < numBatches - 1 ; i ++) {
1294+ numResponsesPerBatch [i ] = ElasticsearchInternalService .EMBEDDING_MAX_BATCH_SIZE ;
1295+ }
1296+ numResponsesPerBatch [numBatches - 1 ] = numChunks % ElasticsearchInternalService .EMBEDDING_MAX_BATCH_SIZE ;
1297+
1298+ var batchIndex = new AtomicInteger ();
1299+ Client client = mock (Client .class );
1300+ when (client .threadPool ()).thenReturn (threadPool );
1301+
1302+ // mock the inference response
1303+ doAnswer (invocationOnMock -> {
1304+ var listener = (ActionListener <InferModelAction .Response >) invocationOnMock .getArguments ()[2 ];
1305+
1306+ var mlTrainedModelResults = new ArrayList <InferenceResults >();
1307+ for (int i = 0 ; i < numResponsesPerBatch [batchIndex .get ()]; i ++) {
1308+ mlTrainedModelResults .add (MlTextEmbeddingResultsTests .createRandomResults ());
1309+ }
1310+ batchIndex .incrementAndGet ();
1311+ var response = new InferModelAction .Response (mlTrainedModelResults , "foo" , true );
1312+ listener .onResponse (response );
1313+ return null ;
1314+ }).when (client ).execute (same (InferModelAction .INSTANCE ), any (InferModelAction .Request .class ), any (ActionListener .class ));
1315+
1316+ var service = createService (client );
1317+
1318+ var gotResults = new AtomicBoolean ();
1319+ var resultsListener = ActionListener .<List <ChunkedInferenceServiceResults >>wrap (chunkedResponse -> {
1320+ assertThat (chunkedResponse , hasSize (1 ));
1321+ assertThat (chunkedResponse .get (0 ), instanceOf (InferenceChunkedTextEmbeddingFloatResults .class ));
1322+ var sparseResults = (InferenceChunkedTextEmbeddingFloatResults ) chunkedResponse .get (0 );
1323+ assertThat (sparseResults .chunks (), hasSize (numChunks ));
1324+
1325+ gotResults .set (true );
1326+ }, ESTestCase ::fail );
1327+
1328+ // Create model using the word boundary chunker.
1329+ var model = new MultilingualE5SmallModel (
1330+ "foo" ,
1331+ TaskType .TEXT_EMBEDDING ,
1332+ "e5" ,
1333+ new MultilingualE5SmallInternalServiceSettings (1 , 1 , "cross-platform" , null ),
1334+ new WordBoundaryChunkingSettings (wordsPerChunk , 0 )
1335+ );
1336+
1337+ var latch = new CountDownLatch (1 );
1338+ var latchedListener = new LatchedActionListener <>(resultsListener , latch );
1339+
1340+ // For the given input we know how many requests will be made
1341+ service .chunkedInfer (
1342+ model ,
1343+ null ,
1344+ List .of (docBuilder .toString ()),
1345+ Map .of (),
1346+ InputType .SEARCH ,
1347+ new ChunkingOptions (null , null ),
1348+ InferenceAction .Request .DEFAULT_TIMEOUT ,
1349+ latchedListener
12481350 );
12491351
1352+ latch .await ();
12501353 assertTrue ("Listener not called" , gotResults .get ());
12511354 }
12521355
0 commit comments