Skip to content

Commit 008e0c0

Browse files
committed
Fixed tests
1 parent 13e875e commit 008e0c0

File tree

12 files changed

+85
-86
lines changed

12 files changed

+85
-86
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ public void testMergingListener_Float() {
246246
for (int i = 0; i < numberOfWordsInPassage; i++) {
247247
passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
248248
}
249-
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small");
249+
List<String> inputs = List.of("a", passageBuilder.toString(), "bb", "ccc");
250250

251251
var finalListener = testListener();
252252
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
@@ -275,34 +275,34 @@ public void testMergingListener_Float() {
275275
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
276276
var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult;
277277
assertThat(chunkedFloatResult.chunks(), hasSize(1));
278-
assertEquals("1st small", chunkedFloatResult.chunks().get(0).matchedText());
278+
assertEquals(new ChunkedInference.TextOffset(0, 1), chunkedFloatResult.chunks().get(0).offset());
279279
}
280280
{
281281
// this is the large input split in multiple chunks
282282
var chunkedResult = finalListener.results.get(1);
283283
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
284284
var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult;
285285
assertThat(chunkedFloatResult.chunks(), hasSize(6));
286-
assertThat(chunkedFloatResult.chunks().get(0).matchedText(), startsWith("passage_input0 "));
287-
assertThat(chunkedFloatResult.chunks().get(1).matchedText(), startsWith(" passage_input20 "));
288-
assertThat(chunkedFloatResult.chunks().get(2).matchedText(), startsWith(" passage_input40 "));
289-
assertThat(chunkedFloatResult.chunks().get(3).matchedText(), startsWith(" passage_input60 "));
290-
assertThat(chunkedFloatResult.chunks().get(4).matchedText(), startsWith(" passage_input80 "));
291-
assertThat(chunkedFloatResult.chunks().get(5).matchedText(), startsWith(" passage_input100 "));
286+
assertThat(chunkedFloatResult.chunks().get(0).offset(), equalTo(new ChunkedInference.TextOffset(0, 309)));
287+
assertThat(chunkedFloatResult.chunks().get(1).offset(), equalTo(new ChunkedInference.TextOffset(309, 629)));
288+
assertThat(chunkedFloatResult.chunks().get(2).offset(), equalTo(new ChunkedInference.TextOffset(629, 949)));
289+
assertThat(chunkedFloatResult.chunks().get(3).offset(), equalTo(new ChunkedInference.TextOffset(949, 1269)));
290+
assertThat(chunkedFloatResult.chunks().get(4).offset(), equalTo(new ChunkedInference.TextOffset(1269, 1589)));
291+
assertThat(chunkedFloatResult.chunks().get(5).offset(), equalTo(new ChunkedInference.TextOffset(1589, 1675)));
292292
}
293293
{
294294
var chunkedResult = finalListener.results.get(2);
295295
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
296296
var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult;
297297
assertThat(chunkedFloatResult.chunks(), hasSize(1));
298-
assertEquals("2nd small", chunkedFloatResult.chunks().get(0).matchedText());
298+
assertEquals(new ChunkedInference.TextOffset(0, 2), chunkedFloatResult.chunks().get(0).offset());
299299
}
300300
{
301301
var chunkedResult = finalListener.results.get(3);
302302
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
303303
var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult;
304304
assertThat(chunkedFloatResult.chunks(), hasSize(1));
305-
assertEquals("3rd small", chunkedFloatResult.chunks().get(0).matchedText());
305+
assertEquals(new ChunkedInference.TextOffset(0, 3), chunkedFloatResult.chunks().get(0).offset());
306306
}
307307
}
308308

