Skip to content

Commit 30adfd6

Browse files
committed
updating tests
1 parent bf0ed8e commit 30adfd6

File tree

5 files changed

+119
-4
lines changed

5 files changed

+119
-4
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
1616
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
1717
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
18+
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingBitResults;
1819
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
1920
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
2021
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
@@ -421,6 +422,79 @@ public void testMergingListener_Byte() {
421422
}
422423
}
423424

425+
public void testMergingListener_Bit() {
426+
int batchSize = 5;
427+
int chunkSize = 20;
428+
int overlap = 0;
429+
// passage will be chunked into batchSize + 1 parts
430+
// and spread over 2 batch requests
431+
int numberOfWordsInPassage = (chunkSize * batchSize) + 5;
432+
433+
var passageBuilder = new StringBuilder();
434+
for (int i = 0; i < numberOfWordsInPassage; i++) {
435+
passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
436+
}
437+
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small");
438+
439+
var finalListener = testListener();
440+
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap, EmbeddingRequestChunker.EmbeddingType.BIT)
441+
.batchRequestsWithListeners(finalListener);
442+
assertThat(batches, hasSize(2));
443+
444+
// 4 inputs in 2 batches
445+
{
446+
var embeddings = new ArrayList<InferenceByteEmbedding>();
447+
for (int i = 0; i < batchSize; i++) {
448+
embeddings.add(new InferenceByteEmbedding(new byte[] { randomByte() }));
449+
}
450+
batches.get(0).listener().onResponse(new InferenceTextEmbeddingBitResults(embeddings));
451+
}
452+
{
453+
var embeddings = new ArrayList<InferenceByteEmbedding>();
454+
for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch
455+
embeddings.add(new InferenceByteEmbedding(new byte[] { randomByte() }));
456+
}
457+
batches.get(1).listener().onResponse(new InferenceTextEmbeddingBitResults(embeddings));
458+
}
459+
460+
assertNotNull(finalListener.results);
461+
assertThat(finalListener.results, hasSize(4));
462+
{
463+
var chunkedResult = finalListener.results.get(0);
464+
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class));
465+
var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult;
466+
assertThat(chunkedByteResult.chunks(), hasSize(1));
467+
assertEquals("1st small", chunkedByteResult.chunks().get(0).matchedText());
468+
}
469+
{
470+
// this is the large input split in multiple chunks
471+
var chunkedResult = finalListener.results.get(1);
472+
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class));
473+
var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult;
474+
assertThat(chunkedByteResult.chunks(), hasSize(6));
475+
assertThat(chunkedByteResult.chunks().get(0).matchedText(), startsWith("passage_input0 "));
476+
assertThat(chunkedByteResult.chunks().get(1).matchedText(), startsWith(" passage_input20 "));
477+
assertThat(chunkedByteResult.chunks().get(2).matchedText(), startsWith(" passage_input40 "));
478+
assertThat(chunkedByteResult.chunks().get(3).matchedText(), startsWith(" passage_input60 "));
479+
assertThat(chunkedByteResult.chunks().get(4).matchedText(), startsWith(" passage_input80 "));
480+
assertThat(chunkedByteResult.chunks().get(5).matchedText(), startsWith(" passage_input100 "));
481+
}
482+
{
483+
var chunkedResult = finalListener.results.get(2);
484+
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class));
485+
var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult;
486+
assertThat(chunkedByteResult.chunks(), hasSize(1));
487+
assertEquals("2nd small", chunkedByteResult.chunks().get(0).matchedText());
488+
}
489+
{
490+
var chunkedResult = finalListener.results.get(3);
491+
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class));
492+
var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult;
493+
assertThat(chunkedByteResult.chunks(), hasSize(1));
494+
assertEquals("3rd small", chunkedByteResult.chunks().get(0).matchedText());
495+
}
496+
}
497+
424498
public void testMergingListener_Sparse() {
425499
int batchSize = 4;
426500
int chunkSize = 10;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceTextEmbeddingBitResultsTests.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,17 @@ public void testTransformToCoordinationFormat() {
102102
);
103103
}
104104

