Skip to content

Commit bd73ea1

Browse files
authored
OPENNLP-1518: Roberta-based Models - Add support for utilization via Onnx (#998)
1 parent f9d250b commit bd73ea1

File tree

8 files changed

+266
-27
lines changed

8 files changed

+266
-27
lines changed

opennlp-api/src/main/java/opennlp/tools/tokenize/WordpieceTokenizer.java

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,27 @@
4545
*/
4646
public class WordpieceTokenizer implements Tokenizer {
4747

48-
private static final Pattern PUNCTUATION_PATTERN = Pattern.compile("\\p{Punct}+");
49-
private static final String CLASSIFICATION_TOKEN = "[CLS]";
50-
private static final String SEPARATOR_TOKEN = "[SEP]";
51-
private static final String UNKNOWN_TOKEN = "[UNK]";
48+
/** BERT classification token: {@code [CLS]}. */
49+
public static final String BERT_CLS_TOKEN = "[CLS]";
50+
/** BERT separator token: {@code [SEP]}. */
51+
public static final String BERT_SEP_TOKEN = "[SEP]";
52+
/** BERT unknown token: {@code [UNK]}. */
53+
public static final String BERT_UNK_TOKEN = "[UNK]";
54+
55+
/** RoBERTa classification token: {@code <s>}. */
56+
public static final String ROBERTA_CLS_TOKEN = "<s>";
57+
/** RoBERTa separator token. */
58+
public static final String ROBERTA_SEP_TOKEN = "</s>";
59+
/** RoBERTa unknown token. */
60+
public static final String ROBERTA_UNK_TOKEN = "<unk>";
61+
62+
private static final Pattern PUNCTUATION_PATTERN =
63+
Pattern.compile("\\p{Punct}+");
5264

5365
private final Set<String> vocabulary;
66+
private final String classificationToken;
67+
private final String separatorToken;
68+
private final String unknownToken;
5469
private int maxTokenLength = 50;
5570

5671
/**
@@ -60,7 +75,7 @@ public class WordpieceTokenizer implements Tokenizer {
6075
* @param vocabulary A set of tokens considered the vocabulary.
6176
*/
6277
public WordpieceTokenizer(Set<String> vocabulary) {
63-
this.vocabulary = vocabulary;
78+
this(vocabulary, BERT_CLS_TOKEN, BERT_SEP_TOKEN, BERT_UNK_TOKEN);
6479
}
6580

6681
/**
@@ -75,6 +90,29 @@ public WordpieceTokenizer(Set<String> vocabulary, int maxTokenLength) {
7590
this.maxTokenLength = maxTokenLength;
7691
}
7792

93+
/**
94+
* Initializes a {@link WordpieceTokenizer} with a
95+
* {@code vocabulary} and custom special tokens.
96+
* This allows support for models like RoBERTa that
97+
* use different special tokens instead of the BERT
98+
* defaults.
99+
*
100+
* @param vocabulary The vocabulary.
101+
* @param classificationToken The CLS token.
102+
* @param separatorToken The SEP token.
103+
* @param unknownToken The UNK token.
104+
*/
105+
public WordpieceTokenizer(
106+
final Set<String> vocabulary,
107+
final String classificationToken,
108+
final String separatorToken,
109+
final String unknownToken) {
110+
this.vocabulary = vocabulary;
111+
this.classificationToken = classificationToken;
112+
this.separatorToken = separatorToken;
113+
this.unknownToken = unknownToken;
114+
}
115+
78116
@Override
79117
public Span[] tokenizePos(final String text) {
80118
// TODO: Implement this.
@@ -85,7 +123,7 @@ public Span[] tokenizePos(final String text) {
85123
public String[] tokenize(final String text) {
86124

87125
final List<String> tokens = new LinkedList<>();
88-
tokens.add(CLASSIFICATION_TOKEN);
126+
tokens.add(classificationToken);
89127

90128
// Put spaces around punctuation.
91129
final String spacedPunctuation = PUNCTUATION_PATTERN.matcher(text).replaceAll(" $0 ");
@@ -146,7 +184,7 @@ public String[] tokenize(final String text) {
146184
// If the word can't be represented by vocabulary pieces replace
147185
// it with a specified "unknown" token.
148186
if (!found) {
149-
tokens.add(UNKNOWN_TOKEN);
187+
tokens.add(unknownToken);
150188
break;
151189
}
152190

@@ -157,14 +195,14 @@ public String[] tokenize(final String text) {
157195

158196
} else {
159197

160-
// If the token's length is greater than the max length just add [UNK] instead.
161-
tokens.add(UNKNOWN_TOKEN);
198+
// If the token's length is greater than the max length just add unknown token instead.
199+
tokens.add(unknownToken);
162200

163201
}
164202

165203
}
166204

167-
tokens.add(SEPARATOR_TOKEN);
205+
tokens.add(separatorToken);
168206

169207
return tokens.toArray(new String[0]);
170208

opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java

Lines changed: 82 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,22 @@
1919

2020
import java.io.File;
2121
import java.io.IOException;
22+
import java.nio.charset.StandardCharsets;
2223
import java.nio.file.Files;
2324
import java.nio.file.Path;
2425
import java.util.HashMap;
2526
import java.util.Map;
2627
import java.util.concurrent.atomic.AtomicInteger;
28+
import java.util.regex.Matcher;
29+
import java.util.regex.Pattern;
2730
import java.util.stream.Stream;
2831

2932
import ai.onnxruntime.OrtEnvironment;
3033
import ai.onnxruntime.OrtException;
3134
import ai.onnxruntime.OrtSession;
3235

3336
import opennlp.tools.tokenize.Tokenizer;
37+
import opennlp.tools.tokenize.WordpieceTokenizer;
3438

3539
/**
3640
* Base class for OpenNLP deep-learning classes using ONNX Runtime.
@@ -46,21 +50,92 @@ public abstract class AbstractDL implements AutoCloseable {
4650
protected Tokenizer tokenizer;
4751
protected Map<String, Integer> vocab;
4852

53+
private static final Pattern JSON_ENTRY_PATTERN =
54+
Pattern.compile("\"((?:[^\"\\\\]|\\\\.)*)\"\\s*:\\s*(\\d+)");
55+
4956
/**
5057
* Loads a vocabulary {@link File} from disk.
58+
* Supports both plain text files (one token per
59+
* line) and JSON files mapping tokens to integer
60+
* IDs.
5161
*
5262
* @param vocabFile The vocabulary file.
53-
* @return A map of vocabulary words to integer IDs.
54-
* @throws IOException Thrown if the vocabulary file cannot be opened or read.
63+
* @return A map of vocabulary words to IDs.
64+
* @throws IOException Thrown if the vocabulary
65+
* file cannot be opened or read.
5566
*/
56-
public Map<String, Integer> loadVocab(final File vocabFile) throws IOException {
67+
public Map<String, Integer> loadVocab(
68+
final File vocabFile) throws IOException {
5769

58-
final Map<String, Integer> vocab = new HashMap<>();
59-
final AtomicInteger counter = new AtomicInteger(0);
70+
final Path vocabPath =
71+
Path.of(vocabFile.getPath());
72+
final String content = Files.readString(
73+
vocabPath, StandardCharsets.UTF_8);
74+
final String trimmed = content.trim();
75+
76+
// Detect JSON format by leading brace
77+
if (trimmed.startsWith("{")) {
78+
return loadJsonVocab(trimmed);
79+
}
80+
81+
final Map<String, Integer> vocab =
82+
new HashMap<>();
83+
final AtomicInteger counter =
84+
new AtomicInteger(0);
85+
86+
try (Stream<String> lines = Files.lines(
87+
vocabPath, StandardCharsets.UTF_8)) {
88+
lines.forEach(line ->
89+
vocab.put(line, counter.getAndIncrement())
90+
);
91+
}
92+
93+
return vocab;
94+
}
6095

61-
try (Stream<String> lines = Files.lines(Path.of(vocabFile.getPath()))) {
96+
/**
97+
* Creates a {@link WordpieceTokenizer} that uses the
98+
* appropriate special tokens based on the vocabulary.
99+
* If the vocabulary contains RoBERTa-style tokens,
100+
* those are used. Otherwise, the BERT defaults are
101+
* used.
102+
*
103+
* @param vocab The vocabulary map.
104+
* @return A configured {@link WordpieceTokenizer}.
105+
*/
106+
protected WordpieceTokenizer createTokenizer(
107+
final Map<String, Integer> vocab) {
108+
if (vocab.containsKey(
109+
WordpieceTokenizer.ROBERTA_CLS_TOKEN)
110+
&& vocab.containsKey(
111+
WordpieceTokenizer.ROBERTA_SEP_TOKEN)) {
112+
final String unk = vocab.containsKey(
113+
WordpieceTokenizer.ROBERTA_UNK_TOKEN)
114+
? WordpieceTokenizer.ROBERTA_UNK_TOKEN
115+
: WordpieceTokenizer.BERT_UNK_TOKEN;
116+
return new WordpieceTokenizer(
117+
vocab.keySet(),
118+
WordpieceTokenizer.ROBERTA_CLS_TOKEN,
119+
WordpieceTokenizer.ROBERTA_SEP_TOKEN,
120+
unk);
121+
}
122+
return new WordpieceTokenizer(vocab.keySet());
123+
}
124+
125+
private Map<String, Integer> loadJsonVocab(final String json) {
126+
127+
final Map<String, Integer> vocab = new HashMap<>();
128+
final Matcher matcher = JSON_ENTRY_PATTERN.matcher(json);
62129

63-
lines.forEach(line -> vocab.put(line, counter.getAndIncrement()));
130+
while (matcher.find()) {
131+
final String token = matcher.group(1)
132+
.replace("\\\"", "\"")
133+
.replace("\\\\", "\\")
134+
.replace("\\/", "/")
135+
.replace("\\n", "\n")
136+
.replace("\\t", "\t");
137+
final int id = Integer.parseInt(matcher.group(2));
138+
vocab.put(token, id);
64139
}
65140

66141
return vocab;

opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
import opennlp.dl.Tokens;
4646
import opennlp.dl.doccat.scoring.ClassificationScoringStrategy;
4747
import opennlp.tools.doccat.DocumentCategorizer;
48-
import opennlp.tools.tokenize.WordpieceTokenizer;
48+
4949

5050
/**
5151
* An implementation of {@link DocumentCategorizer} that performs document classification
@@ -90,7 +90,7 @@ public DocumentCategorizerDL(File model, File vocabulary, Map<Integer, String> c
9090

9191
this.session = env.createSession(model.getPath(), sessionOptions);
9292
this.vocab = loadVocab(vocabulary);
93-
this.tokenizer = new WordpieceTokenizer(vocab.keySet());
93+
this.tokenizer = createTokenizer(vocab);
9494
this.categories = categories;
9595
this.classificationScoringStrategy = classificationScoringStrategy;
9696
this.inferenceOptions = inferenceOptions;
@@ -125,7 +125,7 @@ public DocumentCategorizerDL(File model, File vocabulary, File config,
125125

126126
this.session = env.createSession(model.getPath(), sessionOptions);
127127
this.vocab = loadVocab(vocabulary);
128-
this.tokenizer = new WordpieceTokenizer(vocab.keySet());
128+
this.tokenizer = createTokenizer(vocab);
129129
this.categories = readCategoriesFromFile(config);
130130
this.classificationScoringStrategy = classificationScoringStrategy;
131131
this.inferenceOptions = inferenceOptions;
@@ -158,11 +158,22 @@ public double[] categorize(String[] strings) {
158158
LongBuffer.wrap(t.types()), new long[] {1, t.types().length}));
159159
}
160160

161-
// The outputs from the model.
162-
final float[][] v = (float[][]) session.run(inputs).get(0).getValue();
161+
// The outputs from the model. Some models return a 2D array (e.g. BERT),
162+
// while others return a 1D array (e.g. RoBERTa).
163+
final Object output = session.run(inputs).get(0).getValue();
164+
165+
final float[] rawScores;
166+
if (output instanceof float[][] v) {
167+
rawScores = v[0];
168+
} else if (output instanceof float[] v) {
169+
rawScores = v;
170+
} else {
171+
throw new IllegalStateException(
172+
"Unexpected model output type: " + output.getClass().getName());
173+
}
163174

164175
// Keep track of all scores.
165-
final double[] categoryScoresForTokens = softmax(v[0]);
176+
final double[] categoryScoresForTokens = softmax(rawScores);
166177
scores.add(categoryScoresForTokens);
167178

168179
}

opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
import opennlp.dl.Tokens;
4040
import opennlp.tools.namefind.TokenNameFinder;
4141
import opennlp.tools.sentdetect.SentenceDetector;
42-
import opennlp.tools.tokenize.WordpieceTokenizer;
4342
import opennlp.tools.util.Span;
4443

4544
/**
@@ -104,7 +103,7 @@ public NameFinderDL(File model, File vocabulary, Map<Integer, String> ids2Labels
104103
this.session = env.createSession(model.getPath(), sessionOptions);
105104
this.ids2Labels = ids2Labels;
106105
this.vocab = loadVocab(vocabulary);
107-
this.tokenizer = new WordpieceTokenizer(vocab.keySet());
106+
this.tokenizer = createTokenizer(vocab);
108107
this.inferenceOptions = inferenceOptions;
109108
this.sentenceDetector = sentenceDetector;
110109

opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import opennlp.dl.AbstractDL;
3333
import opennlp.dl.Tokens;
3434
import opennlp.tools.tokenize.Tokenizer;
35-
import opennlp.tools.tokenize.WordpieceTokenizer;
35+
3636

3737
/**
3838
* Facilitates the generation of sentence vectors using
@@ -55,7 +55,7 @@ public SentenceVectorsDL(final File model, final File vocabulary)
5555
env = OrtEnvironment.getEnvironment();
5656
session = env.createSession(model.getPath(), new OrtSession.SessionOptions());
5757
vocab = loadVocab(new File(vocabulary.getPath()));
58-
tokenizer = new WordpieceTokenizer(vocab.keySet());
58+
tokenizer = createTokenizer(vocab);
5959

6060
}
6161

0 commit comments

Comments
 (0)