|
13 | 13 | import com.example.model.Configuration;
|
14 | 14 | import com.example.model.Model;
|
15 | 15 | import com.example.model.ModelType;
|
16 |
| -import com.example.model.llama.LlamaConfiguration; |
17 |
| -import com.example.model.llama.Llama; |
18 |
| -import com.example.model.mistral.Mistral; |
19 |
| -import com.example.model.mistral.MistralConfiguration; |
20 | 16 | import com.example.inference.operation.RoPE;
|
21 |
| -import com.example.tokenizer.impl.LlamaTokenizer; |
22 |
| -import com.example.tokenizer.impl.MistralTokenizer; |
23 |
| -import com.example.tokenizer.impl.Tokenizer; |
24 |
| -import com.example.tokenizer.vocabulary.Vocabulary; |
25 | 17 | import uk.ac.manchester.tornado.api.types.HalfFloat;
|
26 | 18 | import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
|
27 | 19 | import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
|
|
33 | 25 | import java.nio.channels.FileChannel;
|
34 | 26 | import java.nio.file.Path;
|
35 | 27 | import java.nio.file.StandardOpenOption;
|
36 |
| -import java.util.Arrays; |
37 |
| -import java.util.List; |
38 | 28 | import java.util.Map;
|
39 | 29 | import java.util.function.IntFunction;
|
40 |
| -import java.util.stream.Collectors; |
41 |
| -import java.util.stream.IntStream; |
42 | 30 |
|
43 | 31 | public final class ModelLoader {
|
44 | 32 | private static final String TOKENIZER_LLAMA_3_MODEL = "gpt2";
|
45 | 33 | private static final String TOKENIZER_MISTRAL_MODEL = "llama";
|
46 | 34 |
|
47 |
| - private static final String LLAMA_3_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; |
48 |
| - private static final String MISTRAL_PATTERN = "\\S+|\\s+"; |
49 |
| - |
50 | 35 | private static ModelType detectModelType(Map<String, Object> metadata) {
|
51 | 36 | String name = (String) metadata.get("general.name");
|
52 | 37 | String tokenizerModel = (String) metadata.get("tokenizer.ggml.model");
|
@@ -232,37 +217,6 @@ private static Weights createStandardWeights(Map<String, GGMLTensorEntry> tensor
|
232 | 217 | FloatBuffer.wrap(ropeFreqs.first()), FloatBuffer.wrap(ropeFreqs.second()), loadQuantized(outputWeight), outputWeight.ggmlType());
|
233 | 218 | }
|
234 | 219 |
|
235 |
| - private static Tokenizer createLlama3Tokenizer(Map<String, Object> metadata, Vocabulary vocabulary) { |
236 |
| - String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges"); |
237 |
| - List<Pair<Integer, Integer>> merges = Arrays.stream(mergeLines).map(line -> line.split(" ")) |
238 |
| - .map(parts -> new Pair<>(vocabulary.getIndex(parts[0]).orElseThrow(), vocabulary.getIndex(parts[1]).orElseThrow())).toList(); |
239 |
| - |
240 |
| - int allTokens = vocabulary.size(); |
241 |
| - int baseTokens = 128000; // assume all tokens after the base ones are special. |
242 |
| - int reservedSpecialTokens = allTokens - baseTokens; |
243 |
| - List<String> specialTokensList = Arrays.stream(vocabulary.tokens(), baseTokens, allTokens).toList(); |
244 |
| - |
245 |
| - assert specialTokensList.stream().allMatch(token -> vocabulary.getIndex(token).isPresent()); |
246 |
| - |
247 |
| - Map<String, Integer> specialTokens = IntStream.range(0, specialTokensList.size()).boxed().collect(Collectors.toMap(i -> specialTokensList.get(i), i -> baseTokens + i)); |
248 |
| - |
249 |
| - return new LlamaTokenizer(vocabulary, merges, LLAMA_3_PATTERN, specialTokens); |
250 |
| - |
251 |
| - } |
252 |
| - |
253 |
| - private static Tokenizer createMistralTokenizer(Map<String, Object> metadata, Vocabulary vocabulary) { |
254 |
| - int[] tokenTypes = (int[]) metadata.get("tokenizer.ggml.token_type"); |
255 |
| - List<Integer> specialTokensList = IntStream.range(0, vocabulary.size()).filter(t -> tokenTypes[t] != 1 && tokenTypes[t] != 6).boxed().toList(); |
256 |
| - Map<String, Integer> specialTokens = |
257 |
| - IntStream.range(0, specialTokensList.size()) |
258 |
| - .boxed() |
259 |
| - .collect(Collectors.toMap( |
260 |
| - t -> vocabulary.get(t), |
261 |
| - t -> t) |
262 |
| - ); |
263 |
| - return new MistralTokenizer(vocabulary, null, specialTokens, tokenTypes); |
264 |
| - } |
265 |
| - |
266 | 220 | public static FloatTensor loadQuantized(GGMLTensorEntry entry) {
|
267 | 221 | GGMLType ggmlType = entry.ggmlType();
|
268 | 222 | return switch (ggmlType) {
|
|
0 commit comments