@@ -318,7 +318,7 @@ public void testMergingListener_Byte() {
318318
for (int i = 0; i < numberOfWordsInPassage; i++) {
319319
passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
320320
}
321-
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small");
321+
List<String> inputs = List.of("a", passageBuilder.toString(), "bb", "ccc");
322322

323323
var finalListener = testListener();
324324
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
@@ -347,34 +347,34 @@ public void testMergingListener_Byte() {
347347
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
348348
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
349349
assertThat(chunkedByteResult.chunks(), hasSize(1));
350-
assertEquals("1st small", chunkedByteResult.chunks().get(0).matchedText());
350+
assertEquals(new ChunkedInference.TextOffset(0, 1), chunkedByteResult.chunks().get(0).offset());
351351
}
352352
{
353353
// this is the large input split in multiple chunks
354354
var chunkedResult = finalListener.results.get(1);
355355
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
356356
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
357357
assertThat(chunkedByteResult.chunks(), hasSize(6));
358-
assertThat(chunkedByteResult.chunks().get(0).matchedText(), startsWith("passage_input0 "));
359-
assertThat(chunkedByteResult.chunks().get(1).matchedText(), startsWith(" passage_input20 "));
360-
assertThat(chunkedByteResult.chunks().get(2).matchedText(), startsWith(" passage_input40 "));
361-
assertThat(chunkedByteResult.chunks().get(3).matchedText(), startsWith(" passage_input60 "));
362-
assertThat(chunkedByteResult.chunks().get(4).matchedText(), startsWith(" passage_input80 "));
363-
assertThat(chunkedByteResult.chunks().get(5).matchedText(), startsWith(" passage_input100 "));
358+
assertThat(chunkedByteResult.chunks().get(0).offset(), equalTo(new ChunkedInference.TextOffset(0, 309)));
359+
assertThat(chunkedByteResult.chunks().get(1).offset(), equalTo(new ChunkedInference.TextOffset(309, 629)));
360+
assertThat(chunkedByteResult.chunks().get(2).offset(), equalTo(new ChunkedInference.TextOffset(629, 949)));
361+
assertThat(chunkedByteResult.chunks().get(3).offset(), equalTo(new ChunkedInference.TextOffset(949, 1269)));
362+
assertThat(chunkedByteResult.chunks().get(4).offset(), equalTo(new ChunkedInference.TextOffset(1269, 1589)));
363+
assertThat(chunkedByteResult.chunks().get(5).offset(), equalTo(new ChunkedInference.TextOffset(1589, 1675)));
364364
}
365365
{
366366
var chunkedResult = finalListener.results.get(2);
367367
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
368368
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
369369
assertThat(chunkedByteResult.chunks(), hasSize(1));
370-
assertEquals("2nd small", chunkedByteResult.chunks().get(0).matchedText());
370+
assertEquals(new ChunkedInference.TextOffset(0, 2), chunkedByteResult.chunks().get(0).offset());
371371
}
372372
{
373373
var chunkedResult = finalListener.results.get(3);
374374
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
375375
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
376376
assertThat(chunkedByteResult.chunks(), hasSize(1));
377-
assertEquals("3rd small", chunkedByteResult.chunks().get(0).matchedText());
377+
assertEquals(new ChunkedInference.TextOffset(0, 3), chunkedByteResult.chunks().get(0).offset());
378378
}
379379
}
380380

