Skip to content

Commit faf18e9

Browse files
committed
Improve test
1 parent 19deea7 commit faf18e9

File tree

1 file changed

+10
-17
lines changed

1 file changed

+10
-17
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
import org.mockito.Mockito;
7070

7171
import java.util.ArrayList;
72+
import java.util.Arrays;
7273
import java.util.EnumSet;
7374
import java.util.HashMap;
7475
import java.util.List;
@@ -1274,40 +1275,32 @@ public void testChunkInfer_FailsBatch() throws InterruptedException {
12741275
public void testChunkingLargeDocument() throws InterruptedException {
12751276
assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled());
12761277

1277-
int wordsPerChunk = 10;
12781278
int numBatches = randomIntBetween(3, 6);
1279-
int numChunks = randomIntBetween(
1280-
((numBatches - 1) * ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE) + 1,
1281-
numBatches * ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE
1282-
);
1283-
1284-
// build a doc with enough words to make numChunks of chunks
1285-
int numWords = numChunks * wordsPerChunk;
1286-
var input = "word ".repeat(numWords);
12871279

12881280
// how many response objects to return in each batch
12891281
int[] numResponsesPerBatch = new int[numBatches];
12901282
for (int i = 0; i < numBatches - 1; i++) {
12911283
numResponsesPerBatch[i] = ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE;
12921284
}
1293-
numResponsesPerBatch[numBatches - 1] = numChunks % ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE;
1294-
if (numResponsesPerBatch[numBatches - 1] == 0) {
1295-
numResponsesPerBatch[numBatches - 1] = ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE;
1296-
}
1285+
numResponsesPerBatch[numBatches - 1] = randomIntBetween(1, ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE);
1286+
int numChunks = Arrays.stream(numResponsesPerBatch).sum();
1287+
1288+
// build a doc with enough words to make numChunks of chunks
1289+
int wordsPerChunk = 10;
1290+
int numWords = numChunks * wordsPerChunk;
1291+
var input = "word ".repeat(numWords);
12971292

1298-
var batchIndex = new AtomicInteger();
12991293
Client client = mock(Client.class);
13001294
when(client.threadPool()).thenReturn(threadPool);
13011295

13021296
// mock the inference response
13031297
doAnswer(invocationOnMock -> {
1298+
var request = (InferModelAction.Request) invocationOnMock.getArguments()[1];
13041299
var listener = (ActionListener<InferModelAction.Response>) invocationOnMock.getArguments()[2];
1305-
13061300
var mlTrainedModelResults = new ArrayList<InferenceResults>();
1307-
for (int i = 0; i < numResponsesPerBatch[batchIndex.get()]; i++) {
1301+
for (int i = 0; i < request.numberOfDocuments(); i++) {
13081302
mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
13091303
}
1310-
batchIndex.incrementAndGet();
13111304
var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true);
13121305
listener.onResponse(response);
13131306
return null;

0 commit comments

Comments
 (0)