Skip to content

Commit f6efd38

Browse files
committed
Updates after the refactor
1 parent ee17a62 commit f6efd38

File tree

4 files changed

+82
-81
lines changed

4 files changed

+82
-81
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java

Lines changed: 73 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
1414
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
1515
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
16+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
1617
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
1718
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
1819
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
@@ -377,78 +378,78 @@ public void testMergingListener_Byte() {
377378
}
378379
}
379380

380-
// public void testMergingListener_Bit() {
381-
// int batchSize = 5;
382-
// int chunkSize = 20;
383-
// int overlap = 0;
384-
// // passage will be chunked into batchSize + 1 parts
385-
// // and spread over 2 batch requests
386-
// int numberOfWordsInPassage = (chunkSize * batchSize) + 5;
387-
//
388-
// var passageBuilder = new StringBuilder();
389-
// for (int i = 0; i < numberOfWordsInPassage; i++) {
390-
// passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
391-
// }
392-
// List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small");
393-
//
394-
// var finalListener = testListener();
395-
// var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap, EmbeddingRequestChunker.EmbeddingType.BIT)
396-
// .batchRequestsWithListeners(finalListener);
397-
// assertThat(batches, hasSize(2));
398-
//
399-
// // 4 inputs in 2 batches
400-
// {
401-
// var embeddings = new ArrayList<InferenceByteEmbedding>();
402-
// for (int i = 0; i < batchSize; i++) {
403-
// embeddings.add(new InferenceByteEmbedding(new byte[] { randomByte() }));
404-
// }
405-
// batches.get(0).listener().onResponse(new InferenceTextEmbeddingBitResults(embeddings));
406-
// }
407-
// {
408-
// var embeddings = new ArrayList<InferenceByteEmbedding>();
409-
// for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch
410-
// embeddings.add(new InferenceByteEmbedding(new byte[] { randomByte() }));
411-
// }
412-
// batches.get(1).listener().onResponse(new InferenceTextEmbeddingBitResults(embeddings));
413-
// }
414-
//
415-
// assertNotNull(finalListener.results);
416-
// assertThat(finalListener.results, hasSize(4));
417-
// {
418-
// var chunkedResult = finalListener.results.get(0);
419-
// assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class));
420-
// var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult;
421-
// assertThat(chunkedByteResult.chunks(), hasSize(1));
422-
// assertEquals("1st small", chunkedByteResult.chunks().get(0).matchedText());
423-
// }
424-
// {
425-
// // this is the large input split in multiple chunks
426-
// var chunkedResult = finalListener.results.get(1);
427-
// assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class));
428-
// var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult;
429-
// assertThat(chunkedByteResult.chunks(), hasSize(6));
430-
// assertThat(chunkedByteResult.chunks().get(0).matchedText(), startsWith("passage_input0 "));
431-
// assertThat(chunkedByteResult.chunks().get(1).matchedText(), startsWith(" passage_input20 "));
432-
// assertThat(chunkedByteResult.chunks().get(2).matchedText(), startsWith(" passage_input40 "));
433-
// assertThat(chunkedByteResult.chunks().get(3).matchedText(), startsWith(" passage_input60 "));
434-
// assertThat(chunkedByteResult.chunks().get(4).matchedText(), startsWith(" passage_input80 "));
435-
// assertThat(chunkedByteResult.chunks().get(5).matchedText(), startsWith(" passage_input100 "));
436-
// }
437-
// {
438-
// var chunkedResult = finalListener.results.get(2);
439-
// assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class));
440-
// var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult;
441-
// assertThat(chunkedByteResult.chunks(), hasSize(1));
442-
// assertEquals("2nd small", chunkedByteResult.chunks().get(0).matchedText());
443-
// }
444-
// {
445-
// var chunkedResult = finalListener.results.get(3);
446-
// assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class));
447-
// var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult;
448-
// assertThat(chunkedByteResult.chunks(), hasSize(1));
449-
// assertEquals("3rd small", chunkedByteResult.chunks().get(0).matchedText());
450-
// }
451-
// }
381+
public void testMergingListener_Bit() {
382+
int batchSize = 5;
383+
int chunkSize = 20;
384+
int overlap = 0;
385+
// passage will be chunked into batchSize + 1 parts
386+
// and spread over 2 batch requests
387+
int numberOfWordsInPassage = (chunkSize * batchSize) + 5;
388+
389+
var passageBuilder = new StringBuilder();
390+
for (int i = 0; i < numberOfWordsInPassage; i++) {
391+
passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
392+
}
393+
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small");
394+
395+
var finalListener = testListener();
396+
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap)
397+
.batchRequestsWithListeners(finalListener);
398+
assertThat(batches, hasSize(2));
399+
400+
// 4 inputs in 2 batches
401+
{
402+
var embeddings = new ArrayList<TextEmbeddingByteResults.Embedding>();
403+
for (int i = 0; i < batchSize; i++) {
404+
embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { randomByte() }));
405+
}
406+
batches.get(0).listener().onResponse(new TextEmbeddingBitResults(embeddings));
407+
}
408+
{
409+
var embeddings = new ArrayList<TextEmbeddingByteResults.Embedding>();
410+
for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch
411+
embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { randomByte() }));
412+
}
413+
batches.get(1).listener().onResponse(new TextEmbeddingBitResults(embeddings));
414+
}
415+
416+
assertNotNull(finalListener.results);
417+
assertThat(finalListener.results, hasSize(4));
418+
{
419+
var chunkedResult = finalListener.results.get(0);
420+
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
421+
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
422+
assertThat(chunkedByteResult.chunks(), hasSize(1));
423+
assertEquals("1st small", chunkedByteResult.chunks().get(0).matchedText());
424+
}
425+
{
426+
// this is the large input split in multiple chunks
427+
var chunkedResult = finalListener.results.get(1);
428+
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
429+
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
430+
assertThat(chunkedByteResult.chunks(), hasSize(6));
431+
assertThat(chunkedByteResult.chunks().get(0).matchedText(), startsWith("passage_input0 "));
432+
assertThat(chunkedByteResult.chunks().get(1).matchedText(), startsWith(" passage_input20 "));
433+
assertThat(chunkedByteResult.chunks().get(2).matchedText(), startsWith(" passage_input40 "));
434+
assertThat(chunkedByteResult.chunks().get(3).matchedText(), startsWith(" passage_input60 "));
435+
assertThat(chunkedByteResult.chunks().get(4).matchedText(), startsWith(" passage_input80 "));
436+
assertThat(chunkedByteResult.chunks().get(5).matchedText(), startsWith(" passage_input100 "));
437+
}
438+
{
439+
var chunkedResult = finalListener.results.get(2);
440+
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
441+
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
442+
assertThat(chunkedByteResult.chunks(), hasSize(1));
443+
assertEquals("2nd small", chunkedByteResult.chunks().get(0).matchedText());
444+
}
445+
{
446+
var chunkedResult = finalListener.results.get(3);
447+
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
448+
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
449+
assertThat(chunkedByteResult.chunks(), hasSize(1));
450+
assertEquals("3rd small", chunkedByteResult.chunks().get(0).matchedText());
451+
}
452+
}
452453