@@ -390,7 +390,7 @@ public void testMergingListener_Bit() {
390390
for (int i = 0; i < numberOfWordsInPassage; i++) {
391391
passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
392392
}
393-
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small");
393+
List<String> inputs = List.of("a", passageBuilder.toString(), "bb", "ccc");
394394

395395
var finalListener = testListener();
396396
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
@@ -419,34 +419,34 @@ public void testMergingListener_Bit() {
419419
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
420420
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
421421
assertThat(chunkedByteResult.chunks(), hasSize(1));
422-
assertEquals("1st small", chunkedByteResult.chunks().get(0).matchedText());
422+
assertEquals(new ChunkedInference.TextOffset(0, 1), chunkedByteResult.chunks().get(0).offset());
423423
}
424424
{
425425
// this is the large input split in multiple chunks
426426
var chunkedResult = finalListener.results.get(1);
427427
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
428428
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
429429
assertThat(chunkedByteResult.chunks(), hasSize(6));
430-
assertThat(chunkedByteResult.chunks().get(0).matchedText(), startsWith("passage_input0 "));
431-
assertThat(chunkedByteResult.chunks().get(1).matchedText(), startsWith(" passage_input20 "));
432-
assertThat(chunkedByteResult.chunks().get(2).matchedText(), startsWith(" passage_input40 "));
433-
assertThat(chunkedByteResult.chunks().get(3).matchedText(), startsWith(" passage_input60 "));
434-
assertThat(chunkedByteResult.chunks().get(4).matchedText(), startsWith(" passage_input80 "));
435-
assertThat(chunkedByteResult.chunks().get(5).matchedText(), startsWith(" passage_input100 "));
430+
assertThat(chunkedByteResult.chunks().get(0).offset(), equalTo(new ChunkedInference.TextOffset(0, 309)));
431+
assertThat(chunkedByteResult.chunks().get(1).offset(), equalTo(new ChunkedInference.TextOffset(309, 629)));
432+
assertThat(chunkedByteResult.chunks().get(2).offset(), equalTo(new ChunkedInference.TextOffset(629, 949)));
433+
assertThat(chunkedByteResult.chunks().get(3).offset(), equalTo(new ChunkedInference.TextOffset(949, 1269)));
434+
assertThat(chunkedByteResult.chunks().get(4).offset(), equalTo(new ChunkedInference.TextOffset(1269, 1589)));
435+
assertThat(chunkedByteResult.chunks().get(5).offset(), equalTo(new ChunkedInference.TextOffset(1589, 1675)));
436436
}
437437
{
438438
var chunkedResult = finalListener.results.get(2);
439439
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
440440
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
441441
assertThat(chunkedByteResult.chunks(), hasSize(1));
442-
assertEquals("2nd small", chunkedByteResult.chunks().get(0).matchedText());
442+
assertEquals(new ChunkedInference.TextOffset(0, 2), chunkedByteResult.chunks().get(0).offset());
443443
}
444444
{
445445
var chunkedResult = finalListener.results.get(3);
446446
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
447447
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
448448
assertThat(chunkedByteResult.chunks(), hasSize(1));
449-
assertEquals("3rd small", chunkedByteResult.chunks().get(0).matchedText());
449+
assertEquals(new ChunkedInference.TextOffset(0, 3), chunkedByteResult.chunks().get(0).offset());
450450
}
451451
}
452452

