|
10 | 10 | import org.elasticsearch.test.ESTestCase; |
11 | 11 | import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResults; |
12 | 12 | import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization; |
| 13 | +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DebertaV2Tokenization; |
13 | 14 | import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfig; |
14 | 15 | import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization; |
15 | 16 | import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig; |
16 | 17 | import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizationResult; |
17 | 18 | import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer; |
| 19 | +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DebertaV2Tokenizer; |
18 | 20 | import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult; |
19 | 21 | import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult; |
20 | 22 |
|
21 | 23 | import java.io.IOException; |
22 | 24 | import java.util.List; |
23 | 25 |
|
24 | 26 | import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizerTests.TEST_CASED_VOCAB; |
| 27 | +import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DebertaV2TokenizerTests.TEST_CASE_SCORES; |
| 28 | +import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DebertaV2TokenizerTests.TEST_CASE_VOCAB; |
25 | 29 | import static org.hamcrest.Matchers.closeTo; |
26 | 30 | import static org.hamcrest.Matchers.equalTo; |
27 | 31 | import static org.hamcrest.Matchers.is; |
@@ -62,6 +66,33 @@ public void testProcessor() throws IOException { |
62 | 66 | assertThat(result.predictedValue(), closeTo(42, 1e-6)); |
63 | 67 | } |
64 | 68 |
|
| 69 | + public void testBalancedTruncationWithLongInput() throws IOException { |
| 70 | + String question = "Is Elasticsearch scalable?"; |
| 71 | + StringBuilder longInputBuilder = new StringBuilder(); |
| 72 | + for (int i = 0; i < 1000; i++) { |
| 73 | + longInputBuilder.append(TEST_CASE_VOCAB.get(randomIntBetween(0, TEST_CASE_VOCAB.size() - 1))).append(i).append(" "); |
| 74 | + } |
| 75 | + String longInput = longInputBuilder.toString().trim(); |
| 76 | + |
| 77 | + DebertaV2Tokenization tokenization = new DebertaV2Tokenization(false, true, null, Tokenization.Truncate.BALANCED, -1); |
| 78 | + DebertaV2Tokenizer tokenizer = DebertaV2Tokenizer.builder(TEST_CASE_VOCAB, TEST_CASE_SCORES, tokenization).build(); |
| 79 | + TextSimilarityConfig textSimilarityConfig = new TextSimilarityConfig( |
| 80 | + question, |
| 81 | + new VocabularyConfig(""), |
| 82 | + tokenization, |
| 83 | + "result", |
| 84 | + TextSimilarityConfig.SpanScoreFunction.MAX |
| 85 | + ); |
| 86 | + TextSimilarityProcessor processor = new TextSimilarityProcessor(tokenizer); |
| 87 | + TokenizationResult tokenizationResult = processor.getRequestBuilder(textSimilarityConfig) |
| 88 | + .buildRequest(List.of(longInput), "1", Tokenization.Truncate.BALANCED, -1, null) |
| 89 | + .tokenization(); |
| 90 | + |
| 91 | + // Assert that the tokenization result is as expected |
| 92 | + assertThat(tokenizationResult.anyTruncated(), is(true)); |
| 93 | + assertThat(tokenizationResult.getTokenization(0).tokenIds().length, equalTo(512)); |
| 94 | + } |
| 95 | + |
65 | 96 | public void testResultFunctions() { |
66 | 97 | BertTokenization tokenization = new BertTokenization(false, true, 384, Tokenization.Truncate.NONE, 128); |
67 | 98 | BertTokenizer tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, tokenization).build(); |
|
0 commit comments