|
1 | 1 | package com.example.loader.weights;
|
2 | 2 |
|
3 | 3 | import com.example.LlamaApp;
|
4 |
| -import com.example.auxiliary.Timer; |
5 | 4 | import com.example.core.model.GGMLType;
|
6 | 5 | import com.example.core.model.GGUF;
|
7 | 6 | import com.example.core.model.tensor.F16FloatTensor;
|
@@ -70,89 +69,10 @@ public static Model loadModel(Path ggufPath, int contextLength, boolean loadWeig
|
70 | 69 | // initial load of metadata from gguf file
|
71 | 70 | GGUF gguf = GGUF.loadModel(ggufPath);
|
72 | 71 | FileChannel fileChannel = FileChannel.open(ggufPath, StandardOpenOption.READ);
|
73 |
| - |
74 | 72 | // detect model type
|
75 | 73 | ModelType modelType = detectModelType(gguf.getMetadata());
|
76 |
| - System.out.println("Detected model type: " + modelType); |
77 |
| - |
78 |
| - // load model (vocabulary, tokenizer, configuration, tensors, weights) |
79 |
| - return switch (modelType) { |
80 |
| - case LLAMA_3 -> loadLlamaModel(fileChannel, gguf, contextLength, loadWeights); |
81 |
| - case MISTRAL -> loadMistralModel(fileChannel, gguf, contextLength, loadWeights); |
82 |
| - default -> throw new UnsupportedOperationException("Unsupported model type: " + modelType); |
83 |
| - }; |
84 |
| - } |
85 |
| - |
86 |
| - public static Llama loadLlamaModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) throws IOException { |
87 |
| - try (var ignored = Timer.log("Load LlaMa model")) { |
88 |
| - Map<String, Object> metadata = gguf.getMetadata(); |
89 |
| - |
90 |
| - Vocabulary vocabulary = Vocabulary.loadLlamaVocabulary(metadata); |
91 |
| - Tokenizer tokenizer = createLlama3Tokenizer(metadata, vocabulary); |
92 |
| - |
93 |
| - LlamaConfiguration config = new LlamaConfiguration( |
94 |
| - (int) metadata.get("llama.embedding_length"), |
95 |
| - (int) metadata.get("llama.feed_forward_length"), |
96 |
| - (int) metadata.get("llama.block_count"), |
97 |
| - (int) metadata.get("llama.attention.head_count"), |
98 |
| - |
99 |
| - metadata.containsKey("llama.attention.head_count_kv") ? |
100 |
| - (int) metadata.get("llama.attention.head_count_kv") : |
101 |
| - (int) metadata.get("llama.attention.head_count"), |
102 |
| - |
103 |
| - vocabulary.size(), |
104 |
| - (int) metadata.get("llama.context_length"), |
105 |
| - (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f), |
106 |
| - (float) metadata.getOrDefault("llama.rope.freq_base", 10000f) |
107 |
| - ).withContextLength(contextLength); |
108 |
| - |
109 |
| - Weights weights = null; |
110 |
| - if (loadWeights) { |
111 |
| - Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); |
112 |
| - weights = loadWeights(tensorEntries, config); |
113 |
| - } |
114 |
| - return new Llama(config, tokenizer, weights); |
115 |
| - } |
116 |
| - } |
117 |
| - |
118 |
| - public static Mistral loadMistralModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) { |
119 |
| - try (var ignored = Timer.log("Load Mistral model")) { |
120 |
| - Map<String, Object> metadata = gguf.getMetadata(); |
121 |
| - |
122 |
| - Vocabulary vocabulary = Vocabulary.loadMistralVocabulary(metadata); |
123 |
| - Tokenizer tokenizer = createMistralTokenizer(metadata, vocabulary); |
124 |
| - |
125 |
| - int modelContextLength = (int) metadata.get("llama.context_length"); |
126 |
| - if (contextLength < 0 || modelContextLength < contextLength) { |
127 |
| - contextLength = modelContextLength; |
128 |
| - } |
129 |
| - |
130 |
| - MistralConfiguration config = new MistralConfiguration( |
131 |
| - (int) metadata.get("llama.embedding_length"), |
132 |
| - (int) metadata.get("llama.feed_forward_length"), |
133 |
| - (int) metadata.get("llama.block_count"), |
134 |
| - (int) metadata.get("llama.attention.head_count"), |
135 |
| - |
136 |
| - metadata.containsKey("llama.attention.head_count_kv") |
137 |
| - ? (int) metadata.get("llama.attention.head_count_kv") |
138 |
| - : (int) metadata.get("llama.attention.head_count"), |
139 |
| - |
140 |
| - vocabulary.size(), |
141 |
| - contextLength, |
142 |
| - false, |
143 |
| - (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f), |
144 |
| - (float) metadata.getOrDefault("llama.rope.freq_base", 10000f) |
145 |
| - ); |
146 |
| - |
147 |
| - Weights weights = null; |
148 |
| - if (loadWeights) { |
149 |
| - Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); |
150 |
| - weights = loadWeights(tensorEntries, config); |
151 |
| - } |
152 |
| - return new Mistral(config, tokenizer, weights); |
153 |
| - } catch (IOException e) { |
154 |
| - throw new RuntimeException(e); |
155 |
| - } |
| 74 | + // model type-specific load |
| 75 | + return modelType.loadModel(fileChannel, gguf, contextLength, loadWeights); |
156 | 76 | }
|
157 | 77 |
|
158 | 78 | public static Weights loadWeights(Map<String, GGMLTensorEntry> tensorEntries, Configuration config) {
|
|
0 commit comments