105+
public void testGetFirstEmbeddingSize() {
106+
var firstEmbeddingSize = new InferenceTextEmbeddingBitResults(
107+
List.of(
108+
new InferenceByteEmbedding(new byte[] { (byte) 23, (byte) 24 }),
109+
new InferenceByteEmbedding(new byte[] { (byte) 25, (byte) 26 })
110+
)
111+
).getFirstEmbeddingSize();
112+
113+
assertThat(firstEmbeddingSize, is(16));
114+
}
115+
105116
@Override
106117
protected Writeable.Reader<InferenceTextEmbeddingBitResults> instanceReader() {
107118
return InferenceTextEmbeddingBitResults::new;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceTextEmbeddingByteResultsTests.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,17 @@ public void testTransformToCoordinationFormat() {
102102
);
103103
}
104104

105+
public void testGetFirstEmbeddingSize() {
106+
var firstEmbeddingSize = new InferenceTextEmbeddingByteResults(
107+
List.of(
108+
new InferenceByteEmbedding(new byte[] { (byte) 23, (byte) 24 }),
109+
new InferenceByteEmbedding(new byte[] { (byte) 25, (byte) 26 })
110+
)
111+
).getFirstEmbeddingSize();
112+
113+
assertThat(firstEmbeddingSize, is(2));
114+
}
115+
105116
@Override
106117
protected Writeable.Reader<InferenceTextEmbeddingByteResults> instanceReader() {
107118
return InferenceTextEmbeddingByteResults::new;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,17 @@ public void testTransformToCoordinationFormat() {
108108
);
109109
}
110110

111+
public void testGetFirstEmbeddingSize() {
112+
var firstEmbeddingSize = new InferenceTextEmbeddingFloatResults(
113+
List.of(
114+
new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.1F, 0.2F }),
115+
new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.3F, 0.4F })
116+
)
117+
).getFirstEmbeddingSize();
118+
119+
assertThat(firstEmbeddingSize, is(2));
120+
}
121+
111122
@Override
112123
protected Writeable.Reader<InferenceTextEmbeddingFloatResults> instanceReader() {
113124
return InferenceTextEmbeddingFloatResults::new;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,20 +1099,21 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si
10991099

11001100
try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) {
11011101
var embeddingSize = randomNonNegativeInt();
1102+
var embeddingType = randomFrom(CohereEmbeddingType.values());
11021103
var model = CohereEmbeddingsModelTests.createModel(
11031104
randomAlphaOfLength(10),
11041105
randomAlphaOfLength(10),
11051106
CohereEmbeddingsTaskSettings.EMPTY_SETTINGS,
11061107
randomNonNegativeInt(),
11071108
randomNonNegativeInt(),
11081109
randomAlphaOfLength(10),
1109-
randomFrom(CohereEmbeddingType.values()),
1110+
embeddingType,
11101111
similarityMeasure
11111112
);
11121113

11131114
Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
11141115

1115-
SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null ? CohereService.defaultSimilarity() : similarityMeasure;
1116+
SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null ? CohereService.defaultSimilarity(embeddingType) : similarityMeasure;
11161117
assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity());
11171118
assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
11181119
}
@@ -1579,8 +1580,15 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException {
15791580
}
15801581
}
15811582

1582-
public void testDefaultSimilarity() {
1583-
assertEquals(SimilarityMeasure.DOT_PRODUCT, CohereService.defaultSimilarity());
1583+
public void testDefaultSimilarity_BinaryEmbedding() {
1584+
assertEquals(SimilarityMeasure.L2_NORM, CohereService.defaultSimilarity(CohereEmbeddingType.BINARY));
1585+
assertEquals(SimilarityMeasure.L2_NORM, CohereService.defaultSimilarity(CohereEmbeddingType.BIT));
1586+
}
1587+
1588+
public void testDefaultSimilarity_NotBinaryEmbedding() {
1589+
assertEquals(SimilarityMeasure.DOT_PRODUCT, CohereService.defaultSimilarity(CohereEmbeddingType.FLOAT));
1590+
assertEquals(SimilarityMeasure.DOT_PRODUCT, CohereService.defaultSimilarity(CohereEmbeddingType.BYTE));
1591+
assertEquals(SimilarityMeasure.DOT_PRODUCT, CohereService.defaultSimilarity(CohereEmbeddingType.INT8));
15841592
}
15851593

15861594
public void testInfer_StreamRequest() throws Exception {

0 commit comments

Comments
 (0)