|
69 | 69 | import org.mockito.Mockito; |
70 | 70 |
|
71 | 71 | import java.util.ArrayList; |
| 72 | +import java.util.Arrays; |
72 | 73 | import java.util.EnumSet; |
73 | 74 | import java.util.HashMap; |
74 | 75 | import java.util.List; |
@@ -1274,40 +1275,32 @@ public void testChunkInfer_FailsBatch() throws InterruptedException { |
1274 | 1275 | public void testChunkingLargeDocument() throws InterruptedException { |
1275 | 1276 | assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); |
1276 | 1277 |
|
1277 | | - int wordsPerChunk = 10; |
1278 | 1278 | int numBatches = randomIntBetween(3, 6); |
1279 | | - int numChunks = randomIntBetween( |
1280 | | - ((numBatches - 1) * ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE) + 1, |
1281 | | - numBatches * ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE |
1282 | | - ); |
1283 | | - |
1284 | | - // build a doc with enough words to make numChunks of chunks |
1285 | | - int numWords = numChunks * wordsPerChunk; |
1286 | | - var input = "word ".repeat(numWords); |
1287 | 1279 |
|
1288 | 1280 | // how many response objects to return in each batch |
1289 | 1281 | int[] numResponsesPerBatch = new int[numBatches]; |
1290 | 1282 | for (int i = 0; i < numBatches - 1; i++) { |
1291 | 1283 | numResponsesPerBatch[i] = ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE; |
1292 | 1284 | } |
1293 | | - numResponsesPerBatch[numBatches - 1] = numChunks % ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE; |
1294 | | - if (numResponsesPerBatch[numBatches - 1] == 0) { |
1295 | | - numResponsesPerBatch[numBatches - 1] = ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE; |
1296 | | - } |
| 1285 | + numResponsesPerBatch[numBatches - 1] = randomIntBetween(1, ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE); |
| 1286 | + int numChunks = Arrays.stream(numResponsesPerBatch).sum(); |
| 1287 | + |
| 1288 | + // build a doc with enough words to make numChunks of chunks |
| 1289 | + int wordsPerChunk = 10; |
| 1290 | + int numWords = numChunks * wordsPerChunk; |
| 1291 | + var input = "word ".repeat(numWords); |
1297 | 1292 |
|
1298 | | - var batchIndex = new AtomicInteger(); |
1299 | 1293 | Client client = mock(Client.class); |
1300 | 1294 | when(client.threadPool()).thenReturn(threadPool); |
1301 | 1295 |
|
1302 | 1296 | // mock the inference response |
1303 | 1297 | doAnswer(invocationOnMock -> { |
| 1298 | + var request = (InferModelAction.Request) invocationOnMock.getArguments()[1]; |
1304 | 1299 | var listener = (ActionListener<InferModelAction.Response>) invocationOnMock.getArguments()[2]; |
1305 | | - |
1306 | 1300 | var mlTrainedModelResults = new ArrayList<InferenceResults>(); |
1307 | | - for (int i = 0; i < numResponsesPerBatch[batchIndex.get()]; i++) { |
| 1301 | + for (int i = 0; i < request.numberOfDocuments(); i++) { |
1308 | 1302 | mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults()); |
1309 | 1303 | } |
1310 | | - batchIndex.incrementAndGet(); |
1311 | 1304 | var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true); |
1312 | 1305 | listener.onResponse(response); |
1313 | 1306 | return null; |
|
0 commit comments