Skip to content

Commit 340b35e

Browse files
Merge createTokenizer methods into Tokenizer constructors
1 parent 733815b commit 340b35e

File tree

3 files changed

+34
-51
lines changed

3 files changed

+34
-51
lines changed

src/main/java/com/example/loader/weights/ModelLoader.java

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,7 @@
1313
import com.example.model.Configuration;
1414
import com.example.model.Model;
1515
import com.example.model.ModelType;
16-
import com.example.model.llama.LlamaConfiguration;
17-
import com.example.model.llama.Llama;
18-
import com.example.model.mistral.Mistral;
19-
import com.example.model.mistral.MistralConfiguration;
2016
import com.example.inference.operation.RoPE;
21-
import com.example.tokenizer.impl.LlamaTokenizer;
22-
import com.example.tokenizer.impl.MistralTokenizer;
23-
import com.example.tokenizer.impl.Tokenizer;
24-
import com.example.tokenizer.vocabulary.Vocabulary;
2517
import uk.ac.manchester.tornado.api.types.HalfFloat;
2618
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
2719
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
@@ -33,20 +25,13 @@
3325
import java.nio.channels.FileChannel;
3426
import java.nio.file.Path;
3527
import java.nio.file.StandardOpenOption;
36-
import java.util.Arrays;
37-
import java.util.List;
3828
import java.util.Map;
3929
import java.util.function.IntFunction;
40-
import java.util.stream.Collectors;
41-
import java.util.stream.IntStream;
4230