453454
public void testMergingListener_Sparse() {
454455
int batchSize = 4;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingBitResultsTests.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,10 @@ public void testTransformToCoordinationFormat() {
106106
}
107107

108108
public void testGetFirstEmbeddingSize() {
109-
var firstEmbeddingSize = new InferenceTextEmbeddingBitResults(
109+
var firstEmbeddingSize = new TextEmbeddingBitResults(
110110
List.of(
111-
new InferenceByteEmbedding(new byte[] { (byte) 23, (byte) 24 }),
112-
new InferenceByteEmbedding(new byte[] { (byte) 25, (byte) 26 })
111+
new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }),
112+
new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 })
113113
)
114114
).getFirstEmbeddingSize();
115115

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,10 @@ public void testTransformToCoordinationFormat() {
105105
}
106106

107107
public void testGetFirstEmbeddingSize() {
108-
var firstEmbeddingSize = new InferenceTextEmbeddingByteResults(
108+
var firstEmbeddingSize = new TextEmbeddingByteResults(
109109
List.of(
110-
new InferenceByteEmbedding(new byte[] { (byte) 23, (byte) 24 }),
111-
new InferenceByteEmbedding(new byte[] { (byte) 25, (byte) 26 })
110+
new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }),
111+
new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 })
112112
)
113113
).getFirstEmbeddingSize();
114114

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,10 @@ public void testTransformToCoordinationFormat() {
106106
}
107107

108108
public void testGetFirstEmbeddingSize() {
109-
var firstEmbeddingSize = new InferenceTextEmbeddingFloatResults(
109+
var firstEmbeddingSize = new TextEmbeddingFloatResults(
110110
List.of(
111-
new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.1F, 0.2F }),
112-
new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.3F, 0.4F })
111+
new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.2F }),
112+
new TextEmbeddingFloatResults.Embedding(new float[] { 0.3F, 0.4F })
113113
)
114114
).getFirstEmbeddingSize();
115115

0 commit comments

Comments
 (0)