Skip to content

Commit eead727

Browse files
Add missing pieces for Qwen2.5 & Deepseek-r1-distill-qwen with tornado
1 parent abc1b2b commit eead727

File tree

6 files changed

+244
-18
lines changed

6 files changed

+244
-18
lines changed

src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,16 @@
44
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
55
import org.beehive.gpullama3.model.Configuration;
66
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
7+
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
8+
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
79

810
import java.util.stream.Stream;
911

1012
public class Qwen2State extends State {
1113

12-
//Qwen2 specific fields TODO
13-
1414
public Qwen2State(Configuration config, int batchsize) {
1515
super(config, batchsize);
16-
// Initialize Qwen2-specific fields TODO
17-
Qwen2Configuration qwen2Config = (Qwen2Configuration) config;
16+
this.localSize = 32;
1817
}
1918
@Override
2019
protected StateFields createStateFields(Configuration configuration) {
@@ -40,6 +39,30 @@ protected StateFields createStateFields(Configuration configuration) {
4039
fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
4140
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
4241

42+
// TornadoVM wrappers with Qwen2 dimensions
43+
fields.wrapX = new FloatArray(config.dim());
44+
fields.wrapXb = new FloatArray(config.dim());
45+
fields.wrapXb2 = new FloatArray(config.dim());
46+
fields.wrapHb = new FloatArray(config.hiddenDim());
47+
fields.wrapHb2 = new FloatArray(config.hiddenDim());
48+
49+
fields.wrapLogits = new FloatArray(config.vocabularySize());
50+
fields.wrapQ = new FloatArray(config.dim());
51+
fields.wrapK = new FloatArray(config.kvDim());
52+
fields.wrapV = new FloatArray(config.kvDim());
53+
54+
fields.wrapKeyCache = new FloatArray(config.contextLength() * nEmbdGqa * config.numberOfLayers());
55+
fields.wrapValueCache = new FloatArray(config.contextLength() * nEmbdGqa * config.numberOfLayers());
56+
fields.wrapValueCache.init(0.f);
57+
fields.wrapKeyCache.init(0.f);
58+
fields.wrapAtt = new FloatArray(config.numberOfHeads() * config.contextLength());
59+
fields.positionHolder = new IntArray(1);
60+
61+
// Temporary arrays
62+
fields.temp = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
63+
fields.tempFFN = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
64+
fields.tempLogits = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
65+
4366
return fields;
4467

4568
}

src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2TornadoWeights.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
public class Qwen2TornadoWeights extends TornadoWeights {
88

99
// Qwen2-specific tornado weights
10-
FloatArray[] q_biasLayered;
11-
FloatArray[] k_biasLayered;
12-
FloatArray[] v_biasLayered;
10+
public FloatArray[] q_biasLayered;
11+
public FloatArray[] k_biasLayered;
12+
public FloatArray[] v_biasLayered;
1313

1414
public Qwen2TornadoWeights(FloatArray tokenEmbeddingTable, FloatArray[] rms_att_weightLayered, HalfFloatArray[] wqLayered, HalfFloatArray[] wkLayered, HalfFloatArray[] wvLayered,
1515
FloatArray[] wqBiasLayered,

src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,9 @@ public Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries
150150
loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
151151
loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
152152
// Qwen2-specific: qkv bias
153-
loadArrayAsFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.bias")),
154-
loadArrayAsFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.bias")),
155-
loadArrayAsFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.bias")),
153+
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.bias")),
154+
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.bias")),
155+
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.bias")),
156156

157157
loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
158158
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),

src/main/java/org/beehive/gpullama3/model/qwen2/Qwen2Configuration.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public int kvDim() {
2727

2828
@Override
2929
public int kvMul() {
30-
throw new UnsupportedOperationException("Not supported for Qwen2.");
30+
return numberOfHeads / numberOfKeyValueHeads;
3131
}
3232

3333
@Override

0 commit comments

Comments
 (0)