diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index 65e30072d9870..667d7bf63efc9 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -40,6 +40,8 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenizationUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DebertaV2Tokenization; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DebertaV2TokenizationUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigUpdate; @@ -547,6 +549,13 @@ public List getNamedXContentParsers() { (p, c) -> XLMRobertaTokenization.fromXContent(p, (boolean) c) ) ); + namedXContent.add( + new NamedXContentRegistry.Entry( + Tokenization.class, + new ParseField(DebertaV2Tokenization.NAME), + (p, c) -> DebertaV2Tokenization.fromXContent(p, (boolean) c) + ) + ); namedXContent.add( new NamedXContentRegistry.Entry( @@ -583,6 +592,13 @@ public List getNamedXContentParsers() { (p, c) -> XLMRobertaTokenizationUpdate.fromXContent(p) ) ); + namedXContent.add( + new NamedXContentRegistry.Entry( + TokenizationUpdate.class, + DebertaV2TokenizationUpdate.NAME, + (p, c) -> DebertaV2TokenizationUpdate.fromXContent(p) + ) + ); return namedXContent; } @@ -791,6 +807,7 @@ public List getNamedWriteables() { ); namedWriteables.add(new NamedWriteableRegistry.Entry(Tokenization.class, RobertaTokenization.NAME, RobertaTokenization::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(Tokenization.class, XLMRobertaTokenization.NAME, XLMRobertaTokenization::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(Tokenization.class, DebertaV2Tokenization.NAME, DebertaV2Tokenization::new)); namedWriteables.add( new NamedWriteableRegistry.Entry( @@ -827,6 +844,9 @@ public List getNamedWriteables() { XLMRobertaTokenizationUpdate::new ) ); + namedWriteables.add( + new NamedWriteableRegistry.Entry(TokenizationUpdate.class, DebertaV2Tokenization.NAME, DebertaV2TokenizationUpdate::new) + ); return namedWriteables; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/DebertaV2Tokenization.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/DebertaV2Tokenization.java new file mode 100644 index 0000000000000..ce5464832b6d5 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/DebertaV2Tokenization.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; + +public class DebertaV2Tokenization extends Tokenization { + + public static final String NAME = "deberta_v2"; + public static final String MASK_TOKEN = "[MASK]"; + + public static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>( + NAME, + ignoreUnknownFields, + a -> new DebertaV2Tokenization( + (Boolean) a[0], + (Boolean) a[1], + (Integer) a[2], + a[3] == null ? null : Truncate.fromString((String) a[3]), + (Integer) a[4] + ) + ); + declareCommonFields(parser); + return parser; + } + + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + + public static DebertaV2Tokenization fromXContent(XContentParser parser, boolean lenient) { + return lenient ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null); + } + + public DebertaV2Tokenization( + Boolean doLowerCase, + Boolean withSpecialTokens, + Integer maxSequenceLength, + Truncate truncate, + Integer span + ) { + super(doLowerCase, withSpecialTokens, maxSequenceLength, truncate, span); + } + + public DebertaV2Tokenization(StreamInput in) throws IOException { + super(in); + } + + @Override + Tokenization buildWindowingTokenization(int updatedMaxSeqLength, int updatedSpan) { + return new DebertaV2Tokenization(doLowerCase, withSpecialTokens, updatedMaxSeqLength, truncate, updatedSpan); + } + + @Override + public String getMaskToken() { + return MASK_TOKEN; + } + + @Override + XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/DebertaV2TokenizationUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/DebertaV2TokenizationUpdate.java new file mode 100644 index 0000000000000..683b27793402d --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/DebertaV2TokenizationUpdate.java @@ -0,0 +1,90 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Optional; + +public class DebertaV2TokenizationUpdate extends AbstractTokenizationUpdate { + public static final ParseField NAME = new ParseField(DebertaV2Tokenization.NAME); + + public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "deberta_v2_tokenization_update", + a -> new DebertaV2TokenizationUpdate(a[0] == null ? null : Tokenization.Truncate.fromString((String) a[0]), (Integer) a[1]) + ); + + static { + declareCommonParserFields(PARSER); + } + + public static DebertaV2TokenizationUpdate fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public DebertaV2TokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) { + super(truncate, span); + } + + public DebertaV2TokenizationUpdate(StreamInput in) throws IOException { + super(in); + } + + @Override + public Tokenization apply(Tokenization originalConfig) { + if (originalConfig instanceof DebertaV2Tokenization debertaV2Tokenization) { + if (isNoop()) { + return debertaV2Tokenization; + } + + Tokenization.validateSpanAndTruncate(getTruncate(), getSpan()); + + if (getTruncate() != null && getTruncate().isInCompatibleWithSpan() == false) { + // When truncate value is incompatible with span wipe out + // the existing span setting to avoid an invalid combination of settings. + // This avoids the user have to set span to the special unset value + return new DebertaV2Tokenization( + debertaV2Tokenization.doLowerCase(), + debertaV2Tokenization.withSpecialTokens(), + debertaV2Tokenization.maxSequenceLength(), + getTruncate(), + null + ); + } + + return new DebertaV2Tokenization( + debertaV2Tokenization.doLowerCase(), + debertaV2Tokenization.withSpecialTokens(), + debertaV2Tokenization.maxSequenceLength(), + Optional.ofNullable(this.getTruncate()).orElse(originalConfig.getTruncate()), + Optional.ofNullable(this.getSpan()).orElse(originalConfig.getSpan()) + ); + } + throw ExceptionsHelper.badRequestException( + "Tokenization config of type [{}] can not be updated with a request of type [{}]", + originalConfig.getName(), + getName() + ); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfigUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfigUpdate.java index 92e44edcd1259..328c851d63be6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfigUpdate.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfigUpdate.java @@ -42,7 +42,9 @@ public static TokenizationUpdate tokenizationFromMap(Map map) { RobertaTokenizationUpdate.NAME.getPreferredName(), RobertaTokenizationUpdate::new, XLMRobertaTokenizationUpdate.NAME.getPreferredName(), - XLMRobertaTokenizationUpdate::new + XLMRobertaTokenizationUpdate::new, + DebertaV2Tokenization.NAME, + DebertaV2TokenizationUpdate::new ); Map tokenizationConfig = null; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/Tokenization.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/Tokenization.java index 4fec726b9fa5d..a7c46a68538c0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/Tokenization.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/Tokenization.java @@ -36,7 +36,8 @@ public enum Truncate { public boolean isInCompatibleWithSpan() { return false; } - }; + }, + BALANCED; public boolean isInCompatibleWithSpan() { return true; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfigUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfigUpdateTests.java index 8bc3a339ab0ee..83dc0b2a06376 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfigUpdateTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfigUpdateTests.java @@ -73,7 +73,9 @@ public void testTokenizationFromMap() { ); assertThat( e.getMessage(), - containsString("unknown tokenization type expecting one of [bert, bert_ja, mpnet, roberta, xlm_roberta] got [not_bert]") + containsString( + "unknown tokenization type expecting one of [bert, bert_ja, deberta_v2, mpnet, roberta, xlm_roberta] got [not_bert]" + ) ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java index 464c8eac8c9dd..1b53a7642abf3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java @@ -169,11 +169,6 @@ boolean isWithSpecialTokens() { return withSpecialTokens; } - @Override - int defaultSpanForChunking(int maxWindowSize) { - return (maxWindowSize - numExtraTokensForSingleSequence()) / 2; - } - @Override int getNumExtraTokensForSeqPair() { return 3; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/DebertaTokenizationResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/DebertaTokenizationResult.java new file mode 100644 index 0000000000000..2a50172fcc722 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/DebertaTokenizationResult.java @@ -0,0 +1,143 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * this file was contributed to by a Generative AI model + */ + +package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization; +import org.elasticsearch.xpack.ml.inference.nlp.NlpTask; + +import java.io.IOException; +import java.util.List; +import java.util.function.Function; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +public class DebertaTokenizationResult extends TokenizationResult { + static final String REQUEST_ID = "request_id"; + static final String TOKENS = "tokens"; + static final String ARG1 = "arg_1"; + static final String ARG2 = "arg_2"; + + private static final Logger logger = LogManager.getLogger(DebertaTokenizationResult.class); + + protected DebertaTokenizationResult(List vocab, List tokenizations, int padTokenId) { + super(vocab, tokenizations, padTokenId); + } + + @Override + public NlpTask.Request buildRequest(String requestId, Tokenization.Truncate t) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.field(REQUEST_ID, requestId); + writePaddedTokens(TOKENS, builder); + writeAttentionMask(ARG1, builder); + writeTokenTypeIds(ARG2, builder); + builder.endObject(); + + // BytesReference.bytes closes the builder + BytesReference jsonRequest = BytesReference.bytes(builder); + return new NlpTask.Request(this, jsonRequest); + } + + static class DebertaTokensBuilder implements TokenizationResult.TokensBuilder { + private final int clsTokenId; + private final int sepTokenId; + private final boolean withSpecialTokens; + protected final Stream.Builder tokenIds; + protected final Stream.Builder tokenMap; + protected int seqPairOffset = 0; + + DebertaTokensBuilder(int clsTokenId, int sepTokenId, boolean withSpecialTokens) { + this.clsTokenId = clsTokenId; + this.sepTokenId = sepTokenId; + this.withSpecialTokens = withSpecialTokens; + this.tokenIds = Stream.builder(); + this.tokenMap = Stream.builder(); + } + + @Override + public TokensBuilder addSequence(List tokenIds, List tokenMap) { + // DeBERTa-v2 single sequence: [CLS] X [SEP] + if (withSpecialTokens) { + this.tokenIds.add(IntStream.of(clsTokenId)); + this.tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION)); + } + this.tokenIds.add(tokenIds.stream().mapToInt(Integer::valueOf)); + this.tokenMap.add(tokenMap.stream().mapToInt(Integer::valueOf)); + if (withSpecialTokens) { + this.tokenIds.add(IntStream.of(sepTokenId)); + this.tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION)); + } + return this; + } + + @Override + public TokensBuilder addSequencePair( + List tokenId1s, + List tokenMap1, + List tokenId2s, + List tokenMap2 + ) { + if (tokenId1s.isEmpty() || tokenId2s.isEmpty()) { + throw new IllegalArgumentException("Both sequences must have at least one token"); + } + + // DeBERTa-v2 pair of sequences: [CLS] A [SEP] B [SEP] + if (withSpecialTokens) { + tokenIds.add(IntStream.of(clsTokenId)); + tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION)); + } + tokenIds.add(tokenId1s.stream().mapToInt(Integer::valueOf)); + tokenMap.add(tokenMap1.stream().mapToInt(Integer::valueOf)); + int previouslyFinalMap = tokenMap1.get(tokenMap1.size() - 1); + if (withSpecialTokens) { + tokenIds.add(IntStream.of(sepTokenId)); + tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION)); + } + tokenIds.add(tokenId2s.stream().mapToInt(Integer::valueOf)); + tokenMap.add(tokenMap2.stream().mapToInt(i -> i + previouslyFinalMap)); + if (withSpecialTokens) { + tokenIds.add(IntStream.of(sepTokenId)); + tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION)); + } + seqPairOffset = withSpecialTokens ? tokenId1s.size() + 2 : tokenId1s.size(); + return this; + } + + @Override + public Tokens build( + List input, + boolean truncated, + List> allTokens, + int spanPrev, + int seqId + ) { + return new Tokens( + input, + allTokens, + truncated, + tokenIds.build().flatMapToInt(Function.identity()).toArray(), + tokenMap.build().flatMapToInt(Function.identity()).toArray(), + spanPrev, + seqId, + seqPairOffset + ); + } + + @Override + public Tokens build(String input, boolean truncated, List allTokens, int spanPrev, int seqId) { + return TokensBuilder.super.build(input, truncated, allTokens, spanPrev, seqId); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/DebertaV2Tokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/DebertaV2Tokenizer.java new file mode 100644 index 0000000000000..3f7094bcce29d --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/DebertaV2Tokenizer.java @@ -0,0 +1,301 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * This Java port DeBERTa-V2 tokenizer, was derived from + * Microsoft's DeBERTa-V2 project at https://github.com/microsoft/DeBERTa + * and + * Huggingface's DeBERTa-V2 transformers + * project at https://github.com/huggingface/transformers/blob/main/src/transformers/models/deberta_v2/tokenization_deberta_v2.py + */ + +package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute; +import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DebertaV2Tokenization; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.inference.nlp.NlpTask; + +import java.io.IOException; +import java.io.Reader; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.OptionalInt; +import java.util.Set; +import java.util.SortedMap; +import java.util.TreeMap; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class DebertaV2Tokenizer extends NlpTokenizer { + + public static final String UNKNOWN_TOKEN = "[UNK]"; + public static final String SEPARATOR_TOKEN = "[SEP]"; + public static final String PAD_TOKEN = "[PAD]"; + public static final String CLASS_TOKEN = "[CLS]"; + public static final String MASK_TOKEN = "[MASK]"; + + private static final Set NEVER_SPLIT = Set.of(UNKNOWN_TOKEN, SEPARATOR_TOKEN, PAD_TOKEN, CLASS_TOKEN, MASK_TOKEN); + + private final DebertaAnalyzer debertaAnalyzer; + protected final List originalVocab; + private final SortedMap vocab; + protected final boolean withSpecialTokens; + protected final int sepTokenId; + private final int clsTokenId; + protected final int padTokenId; + private final int maxSequenceLength; + + protected DebertaV2Tokenizer( + List originalVocab, + SortedMap vocab, + List scores, + boolean withSpecialTokens, + int maxSequenceLength, + Set neverSplit + ) throws IOException { + this.originalVocab = originalVocab; + this.debertaAnalyzer = new DebertaAnalyzer( + originalVocab, + scores, + new ArrayList<>(Sets.union(NEVER_SPLIT, neverSplit)), + UNKNOWN_TOKEN + ); + this.vocab = vocab; + this.withSpecialTokens = withSpecialTokens; + this.maxSequenceLength = maxSequenceLength; + if (vocab.containsKey(UNKNOWN_TOKEN) == false) { + throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required [{}] token", UNKNOWN_TOKEN); + } + if (vocab.containsKey(PAD_TOKEN) == false) { + throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required [{}] token", PAD_TOKEN); + } + this.padTokenId = vocab.get(PAD_TOKEN); + if (withSpecialTokens) { + Set missingSpecialTokens = Sets.difference(Set.of(SEPARATOR_TOKEN, CLASS_TOKEN), vocab.keySet()); + if (missingSpecialTokens.isEmpty() == false) { + throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required {} token(s)", missingSpecialTokens); + } + this.sepTokenId = vocab.get(SEPARATOR_TOKEN); + this.clsTokenId = vocab.get(CLASS_TOKEN); + } else { + this.sepTokenId = -1; + this.clsTokenId = -1; + } + } + + @Override + int clsTokenId() { + return clsTokenId; + } + + @Override + int sepTokenId() { + return sepTokenId; + } + + @Override + int maxSequenceLength() { + return maxSequenceLength; + } + + @Override + boolean isWithSpecialTokens() { + return withSpecialTokens; + } + + @Override + int numExtraTokensForSingleSequence() { + // https://github.com/huggingface/transformers/blob/v4.44.0/src/transformers/models/deberta_v2/tokenization_deberta_v2.py#L164 + // single sequence: [CLS] X [SEP] + return 2; + } + + @Override + int getNumExtraTokensForSeqPair() { + // https://github.com/huggingface/transformers/blob/v4.44.0/src/transformers/models/deberta_v2/tokenization_deberta_v2.py#L165 + // pair of sequences: [CLS] A [SEP] B [SEP] + return 3; + } + + @Override + public TokenizationResult buildTokenizationResult(List tokenizations) { + return new DebertaTokenizationResult(originalVocab, tokenizations, padTokenId); + } + + @Override + public NlpTask.RequestBuilder requestBuilder() { + return (inputs, requestId, truncate, span, windowSize) -> buildTokenizationResult( + IntStream.range(0, inputs.size()) + .boxed() + .flatMap(seqId -> tokenize(inputs.get(seqId), truncate, span, seqId, windowSize).stream()) + .collect(Collectors.toList()) + ).buildRequest(requestId, truncate); + } + + @Override + public OptionalInt getPadTokenId() { + return OptionalInt.of(padTokenId); + } + + @Override + public String getPadToken() { + return PAD_TOKEN; + } + + @Override + public OptionalInt getMaskTokenId() { + Integer maskId = vocab.get(MASK_TOKEN); + if (maskId == null) { + return OptionalInt.empty(); + } + return OptionalInt.of(maskId); + } + + @Override + public String getMaskToken() { + return MASK_TOKEN; + } + + @Override + public List getVocabulary() { + return originalVocab; + } + + @Override + TokenizationResult.TokensBuilder createTokensBuilder(int clsTokenId, int sepTokenId, boolean withSpecialTokens) { + return new DebertaTokenizationResult.DebertaTokensBuilder(clsTokenId, sepTokenId, withSpecialTokens); + } + + public static DebertaV2Tokenizer.Builder builder(List vocab, List scores, DebertaV2Tokenization tokenization) { + return new DebertaV2Tokenizer.Builder(vocab, scores, tokenization); + } + + public static class Builder { + + protected final List originalVocab; + protected final List scores; + protected final SortedMap vocab; + protected boolean withSpecialTokens; + protected int maxSequenceLength; + protected Set neverSplit; + + protected Builder(List vocab, List scores, DebertaV2Tokenization tokenization) { + this.originalVocab = vocab; + this.vocab = buildSortedVocab(vocab); + this.scores = scores; + this.withSpecialTokens = tokenization.withSpecialTokens(); + this.maxSequenceLength = tokenization.maxSequenceLength(); + } + + private static SortedMap buildSortedVocab(List vocab) { + SortedMap sortedVocab = new TreeMap<>(); + for (int i = 0; i < vocab.size(); i++) { + sortedVocab.put(vocab.get(i), i); + } + return sortedVocab; + } + + public DebertaV2Tokenizer.Builder setNeverSplit(Set neverSplit) { + this.neverSplit = neverSplit; + return this; + } + + public DebertaV2Tokenizer.Builder setMaxSequenceLength(int maxSequenceLength) { + this.maxSequenceLength = maxSequenceLength; + return this; + } + + /** + * Include CLS and SEP tokens + * @param withSpecialTokens if true include CLS and SEP tokens + * @return this + */ + public DebertaV2Tokenizer.Builder setWithSpecialTokens(boolean withSpecialTokens) { + this.withSpecialTokens = withSpecialTokens; + return this; + } + + public DebertaV2Tokenizer build() throws IOException { + if (neverSplit == null) { + neverSplit = Collections.emptySet(); + } + + return new DebertaV2Tokenizer(originalVocab, vocab, scores, withSpecialTokens, maxSequenceLength, neverSplit); + } + } + + @Override + public InnerTokenization innerTokenize(String seq) { + List tokenPositionMap = new ArrayList<>(); + try (TokenStream ts = debertaAnalyzer.tokenStream("input", seq)) { + ts.reset(); + PositionIncrementAttribute tokenPos = ts.addAttribute(PositionIncrementAttribute.class); + int currPos = -1; // the PositionIncrement starts at one, so this aligns the first token at position 0 + while (ts.incrementToken()) { + currPos += tokenPos.getPositionIncrement(); + tokenPositionMap.add(currPos); + } + } catch (IOException ex) { + throw new UncheckedIOException(ex); + } + return new InnerTokenization(new ArrayList<>(debertaAnalyzer.getTokens()), tokenPositionMap); + } + + @Override + public void close() { + this.debertaAnalyzer.close(); + } + + static class DebertaAnalyzer extends Analyzer { + private final List vocabulary; + private final List neverSplit; + private final double[] scores; + private UnigramTokenizer innerTokenizer; + private final String unknownToken; + private final PrecompiledCharMapNormalizer.Config normalizer; + + DebertaAnalyzer(List vocabulary, List scores, List neverSplit, String unknownToken) throws IOException { + this.vocabulary = vocabulary; + this.neverSplit = neverSplit; + this.unknownToken = unknownToken; + this.scores = new double[scores.size()]; + int i = 0; + for (Double s : scores) { + this.scores[i++] = s; + } + normalizer = PrecompiledCharMapNormalizer.fromBase64EncodedResource( + "/org/elasticsearch/xpack/ml/inference.nlp.tokenizers/spm_precompiled_normalizer.txt" + ); + } + + @Override + protected Reader initReader(String fieldName, Reader reader) { + if (normalizer.offsets().length > 0) { + return new PrecompiledCharMapNormalizer(normalizer.offsets(), normalizer.utf8str(), reader); + } + return reader; + } + + @Override + protected TokenStreamComponents createComponents(String fieldName) { + this.innerTokenizer = UnigramTokenizer.build(neverSplit, vocabulary, scores, unknownToken, true); + return new TokenStreamComponents(this.innerTokenizer); + } + + public List getTokens() { + if (innerTokenizer != null) { + return innerTokenizer.getTokenizedValues(); + } else { + return List.of(); + } + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java index 5014eb269b081..0b4a5b651d8d4 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java @@ -11,6 +11,7 @@ import org.elasticsearch.core.Strings; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertJapaneseTokenization; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DebertaV2Tokenization; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.MPNetTokenization; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RobertaTokenization; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization; @@ -26,6 +27,7 @@ import java.util.OptionalInt; import java.util.stream.Collectors; +import static java.lang.Math.min; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.TOKENIZATION; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.VOCABULARY; @@ -48,7 +50,9 @@ public abstract class NlpTokenizer implements Releasable { abstract int getNumExtraTokensForSeqPair(); - abstract int defaultSpanForChunking(int maxWindowSize); + int defaultSpanForChunking(int maxWindowSize) { + return (maxWindowSize - numExtraTokensForSingleSequence()) / 2; + } public abstract TokenizationResult buildTokenizationResult(List tokenizations); @@ -85,7 +89,7 @@ public final List tokenize( if (numTokens > windowSize) { switch (truncate) { - case FIRST, SECOND -> { + case FIRST, SECOND, BALANCED -> { // only one sequence exists in this case isTruncated = true; tokenIds = tokenIds.subList(0, isWithSpecialTokens() ? windowSize - numExtraTokensForSingleSequence() : windowSize); tokenPositionMap = tokenPositionMap.subList( @@ -123,7 +127,7 @@ public final List tokenize( int splitStartPos = 0; int spanPrev = -1; while (splitEndPos < tokenIds.size()) { - splitEndPos = Math.min( + splitEndPos = min( splitStartPos + (isWithSpecialTokens() ? windowSize - numExtraTokensForSingleSequence() : windowSize), tokenIds.size() ); @@ -232,6 +236,29 @@ public TokenizationResult.Tokens tokenize( tokenIdsSeq2 = tokenIdsSeq2.subList(0, maxSequenceLength() - extraTokens - tokenIdsSeq1.size()); tokenPositionMapSeq2 = tokenPositionMapSeq2.subList(0, maxSequenceLength() - extraTokens - tokenIdsSeq1.size()); } + case BALANCED -> { + isTruncated = true; + int firstSequenceLength = 0; + + if (tokenIdsSeq2.size() > (maxSequenceLength() - getNumExtraTokensForSeqPair()) / 2) { + firstSequenceLength = min(tokenIdsSeq1.size(), (maxSequenceLength() - getNumExtraTokensForSeqPair()) / 2); + } else { + firstSequenceLength = min( + tokenIdsSeq1.size(), + maxSequenceLength() - tokenIdsSeq2.size() - getNumExtraTokensForSeqPair() + ); + } + int secondSequenceLength = min( + tokenIdsSeq2.size(), + maxSequenceLength() - firstSequenceLength - getNumExtraTokensForSeqPair() + ); + + tokenIdsSeq1 = tokenIdsSeq1.subList(0, firstSequenceLength); + tokenPositionMapSeq1 = tokenPositionMapSeq1.subList(0, firstSequenceLength); + + tokenIdsSeq2 = tokenIdsSeq2.subList(0, secondSequenceLength); + tokenPositionMapSeq2 = tokenPositionMapSeq2.subList(0, secondSequenceLength); + } case NONE -> throw ExceptionsHelper.badRequestException( "Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]", numTokens, @@ -355,7 +382,7 @@ public List tokenize(String seq1, String seq2, Tokeni } while (splitEndPos < tokenIdsSeq2.size()) { - splitEndPos = Math.min(splitStartPos + trueMaxSeqLength, tokenIdsSeq2.size()); + splitEndPos = min(splitStartPos + trueMaxSeqLength, tokenIdsSeq2.size()); // Make sure we do not end on a word if (splitEndPos != tokenIdsSeq2.size()) { while (splitEndPos > splitStartPos + 1 @@ -447,6 +474,9 @@ public static NlpTokenizer build(Vocabulary vocabulary, Tokenization params) thr if (params instanceof XLMRobertaTokenization xlmRobertaTokenization) { return XLMRobertaTokenizer.builder(vocabulary.get(), vocabulary.scores(), xlmRobertaTokenization).build(); } + if (params instanceof DebertaV2Tokenization debertaV2Tokenization) { + return DebertaV2Tokenizer.builder(vocabulary.get(), vocabulary.scores(), debertaV2Tokenization).build(); + } throw new IllegalArgumentException("unknown tokenization type [" + params.getName() + "]"); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/RobertaTokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/RobertaTokenizer.java index e884e84faa85d..6d58d2e2dc2cf 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/RobertaTokenizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/RobertaTokenizer.java @@ -106,11 +106,6 @@ int getNumExtraTokensForSeqPair() { return 4; } - @Override - int defaultSpanForChunking(int maxWindowSize) { - return (maxWindowSize - numExtraTokensForSingleSequence()) / 2; - } - @Override int numExtraTokensForSingleSequence() { return 2; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/UnigramTokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/UnigramTokenizer.java index acb1f6c038ef9..31deac066cba2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/UnigramTokenizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/UnigramTokenizer.java @@ -14,6 +14,7 @@ import org.apache.lucene.analysis.tokenattributes.OffsetAttribute; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.UnicodeUtil; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.util.Maps; import org.elasticsearch.core.Nullable; @@ -49,7 +50,13 @@ public final class UnigramTokenizer extends Tokenizer { private final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class); private final OffsetAttribute offsetAtt = addAttribute(OffsetAttribute.class); - static UnigramTokenizer build(List neverSplit, List dictionary, double[] scores, String unknownToken) { + static UnigramTokenizer build( + List neverSplit, + List dictionary, + double[] scores, + String unknownToken, + boolean byteFallback + ) { if (dictionary.isEmpty()) { throw new IllegalArgumentException("vocab empty"); } @@ -84,7 +91,8 @@ static UnigramTokenizer build(List neverSplit, List dictionary, Optional.ofNullable(tokenToId.get(new BytesRef(unknownToken))) .orElseThrow( () -> new IllegalArgumentException("provided vocabulary does not contain the unknown token of [" + unknownToken + "]") - ) + ), + byteFallback ); } @@ -94,7 +102,7 @@ static UnigramTokenizer build(List neverSplit, List dictionary, private final double minScore; // This may be configurable in the future - private final boolean fuseUnk = true; + private boolean fuseUnk = true; private final double[] vocabScores; private final CharTrie neverSplit; private final CharArraySet neverSplitHash; @@ -104,6 +112,7 @@ static UnigramTokenizer build(List neverSplit, List dictionary, // This is a buffer that is reused per token for decoding the normalized char-sequence into utf-8 bytes // It's usage is NOT thread safe private byte[] normalizedByteBuffer = new byte[128]; + private boolean byteFallback = false; // If true, decompose unknown pieces into UTF-8 byte pieces public UnigramTokenizer( double minScore, @@ -127,6 +136,31 @@ public UnigramTokenizer( this.whitespaceTokenizer = new SimpleWhitespaceTokenizer(); } + public UnigramTokenizer( + double minScore, + double[] vocabScores, + CharTrie neverSplit, + CharArraySet neverSplitHash, + Map vocabToId, + BytesTrie vocabTrie, + int unknownTokenId, + boolean byteFallback + ) { + super(); + this.tokens = new LinkedList<>(); + this.tokenizedValues = new ArrayList<>(); + this.minScore = minScore; + this.neverSplit = neverSplit; + this.neverSplitHash = neverSplitHash; + this.vocabToId = vocabToId; + this.vocabTrie = vocabTrie; + this.unknownTokenId = unknownTokenId; + this.vocabScores = vocabScores; + this.whitespaceTokenizer = new SimpleWhitespaceTokenizer(); + this.byteFallback = byteFallback; + this.fuseUnk = byteFallback == false; + } + List getTokenizedValues() { return tokenizedValues; } @@ -231,6 +265,21 @@ public boolean incrementToken() throws IOException { return false; } + private int[] decomposeBytePieces(byte[] bytes) { + assert this.byteFallback; + + int[] pieces = new int[bytes.length]; + for (int i = 0; i < bytes.length; i++) { + BytesRef decomposedToken = new BytesRef(Strings.format("<0x%02X>", bytes[i])); + Integer piece = vocabToId.get(decomposedToken); + if (piece == null) { + piece = unknownTokenId; + } + pieces[i] = piece; + } + return pieces; + } + /** * This algorithm does the following: * @@ -309,7 +358,21 @@ List tokenize(CharSequence inputSequence, IntToIntFuncti while (endsAtBytes > 0) { BestPathNode node = bestPathNodes[endsAtBytes]; int startsAtBytes = node.startsAtBytePos; - if (node.id == unknownTokenId && fuseUnk) { + if (node.id == unknownTokenId && byteFallback) { + CharSequence multiByteSequence = inputSequence.subSequence(node.startsAtCharPos, endsAtChars); + byte[] bytes = multiByteSequence.toString().getBytes(StandardCharsets.UTF_8); + int[] pieces = decomposeBytePieces(bytes); + for (int i = pieces.length - 1; i >= 0; i--) { + results.add( + new DelimitedToken.Encoded( + Strings.format("<0x%02X>", bytes[i]), + pieces[i], + offsetCorrection.apply(node.startsAtCharPos), + offsetCorrection.apply(startsAtBytes + i) + ) + ); + } + } else if (node.id == unknownTokenId && fuseUnk) { unknownTokens.add( new DelimitedToken.Encoded( new String(normalizedByteBuffer, startsAtBytes, endsAtBytes - startsAtBytes, StandardCharsets.UTF_8), diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/XLMRobertaTokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/XLMRobertaTokenizer.java index 7a856d8e4735a..0e8793eb374ca 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/XLMRobertaTokenizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/XLMRobertaTokenizer.java @@ -101,11 +101,6 @@ int getNumExtraTokensForSeqPair() { return 4; } - @Override - int defaultSpanForChunking(int maxWindowSize) { - return (maxWindowSize - numExtraTokensForSingleSequence()) / 2; - } - @Override int numExtraTokensForSingleSequence() { return 2; @@ -284,7 +279,7 @@ protected Reader initReader(String fieldName, Reader reader) { @Override protected TokenStreamComponents createComponents(String fieldName) { - this.innerTokenizer = UnigramTokenizer.build(neverSplit, vocabulary, scores, unknownToken); + this.innerTokenizer = UnigramTokenizer.build(neverSplit, vocabulary, scores, unknownToken, false); return new TokenStreamComponents(this.innerTokenizer); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java index 901fea45d9de9..ccebe3bf0ca98 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java @@ -760,6 +760,119 @@ public void testTokenizeLargeInputMultiSequenceTruncation() { } + public void testTokenizeLargeInputMultiSequenceBalancedTruncation() { + try ( + BertTokenizer tokenizer = BertTokenizer.builder( + TEST_CASED_VOCAB, + new BertTokenization(null, true, 10, Tokenization.Truncate.BALANCED, -1) + ).build() + ) { + + { // both sequences are truncated + TokenizationResult.Tokens tokenization = tokenizer.tokenize( + "Elasticsearch is fun", + "Godzilla my little red car", + Tokenization.Truncate.BALANCED, + 0 + ); + + var tokenStream = Arrays.stream(tokenization.tokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList()); + assertThat( + tokenStream, + contains( + BertTokenizer.CLASS_TOKEN, + "Elastic", + "##search", + "is", + BertTokenizer.SEPARATOR_TOKEN, + "God", + "##zilla", + "my", + "little", + BertTokenizer.SEPARATOR_TOKEN + ) + ); + } + + { // first sequence is too short to be truncated + TokenizationResult.Tokens tokenization = tokenizer.tokenize( + "Elasticsearch", + "Godzilla my little red car", + Tokenization.Truncate.BALANCED, + 0 + ); + + var tokenStream = Arrays.stream(tokenization.tokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList()); + assertThat( + tokenStream, + contains( + BertTokenizer.CLASS_TOKEN, + "Elastic", + "##search", + BertTokenizer.SEPARATOR_TOKEN, + "God", + "##zilla", + "my", + "little", + "red", + BertTokenizer.SEPARATOR_TOKEN + ) + ); + } + + { // second sequence is too short to be truncated + TokenizationResult.Tokens tokenization = tokenizer.tokenize( + "Elasticsearch is my little red fun", + "Godzilla", + Tokenization.Truncate.BALANCED, + 0 + ); + + var tokenStream = Arrays.stream(tokenization.tokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList()); + assertThat( + tokenStream, + contains( + BertTokenizer.CLASS_TOKEN, + "Elastic", + "##search", + "is", + "my", + "little", + BertTokenizer.SEPARATOR_TOKEN, + "God", + "##zilla", + BertTokenizer.SEPARATOR_TOKEN + ) + ); + } + + { // both sequences are too short to be truncated + TokenizationResult.Tokens tokenization = tokenizer.tokenize("Elasticsearch", "Godzilla", Tokenization.Truncate.BALANCED, 0); + + var tokenStream = Arrays.stream(tokenization.tokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList()); + assertThat( + tokenStream, + contains( + BertTokenizer.CLASS_TOKEN, + "Elastic", + "##search", + BertTokenizer.SEPARATOR_TOKEN, + "God", + "##zilla", + BertTokenizer.SEPARATOR_TOKEN + ) + ); + } + + expectThrows( + ElasticsearchStatusException.class, + () -> BertTokenizer.builder(TEST_CASED_VOCAB, new BertTokenization(null, true, 8, Tokenization.Truncate.NONE, -1)) + .build() + .tokenize("Elasticsearch is fun", "Godzilla my little red car", Tokenization.Truncate.NONE, 0) + ); + } + } + public void testMultiSeqRequiresSpecialTokens() { try ( BertTokenizer tokenizer = BertTokenizer.builder( diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/DebertaV2TokenizerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/DebertaV2TokenizerTests.java new file mode 100644 index 0000000000000..bbe509da67452 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/DebertaV2TokenizerTests.java @@ -0,0 +1,206 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DebertaV2Tokenization; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DebertaV2Tokenizer.MASK_TOKEN; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; + +public class DebertaV2TokenizerTests extends ESTestCase { + + private static final List TEST_CASE_VOCAB = List.of( + DebertaV2Tokenizer.CLASS_TOKEN, + DebertaV2Tokenizer.PAD_TOKEN, + DebertaV2Tokenizer.SEPARATOR_TOKEN, + DebertaV2Tokenizer.UNKNOWN_TOKEN, + "▁Ela", + "stic", + "search", + "▁is", + "▁fun", + "▁God", + "z", + "illa", + "▁my", + "▁little", + "▁red", + "▁car", + "▁😀", + "▁🇸🇴", + MASK_TOKEN, + ".", + "<0xC2>", + "<0xAD>", + "▁" + ); + private static final List TEST_CASE_SCORES = List.of( + 0.0, + 0.0, + 0.0, + 0.0, + -12.535264015197754, + -12.300995826721191, + -13.255199432373047, + -7.402246475219727, + -11.201482772827148, + -10.576351165771484, + -7.898513317108154, + -10.230172157287598, + -9.18289566040039, + -11.451579093933105, + -10.858806610107422, + -10.214239120483398, + -10.230172157287598, + -9.451579093933105, + 0.0, + -3.0, + 1.0, + 2.0, + -7.97025 + ); + + private List tokenStrings(List tokens) { + return tokens.stream().map(DelimitedToken::toString).collect(Collectors.toList()); + } + + public void testTokenize() throws IOException { + try ( + DebertaV2Tokenizer tokenizer = DebertaV2Tokenizer.builder( + TEST_CASE_VOCAB, + TEST_CASE_SCORES, + new DebertaV2Tokenization(false, false, null, Tokenization.Truncate.NONE, -1) + ).build() + ) { + TokenizationResult.Tokens tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE, -1, 0, null) + .get(0); + assertThat(tokenStrings(tokenization.tokens().get(0)), contains("▁Ela", "stic", "search", "▁fun")); + assertArrayEquals(new int[] { 4, 5, 6, 8 }, tokenization.tokenIds()); + assertArrayEquals(new int[] { 0, 1, 2, 3 }, tokenization.tokenMap()); + } + } + + public void testSurrogatePair() throws IOException { + try ( + DebertaV2Tokenizer tokenizer = DebertaV2Tokenizer.builder( + TEST_CASE_VOCAB, + TEST_CASE_SCORES, + new DebertaV2Tokenization(false, false, null, Tokenization.Truncate.NONE, -1) + ).build() + ) { + TokenizationResult.Tokens tokenization = tokenizer.tokenize( + "Elastic" + "\u00AD" + "search 😀" + "\u00AD" + " fun", + Tokenization.Truncate.NONE, + -1, + 0, + null + ).get(0); + assertArrayEquals(new int[] { 4, 5, 20, 21, 6, 16, 20, 21, 8 }, tokenization.tokenIds()); + assertThat( + tokenStrings(tokenization.tokens().get(0)), + contains("▁Ela", "stic", "<0xC2>", "<0xAD>", "search", "▁\uD83D\uDE00", "<0xC2>", "<0xAD>", "▁fun") + ); + + tokenization = tokenizer.tokenize("😀", Tokenization.Truncate.NONE, -1, 0, null).get(0); + assertThat(tokenStrings(tokenization.tokens().get(0)), contains("▁\uD83D\uDE00")); + + tokenization = tokenizer.tokenize("Elasticsearch 😀", Tokenization.Truncate.NONE, -1, 0, null).get(0); + assertThat(tokenStrings(tokenization.tokens().get(0)), contains("▁Ela", "stic", "search", "▁\uD83D\uDE00")); + + tokenization = tokenizer.tokenize("Elasticsearch 😀 fun", Tokenization.Truncate.NONE, -1, 0, null).get(0); + assertThat(tokenStrings(tokenization.tokens().get(0)), contains("▁Ela", "stic", "search", "▁\uD83D\uDE00", "▁fun")); + + } + } + + public void testMultiByteEmoji() throws IOException { + try ( + DebertaV2Tokenizer tokenizer = DebertaV2Tokenizer.builder( + TEST_CASE_VOCAB, + TEST_CASE_SCORES, + new DebertaV2Tokenization(false, false, null, Tokenization.Truncate.NONE, -1) + ).build() + ) { + TokenizationResult.Tokens tokenization = tokenizer.tokenize("🇸🇴", Tokenization.Truncate.NONE, -1, 0, null).get(0); + assertThat(tokenStrings(tokenization.tokens().get(0)), contains("▁🇸🇴")); + assertThat(tokenization.tokenIds()[0], not(equalTo(3))); // not the unknown token + + tokenization = tokenizer.tokenize("🏁", Tokenization.Truncate.NONE, -1, 0, null).get(0); + assertThat(tokenStrings(tokenization.tokens().get(0)), contains("▁", "<0xF0>", "<0x9F>", "<0x8F>", "<0x81>")); + // contains the 4-byte sequence representing the emoji which is not in the vocab, due to byteFallback enabled + } + } + + public void testTokenizeWithNeverSplit() throws IOException { + try ( + DebertaV2Tokenizer tokenizer = DebertaV2Tokenizer.builder( + TEST_CASE_VOCAB, + TEST_CASE_SCORES, + new DebertaV2Tokenization(false, true, null, Tokenization.Truncate.NONE, -1) + ).build() + ) { + TokenizationResult.Tokens tokenization = tokenizer.tokenize( + "Elasticsearch ." + MASK_TOKEN + ".", + Tokenization.Truncate.NONE, + -1, + 0, + null + ).get(0); + assertThat(tokenStrings(tokenization.tokens().get(0)), contains("▁Ela", "stic", "search", "▁", ".", MASK_TOKEN, "▁", ".")); + } + } + + public void testMultiSeqTokenization() throws IOException { + try ( + DebertaV2Tokenizer tokenizer = DebertaV2Tokenizer.builder( + TEST_CASE_VOCAB, + TEST_CASE_SCORES, + new DebertaV2Tokenization(false, false, null, Tokenization.Truncate.NONE, -1) + ).setWithSpecialTokens(true).build() + ) { + TokenizationResult.Tokens tokenization = tokenizer.tokenize( + "Elasticsearch is fun", + "Godzilla my little red car", + Tokenization.Truncate.NONE, + 0 + ); + + var tokenStream = Arrays.stream(tokenization.tokenIds()).mapToObj(TEST_CASE_VOCAB::get).collect(Collectors.toList()); + assertThat( + tokenStream, + contains( + DebertaV2Tokenizer.CLASS_TOKEN, + "▁Ela", + "stic", + "search", + "▁is", + "▁fun", + DebertaV2Tokenizer.SEPARATOR_TOKEN, + "▁God", + "z", + "illa", + "▁my", + "▁little", + "▁red", + "▁car", + DebertaV2Tokenizer.SEPARATOR_TOKEN + ) + ); + } + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizerTests.java index fc2a31a06e187..ad6f44e77aafc 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizerTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertJapaneseTokenization; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DebertaV2Tokenization; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.MPNetTokenization; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RobertaTokenization; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization; @@ -44,6 +45,13 @@ public class NlpTokenizerTests extends ESTestCase { RobertaTokenizer.CLASS_TOKEN, RobertaTokenizer.MASK_TOKEN ); + public static final List DEBERTA_REQUIRED_VOCAB = List.of( + DebertaV2Tokenizer.UNKNOWN_TOKEN, + DebertaV2Tokenizer.SEPARATOR_TOKEN, + DebertaV2Tokenizer.PAD_TOKEN, + DebertaV2Tokenizer.CLASS_TOKEN, + DebertaV2Tokenizer.MASK_TOKEN + ); void validateBuilder(List vocab, Tokenization tokenization, Class expectedClass) throws IOException { Vocabulary vocabulary = new Vocabulary(vocab, "model-name", null, null); @@ -66,5 +74,8 @@ public void testBuildTokenizer() throws IOException { Tokenization xlmRoberta = new XLMRobertaTokenization(null, null, Tokenization.Truncate.NONE, -1); validateBuilder(ROBERTA_REQUIRED_VOCAB, xlmRoberta, XLMRobertaTokenizer.class); + + Tokenization debertaV2 = new DebertaV2Tokenization(false, null, null, Tokenization.Truncate.NONE, -1); + validateBuilder(DEBERTA_REQUIRED_VOCAB, debertaV2, DebertaV2Tokenizer.class); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/UnigramTokenizerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/UnigramTokenizerTests.java index f97055b29ca7b..d1ce2fea9d1dc 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/UnigramTokenizerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/UnigramTokenizerTests.java @@ -39,8 +39,24 @@ public void testSimpleTokenization() throws IOException { public void testLessSimpleTokenization() throws IOException { TestNLPAnalyzer analyzer = new TestNLPAnalyzer( - List.of(UNKNOWN_TOKEN, PREFIX + "ab", "cd", PREFIX + "abc", "a", "b", "c", "ABC", "abcdabcd", "q", "r", "qr", ""), - List.of(0.0, 0.0, -0.1, -0.2, -0.3, -0.4, -0.5, -0.5, 20.0, 20.5, 20.5, -0.5, 0.0), + List.of( + UNKNOWN_TOKEN, + PREFIX + "ab", + "cd", + PREFIX + "abc", + "a", + "b", + "c", + "ABC", + "abcdabcd", + "q", + "r", + "qr", + "", + "aa", + "aaaa" + ), + List.of(0.0, 0.0, -0.1, -0.2, -0.3, -0.4, -0.5, -0.5, 20.0, 20.5, 20.5, -0.5, 0.0, -13.5467, -14.9644), UNKNOWN_TOKEN, new PrecompiledCharMapNormalizer.Config(new int[0], "") ); @@ -53,6 +69,31 @@ public void testLessSimpleTokenization() throws IOException { assertAnalyzesToNoCharFilter(analyzer, " \nabcd \n\n abcc \n", new String[] { PREFIX + "ab", "cd", PREFIX + "abc", "c" }); } + public void testLessSimpleTokenizationForRepeatingCharacters() throws IOException { + TestNLPAnalyzer analyzer = new TestNLPAnalyzer( + List.of(UNKNOWN_TOKEN, "HH", "HHHH", PREFIX + "H", "HHH", PREFIX + "HH", PREFIX, PREFIX + "HHH"), + List.of(0.0, -13.5467, -14.9644, -9.17478, -15.1165, -13.201, -7.97025, -15.602), + UNKNOWN_TOKEN, + PrecompiledCharMapNormalizer.fromBase64EncodedResource( + "/org/elasticsearch/xpack/ml/inference.nlp.tokenizers/spm_precompiled_normalizer.txt" + ) + ); + + assertAnalyzesToNoCharFilter(analyzer, "HHHHHHHHHHHH", new String[] { PREFIX, "HHHH", "HHHH", "HHHH" }); + assertAnalyzesToNoCharFilter(analyzer, "HHHHHHHHHHH", new String[] { PREFIX + "HHH", "HHHH", "HHHH" }); + assertAnalyzesToNoCharFilter(analyzer, "HHHHHHHHHH", new String[] { PREFIX + "HH", "HHHH", "HHHH" }); + assertAnalyzesToNoCharFilter(analyzer, "HHHHHHHHH", new String[] { PREFIX + "H", "HHHH", "HHHH" }); + assertAnalyzesToNoCharFilter(analyzer, "HHHHHHHH", new String[] { PREFIX, "HHHH", "HHHH" }); + assertAnalyzesToNoCharFilter(analyzer, "HHHHHHH", new String[] { PREFIX + "HHH", "HHHH" }); + assertAnalyzesToNoCharFilter(analyzer, "HHHHHH", new String[] { PREFIX + "HH", "HHHH" }); + assertAnalyzesToNoCharFilter(analyzer, "HHHHH", new String[] { PREFIX + "H", "HHHH" }); + assertAnalyzesToNoCharFilter(analyzer, "HHHH", new String[] { PREFIX, "HHHH" }); + assertAnalyzesToNoCharFilter(analyzer, "HHH", new String[] { PREFIX + "HHH" }); + assertAnalyzesToNoCharFilter(analyzer, "HH", new String[] { PREFIX + "HH" }); + assertAnalyzesToNoCharFilter(analyzer, "H", new String[] { PREFIX + "H" }); + + } + public void testLessSimpleTokenizationWithNeverSplit() throws IOException { TestNLPAnalyzer analyzer = new TestNLPAnalyzer( List.of( @@ -153,7 +194,7 @@ protected Reader initReader(String fieldName, Reader reader) { @Override protected TokenStreamComponents createComponents(String fieldName) { - UnigramTokenizer tokenizer = UnigramTokenizer.build(NEVER_SPLIT, dictionary, scores, unknownToken); + UnigramTokenizer tokenizer = UnigramTokenizer.build(NEVER_SPLIT, dictionary, scores, unknownToken, false); return new TokenStreamComponents(tokenizer); } }