Skip to content

Commit 09f1b4d

Browse files
Introduce TornadoWeights for Qwen2
1 parent 1e1ec8a commit 09f1b4d

File tree

3 files changed

+68
-1
lines changed

3 files changed

+68
-1
lines changed

src/main/java/org/beehive/gpullama3/inference/weights/standard/Qwen2StandardWeights.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import org.beehive.gpullama3.inference.weights.Weights;
77

88
public class Qwen2StandardWeights extends StandardWeights {
9+
// Qwen2-specific weights
910
public final FloatTensor[] q_bias, k_bias, v_bias;
1011

1112
public Qwen2StandardWeights(
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package org.beehive.gpullama3.inference.weights.tornado;
2+
3+
import org.beehive.gpullama3.core.model.GGMLType;
4+
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
5+
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
6+
7+
public class Qwen2TornadoWeights extends TornadoWeights {
8+
9+
// Qwen2-specific tornado weights
10+
FloatArray[] q_biasLayered;
11+
FloatArray[] k_biasLayered;
12+
FloatArray[] v_biasLayered;
13+
14+
public Qwen2TornadoWeights(FloatArray tokenEmbeddingTable, FloatArray[] rms_att_weightLayered, HalfFloatArray[] wqLayered, HalfFloatArray[] wkLayered, HalfFloatArray[] wvLayered,
15+
FloatArray[] wqBiasLayered,
16+
FloatArray[] wkBiasLayered,
17+
FloatArray[] wvBiasLayered,
18+
HalfFloatArray[] woLayered, FloatArray[] rms_ffn_weightLayered, HalfFloatArray[] w1Layered,
19+
HalfFloatArray[] w2Layered, HalfFloatArray[] w3Layered, FloatArray rms_final_weight_as_floatArray, FloatArray freq_cis_realFlat, FloatArray freq_cis_imagFlat, HalfFloatArray wclsByteArray,
20+
GGMLType weightType) {
21+
// call to TornadoWeights constructor
22+
super(tokenEmbeddingTable,
23+
rms_att_weightLayered,
24+
wqLayered,
25+
wkLayered,
26+
wvLayered,
27+
woLayered,
28+
rms_ffn_weightLayered,
29+
w1Layered,
30+
w2Layered,
31+
w3Layered,
32+
rms_final_weight_as_floatArray,
33+
freq_cis_realFlat,
34+
freq_cis_imagFlat,
35+
wclsByteArray,
36+
weightType);
37+
// init qwen2-specific fields
38+
this.q_biasLayered = wqBiasLayered;
39+
this.k_biasLayered = wkBiasLayered;
40+
this.v_biasLayered = wvBiasLayered;
41+
}
42+
}

src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.beehive.gpullama3.tokenizer.impl.Qwen3Tokenizer;
2121
import org.beehive.gpullama3.tokenizer.impl.Tokenizer;
2222
import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary;
23+
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
2324

2425
import java.io.IOException;
2526
import java.nio.channels.FileChannel;
@@ -112,7 +113,6 @@ public Weights loadWeights(Map<String, GGMLTensorEntry> tensorEntries, Configura
112113
return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
113114
}
114115
}
115-
// @formatter:on
116116

117117
@Override
118118
public Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, Configuration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
@@ -141,6 +141,30 @@ public Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries,
141141
}
142142

143143
@Override
144+
public Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, Configuration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
145+
GGMLTensorEntry outputWeight) {
146+
return new Qwen2TornadoWeights(
147+
loadTensorAsFloatArray(tokenEmbeddings),
148+
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
149+
loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
150+
loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
151+
loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
152+
// Qwen2-specific: qkv bias
153+
loadArrayAsFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.bias")),
154+
loadArrayAsFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.bias")),
155+
loadArrayAsFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.bias")),
156+
157+
loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
158+
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
159+
loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1
160+
loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2
161+
loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3
162+
floatBufferToFloatArray(tensorEntries.get("output_norm.weight")),
163+
FloatArray.fromArray(ropeFreqs.first()),
164+
FloatArray.fromArray(ropeFreqs.second()),
165+
loadTensorAsHalfFloatArray(outputWeight),
166+
outputWeight.ggmlType()
167+
);
144168
}
145169
// @formatter:on
146170

0 commit comments

Comments
 (0)