| 
10 | 10 | import org.elasticsearch.test.ESTestCase;  | 
11 | 11 | import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResults;  | 
12 | 12 | import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;  | 
 | 13 | +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DebertaV2Tokenization;  | 
13 | 14 | import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfig;  | 
14 | 15 | import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;  | 
15 | 16 | import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;  | 
16 | 17 | import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizationResult;  | 
17 | 18 | import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;  | 
 | 19 | +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DebertaV2Tokenizer;  | 
18 | 20 | import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;  | 
19 | 21 | import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;  | 
20 | 22 | 
 
  | 
21 | 23 | import java.io.IOException;  | 
22 | 24 | import java.util.List;  | 
23 | 25 | 
 
  | 
24 | 26 | import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizerTests.TEST_CASED_VOCAB;  | 
 | 27 | +import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DebertaV2TokenizerTests.TEST_CASE_SCORES;  | 
 | 28 | +import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DebertaV2TokenizerTests.TEST_CASE_VOCAB;  | 
25 | 29 | import static org.hamcrest.Matchers.closeTo;  | 
26 | 30 | import static org.hamcrest.Matchers.equalTo;  | 
27 | 31 | import static org.hamcrest.Matchers.is;  | 
@@ -62,6 +66,33 @@ public void testProcessor() throws IOException {  | 
62 | 66 |         assertThat(result.predictedValue(), closeTo(42, 1e-6));  | 
63 | 67 |     }  | 
64 | 68 | 
 
  | 
 | 69 | +    public void testBalancedTruncationWithLongInput() throws IOException {  | 
 | 70 | +        String question = "Is Elasticsearch scalable?";  | 
 | 71 | +        StringBuilder longInputBuilder = new StringBuilder();  | 
 | 72 | +        for (int i = 0; i < 1000; i++) {  | 
 | 73 | +            longInputBuilder.append(TEST_CASE_VOCAB.get(randomIntBetween(0, TEST_CASE_VOCAB.size() - 1))).append(i).append(" ");  | 
 | 74 | +        }  | 
 | 75 | +        String longInput = longInputBuilder.toString().trim();  | 
 | 76 | + | 
 | 77 | +        DebertaV2Tokenization tokenization = new DebertaV2Tokenization(false, true, null, Tokenization.Truncate.BALANCED, -1);  | 
 | 78 | +        DebertaV2Tokenizer tokenizer = DebertaV2Tokenizer.builder(TEST_CASE_VOCAB, TEST_CASE_SCORES, tokenization).build();  | 
 | 79 | +        TextSimilarityConfig textSimilarityConfig = new TextSimilarityConfig(  | 
 | 80 | +            question,  | 
 | 81 | +            new VocabularyConfig(""),  | 
 | 82 | +            tokenization,  | 
 | 83 | +            "result",  | 
 | 84 | +            TextSimilarityConfig.SpanScoreFunction.MAX  | 
 | 85 | +        );  | 
 | 86 | +        TextSimilarityProcessor processor = new TextSimilarityProcessor(tokenizer);  | 
 | 87 | +        TokenizationResult tokenizationResult = processor.getRequestBuilder(textSimilarityConfig)  | 
 | 88 | +            .buildRequest(List.of(longInput), "1", Tokenization.Truncate.BALANCED, -1, null)  | 
 | 89 | +            .tokenization();  | 
 | 90 | + | 
 | 91 | +        // Assert that the tokenization result is as expected  | 
 | 92 | +        assertThat(tokenizationResult.anyTruncated(), is(true));  | 
 | 93 | +        assertThat(tokenizationResult.getTokenization(0).tokenIds().length, equalTo(512));  | 
 | 94 | +    }  | 
 | 95 | + | 
65 | 96 |     public void testResultFunctions() {  | 
66 | 97 |         BertTokenization tokenization = new BertTokenization(false, true, 384, Tokenization.Truncate.NONE, 128);  | 
67 | 98 |         BertTokenizer tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, tokenization).build();  | 
 | 
0 commit comments