|
20 | 20 | import org.beehive.gpullama3.tokenizer.impl.Qwen3Tokenizer;
|
21 | 21 | import org.beehive.gpullama3.tokenizer.impl.Tokenizer;
|
22 | 22 | import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary;
|
| 23 | +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; |
23 | 24 |
|
24 | 25 | import java.io.IOException;
|
25 | 26 | import java.nio.channels.FileChannel;
|
@@ -112,7 +113,6 @@ public Weights loadWeights(Map<String, GGMLTensorEntry> tensorEntries, Configura
|
112 | 113 | return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
|
113 | 114 | }
|
114 | 115 | }
|
115 |
| - // @formatter:on |
116 | 116 |
|
117 | 117 | @Override
|
118 | 118 | public Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, Configuration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
|
@@ -141,6 +141,30 @@ public Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries,
|
141 | 141 | }
|
142 | 142 |
|
143 | 143 | @Override
|
| 144 | + public Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, Configuration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings, |
| 145 | + GGMLTensorEntry outputWeight) { |
| 146 | + return new Qwen2TornadoWeights( |
| 147 | + loadTensorAsFloatArray(tokenEmbeddings), |
| 148 | + loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), |
| 149 | + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), |
| 150 | + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), |
| 151 | + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), |
| 152 | + // Qwen2-specific: qkv bias |
| 153 | + loadArrayAsFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.bias")), |
| 154 | + loadArrayAsFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.bias")), |
| 155 | + loadArrayAsFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.bias")), |
| 156 | + |
| 157 | + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), |
| 158 | + loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), |
| 159 | + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1 |
| 160 | + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2 |
| 161 | + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3 |
| 162 | + floatBufferToFloatArray(tensorEntries.get("output_norm.weight")), |
| 163 | + FloatArray.fromArray(ropeFreqs.first()), |
| 164 | + FloatArray.fromArray(ropeFreqs.second()), |
| 165 | + loadTensorAsHalfFloatArray(outputWeight), |
| 166 | + outputWeight.ggmlType() |
| 167 | + ); |
144 | 168 | }
|
145 | 169 | // @formatter:on
|
146 | 170 |
|
|
0 commit comments