4331
public final class ModelLoader {
4432
private static final String TOKENIZER_LLAMA_3_MODEL = "gpt2";
4533
private static final String TOKENIZER_MISTRAL_MODEL = "llama";
4634

47-
private static final String LLAMA_3_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
48-
private static final String MISTRAL_PATTERN = "\\S+|\\s+";
49-
5035
private static ModelType detectModelType(Map<String, Object> metadata) {
5136
String name = (String) metadata.get("general.name");
5237
String tokenizerModel = (String) metadata.get("tokenizer.ggml.model");
@@ -232,37 +217,6 @@ private static Weights createStandardWeights(Map<String, GGMLTensorEntry> tensor
232217
FloatBuffer.wrap(ropeFreqs.first()), FloatBuffer.wrap(ropeFreqs.second()), loadQuantized(outputWeight), outputWeight.ggmlType());
233218
}
234219

235-
private static Tokenizer createLlama3Tokenizer(Map<String, Object> metadata, Vocabulary vocabulary) {
236-
String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges");
237-
List<Pair<Integer, Integer>> merges = Arrays.stream(mergeLines).map(line -> line.split(" "))
238-
.map(parts -> new Pair<>(vocabulary.getIndex(parts[0]).orElseThrow(), vocabulary.getIndex(parts[1]).orElseThrow())).toList();
239-
240-
int allTokens = vocabulary.size();
241-
int baseTokens = 128000; // assume all tokens after the base ones are special.
242-
int reservedSpecialTokens = allTokens - baseTokens;
243-
List<String> specialTokensList = Arrays.stream(vocabulary.tokens(), baseTokens, allTokens).toList();
244-
245-
assert specialTokensList.stream().allMatch(token -> vocabulary.getIndex(token).isPresent());
246-
247-
Map<String, Integer> specialTokens = IntStream.range(0, specialTokensList.size()).boxed().collect(Collectors.toMap(i -> specialTokensList.get(i), i -> baseTokens + i));
248-
249-
return new LlamaTokenizer(vocabulary, merges, LLAMA_3_PATTERN, specialTokens);
250-
251-
}
252-
253-
private static Tokenizer createMistralTokenizer(Map<String, Object> metadata, Vocabulary vocabulary) {
254-
int[] tokenTypes = (int[]) metadata.get("tokenizer.ggml.token_type");
255-
List<Integer> specialTokensList = IntStream.range(0, vocabulary.size()).filter(t -> tokenTypes[t] != 1 && tokenTypes[t] != 6).boxed().toList();
256-
Map<String, Integer> specialTokens =
257-
IntStream.range(0, specialTokensList.size())
258-
.boxed()
259-
.collect(Collectors.toMap(
260-
t -> vocabulary.get(t),
261-
t -> t)
262-
);
263-
return new MistralTokenizer(vocabulary, null, specialTokens, tokenTypes);
264-
}
265-
266220
public static FloatTensor loadQuantized(GGMLTensorEntry entry) {
267221
GGMLType ggmlType = entry.ggmlType();
268222
return switch (ggmlType) {

src/main/java/com/example/tokenizer/impl/LlamaTokenizer.java

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
* <a href="https://github.com/openai/gpt-2/blob/master/src/encoder.py">GPT 2 tokenizer</a>
2727
*/
2828
public class LlamaTokenizer implements Tokenizer {
29+
private static final String LLAMA_3_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
2930
// general fields
3031
private final Pattern compiledPattern;
3132
private final Vocabulary vocabulary;
@@ -55,9 +56,23 @@ public boolean shouldDisplayToken(int token) {
5556
return !isSpecialToken(token);
5657
}
5758

58-
public LlamaTokenizer(Vocabulary vocabulary, List<Pair<Integer, Integer>> merges, String regexPattern, Map<String, Integer> specialTokens) {
59+
public LlamaTokenizer(Map<String, Object> metadata, Vocabulary vocabulary) {
60+
// load from metadata
61+
String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges");
62+
List<Pair<Integer, Integer>> merges = Arrays.stream(mergeLines).map(line -> line.split(" "))
63+
.map(parts -> new Pair<>(vocabulary.getIndex(parts[0]).orElseThrow(), vocabulary.getIndex(parts[1]).orElseThrow())).toList();
64+
int allTokens = vocabulary.size();
65+
int baseTokens = 128000; // assume all tokens after the base ones are special.
66+
int reservedSpecialTokens = allTokens - baseTokens;
67+
List<String> specialTokensList = Arrays.stream(vocabulary.tokens(), baseTokens, allTokens).toList();
68+
69+
assert specialTokensList.stream().allMatch(token -> vocabulary.getIndex(token).isPresent());
70+
71+
Map<String, Integer> specialTokens = IntStream.range(0, specialTokensList.size()).boxed().collect(Collectors.toMap(i -> specialTokensList.get(i), i -> baseTokens + i));
72+
73+
// init tokenizer object fields
5974
this.vocabulary = vocabulary;
60-
this.compiledPattern = regexPattern != null ? Pattern.compile(regexPattern) : null;
75+
this.compiledPattern = Pattern.compile(LLAMA_3_PATTERN);
6176
this.specialTokens = new HashMap<>(specialTokens);
6277
this.merges = new HashMap<>();
6378
for (Pair<Integer, Integer> pair : merges) {

src/main/java/com/example/tokenizer/impl/MistralTokenizer.java

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import java.nio.charset.StandardCharsets;
66
import java.util.*;
77
import java.util.regex.Pattern;
8+
import java.util.stream.Collectors;
9+
import java.util.stream.IntStream;
810

911
/**
1012
* TikToken-style BPE tokenizer with byte fallback.
@@ -23,6 +25,7 @@
2325
* This guarantees reversibility: every string can be tokenized and decoded back exactly.
2426
*/
2527
public class MistralTokenizer implements Tokenizer {
28+
private static final String MISTRAL_PATTERN = "\\S+|\\s+";
2629
// general fields
2730
private final Pattern compiledPattern;
2831
private final Vocabulary vocabulary;
@@ -58,11 +61,22 @@ public int getTokenType(int tokenIndex) {
5861
return tokenType[tokenIndex];
5962
}
6063

61-
public MistralTokenizer(Vocabulary vocabulary, String regexPattern, Map<String, Integer> specialTokens, int[] tokenType) {
64+
public MistralTokenizer(Map<String, Object> metadata, Vocabulary vocabulary) {
65+
// load from metadata
66+
int[] tokenTypes = (int[]) metadata.get("tokenizer.ggml.token_type");
67+
List<Integer> specialTokensList = IntStream.range(0, vocabulary.size()).filter(t -> tokenTypes[t] != 1 && tokenTypes[t] != 6).boxed().toList();
68+
Map<String, Integer> specialTokens =
69+
IntStream.range(0, specialTokensList.size())
70+
.boxed()
71+
.collect(Collectors.toMap(
72+
t -> vocabulary.get(t),
73+
t -> t)
74+
);
75+
// init tokenizer object fields
6276
this.vocabulary = vocabulary;
63-
this.compiledPattern = regexPattern != null ? Pattern.compile(regexPattern) : null;
77+
this.compiledPattern = null;
6478
this.specialTokens = new HashMap<>(specialTokens);
65-
this.tokenType = tokenType;
79+
this.tokenType = tokenTypes;
6680
this.byte0 = vocabulary.getIndex("<0x00>").orElseThrow();
6781
}
6882

0 commit comments

Comments
 (0)