@@ -462,7 +462,7 @@ public void testMergingListener_Sparse() {
462462
for (int i = 0; i < numberOfWordsInPassage; i++) {
463463
passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
464464
}
465-
List<String> inputs = List.of("1st small", "2nd small", "3rd small", passageBuilder.toString());
465+
List<String> inputs = List.of("a", "bb", "ccc", passageBuilder.toString());
466466

467467
var finalListener = testListener();
468468
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
@@ -498,31 +498,31 @@ public void testMergingListener_Sparse() {
498498
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
499499
var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult;
500500
assertThat(chunkedSparseResult.chunks(), hasSize(1));
501-
assertEquals("1st small", chunkedSparseResult.chunks().get(0).matchedText());
501+
assertEquals(new ChunkedInference.TextOffset(0, 1), chunkedSparseResult.chunks().get(0).offset());
502502
}
503503
{
504504
var chunkedResult = finalListener.results.get(1);
505505
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
506506
var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult;
507507
assertThat(chunkedSparseResult.chunks(), hasSize(1));
508-
assertEquals("2nd small", chunkedSparseResult.chunks().get(0).matchedText());
508+
assertEquals(new ChunkedInference.TextOffset(0, 2), chunkedSparseResult.chunks().get(0).offset());
509509
}
510510
{
511511
var chunkedResult = finalListener.results.get(2);
512512
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
513513
var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult;
514514
assertThat(chunkedSparseResult.chunks(), hasSize(1));
515-
assertEquals("3rd small", chunkedSparseResult.chunks().get(0).matchedText());
515+
assertEquals(new ChunkedInference.TextOffset(0, 3), chunkedSparseResult.chunks().get(0).offset());
516516
}
517517
{
518518
// this is the large input split in multiple chunks
519519
var chunkedResult = finalListener.results.get(3);
520520
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
521521
var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult;
522522
assertThat(chunkedSparseResult.chunks(), hasSize(9)); // passage is split into 9 chunks, 10 words each
523-
assertThat(chunkedSparseResult.chunks().get(0).matchedText(), startsWith("passage_input0 "));
524-
assertThat(chunkedSparseResult.chunks().get(1).matchedText(), startsWith(" passage_input10 "));
525-
assertThat(chunkedSparseResult.chunks().get(8).matchedText(), startsWith(" passage_input80 "));
523+
assertThat(chunkedSparseResult.chunks().get(0).offset(), equalTo(new ChunkedInference.TextOffset(0, 149)));
524+
assertThat(chunkedSparseResult.chunks().get(1).offset(), equalTo(new ChunkedInference.TextOffset(149, 309)));
525+
assertThat(chunkedSparseResult.chunks().get(8).offset(), equalTo(new ChunkedInference.TextOffset(1269, 1350)));
526526
}
527527
}
528528

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1444,7 +1444,7 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep
14441444
service.chunkedInfer(
14451445
model,
14461446
null,
1447-
List.of("abc", "xyz"),
1447+
List.of("a", "bb"),
14481448
new HashMap<>(),
14491449
InputType.INGEST,
14501450
InferenceAction.Request.DEFAULT_TIMEOUT,
@@ -1457,7 +1457,7 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep
14571457
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
14581458
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
14591459
assertThat(floatResult.chunks(), hasSize(1));
1460-
assertEquals("abc", floatResult.chunks().get(0).matchedText());
1460+
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
14611461
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class));
14621462
assertArrayEquals(
14631463
new float[] { 0.123F, 0.678F },
@@ -1469,7 +1469,7 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep
14691469
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
14701470
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
14711471
assertThat(floatResult.chunks(), hasSize(1));
1472-
assertEquals("xyz", floatResult.chunks().get(0).matchedText());
1472+
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
14731473
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class));
14741474
assertArrayEquals(
14751475
new float[] { 0.223F, 0.278F },

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,7 +1192,7 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep
11921192
service.chunkedInfer(
11931193
model,
11941194
null,
1195-
List.of("foo", "bar"),
1195+
List.of("a", "bb"),
11961196
new HashMap<>(),
11971197
InputType.INGEST,
11981198
InferenceAction.Request.DEFAULT_TIMEOUT,
@@ -1205,7 +1205,7 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep
12051205
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
12061206
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
12071207
assertThat(floatResult.chunks(), hasSize(1));
1208-
assertEquals("foo", floatResult.chunks().get(0).matchedText());
1208+
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
12091209
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class));
12101210
assertArrayEquals(
12111211
new float[] { 0.0123f, -0.0123f },
@@ -1217,7 +1217,7 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep
12171217
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
12181218
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
12191219
assertThat(floatResult.chunks(), hasSize(1));
1220-
assertEquals("bar", floatResult.chunks().get(0).matchedText());
1220+
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
12211221
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class));
12221222
assertArrayEquals(
12231223
new float[] { 1.0123f, -1.0123f },
@@ -1233,7 +1233,7 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep
12331233

12341234
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
12351235
assertThat(requestMap.size(), Matchers.is(2));
1236-
assertThat(requestMap.get("input"), Matchers.is(List.of("foo", "bar")));
1236+
assertThat(requestMap.get("input"), Matchers.is(List.of("a", "bb")));
12371237
assertThat(requestMap.get("user"), Matchers.is("user"));
12381238
}
12391239
}

0 commit comments

Comments
 (0)