Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenizationUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DebertaV2Tokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DebertaV2TokenizationUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigUpdate;
Expand Down Expand Up @@ -547,6 +549,13 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
(p, c) -> XLMRobertaTokenization.fromXContent(p, (boolean) c)
)
);
namedXContent.add(
new NamedXContentRegistry.Entry(
Tokenization.class,
new ParseField(DebertaV2Tokenization.NAME),
(p, c) -> DebertaV2Tokenization.fromXContent(p, (boolean) c)
)
);

namedXContent.add(
new NamedXContentRegistry.Entry(
Expand Down Expand Up @@ -583,6 +592,13 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
(p, c) -> XLMRobertaTokenizationUpdate.fromXContent(p)
)
);
namedXContent.add(
new NamedXContentRegistry.Entry(
TokenizationUpdate.class,
DebertaV2TokenizationUpdate.NAME,
(p, c) -> DebertaV2TokenizationUpdate.fromXContent(p)
)
);

return namedXContent;
}
Expand Down Expand Up @@ -791,6 +807,7 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
);
namedWriteables.add(new NamedWriteableRegistry.Entry(Tokenization.class, RobertaTokenization.NAME, RobertaTokenization::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(Tokenization.class, XLMRobertaTokenization.NAME, XLMRobertaTokenization::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(Tokenization.class, DebertaV2Tokenization.NAME, DebertaV2Tokenization::new));

namedWriteables.add(
new NamedWriteableRegistry.Entry(
Expand Down Expand Up @@ -827,6 +844,9 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
XLMRobertaTokenizationUpdate::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TokenizationUpdate.class, DebertaV2Tokenization.NAME, DebertaV2TokenizationUpdate::new)
);

return namedWriteables;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;

public class DebertaV2Tokenization extends Tokenization {

public static final String NAME = "deberta_v2";
public static final String MASK_TOKEN = "[MASK]";

public static ConstructingObjectParser<DebertaV2Tokenization, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<DebertaV2Tokenization, Void> parser = new ConstructingObjectParser<>(
NAME,
ignoreUnknownFields,
a -> new DebertaV2Tokenization(
(Boolean) a[0],
(Boolean) a[1],
(Integer) a[2],
a[3] == null ? null : Truncate.fromString((String) a[3]),
(Integer) a[4]
)
);
declareCommonFields(parser);
return parser;
}

private static final ConstructingObjectParser<DebertaV2Tokenization, Void> LENIENT_PARSER = createParser(true);
private static final ConstructingObjectParser<DebertaV2Tokenization, Void> STRICT_PARSER = createParser(false);

public static DebertaV2Tokenization fromXContent(XContentParser parser, boolean lenient) {
return lenient ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null);
}

public DebertaV2Tokenization(
Boolean doLowerCase,
Boolean withSpecialTokens,
Integer maxSequenceLength,
Truncate truncate,
Integer span
) {
super(doLowerCase, withSpecialTokens, maxSequenceLength, truncate, span);
}

public DebertaV2Tokenization(StreamInput in) throws IOException {
super(in);
}

@Override
Tokenization buildWindowingTokenization(int updatedMaxSeqLength, int updatedSpan) {
return new DebertaV2Tokenization(doLowerCase, withSpecialTokens, updatedMaxSeqLength, truncate, updatedSpan);
}

@Override
public String getMaskToken() {
return MASK_TOKEN;
}

@Override
XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
return builder;
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public String getName() {
return NAME;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Optional;

public class DebertaV2TokenizationUpdate extends AbstractTokenizationUpdate {
public static final ParseField NAME = new ParseField(DebertaV2Tokenization.NAME);

public static ConstructingObjectParser<DebertaV2TokenizationUpdate, Void> PARSER = new ConstructingObjectParser<>(
"deberta_v2_tokenization_update",
a -> new DebertaV2TokenizationUpdate(a[0] == null ? null : Tokenization.Truncate.fromString((String) a[0]), (Integer) a[1])
);

static {
declareCommonParserFields(PARSER);
}

public static DebertaV2TokenizationUpdate fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

public DebertaV2TokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) {
super(truncate, span);
}

public DebertaV2TokenizationUpdate(StreamInput in) throws IOException {
super(in);
}

@Override
public Tokenization apply(Tokenization originalConfig) {
if (originalConfig instanceof DebertaV2Tokenization debertaV2Tokenization) {
if (isNoop()) {
return debertaV2Tokenization;
}

Tokenization.validateSpanAndTruncate(getTruncate(), getSpan());

if (getTruncate() != null && getTruncate().isInCompatibleWithSpan() == false) {
// When truncate value is incompatible with span wipe out
// the existing span setting to avoid an invalid combination of settings.
// This avoids the user have to set span to the special unset value
return new DebertaV2Tokenization(
debertaV2Tokenization.doLowerCase(),
debertaV2Tokenization.withSpecialTokens(),
debertaV2Tokenization.maxSequenceLength(),
getTruncate(),
null
);
}

return new DebertaV2Tokenization(
debertaV2Tokenization.doLowerCase(),
debertaV2Tokenization.withSpecialTokens(),
debertaV2Tokenization.maxSequenceLength(),
Optional.ofNullable(this.getTruncate()).orElse(originalConfig.getTruncate()),
Optional.ofNullable(this.getSpan()).orElse(originalConfig.getSpan())
);
}
throw ExceptionsHelper.badRequestException(
"Tokenization config of type [{}] can not be updated with a request of type [{}]",
originalConfig.getName(),
getName()
);
}

@Override
public String getWriteableName() {
return NAME.getPreferredName();
}

@Override
public String getName() {
return NAME.getPreferredName();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ public static TokenizationUpdate tokenizationFromMap(Map<String, Object> map) {
RobertaTokenizationUpdate.NAME.getPreferredName(),
RobertaTokenizationUpdate::new,
XLMRobertaTokenizationUpdate.NAME.getPreferredName(),
XLMRobertaTokenizationUpdate::new
XLMRobertaTokenizationUpdate::new,
DebertaV2Tokenization.NAME,
DebertaV2TokenizationUpdate::new
);

Map<String, Object> tokenizationConfig = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ public enum Truncate {
public boolean isInCompatibleWithSpan() {
return false;
}
};
},
BALANCED;

public boolean isInCompatibleWithSpan() {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ public void testTokenizationFromMap() {
);
assertThat(
e.getMessage(),
containsString("unknown tokenization type expecting one of [bert, bert_ja, mpnet, roberta, xlm_roberta] got [not_bert]")
containsString(
"unknown tokenization type expecting one of [bert, bert_ja, deberta_v2, mpnet, roberta, xlm_roberta] got [not_bert]"
)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,6 @@ boolean isWithSpecialTokens() {
return withSpecialTokens;
}

@Override
int defaultSpanForChunking(int maxWindowSize) {
return (maxWindowSize - numExtraTokensForSingleSequence()) / 2;
}

@Override
int getNumExtraTokensForSeqPair() {
return 3;
Expand Down
Loading
Loading