Skip to content

Commit 480a1f0

Browse files
Additional cleanup
1 parent 7b5f052 commit 480a1f0

File tree

2 files changed

+10
-16
lines changed

2 files changed

+10
-16
lines changed

src/main/java/com/example/inference/state/Qwen3State.java

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,16 @@
1111

1212
public final class Qwen3State extends State {
1313

14-
// Qwen3-specific field
15-
public final FloatTensor kq;
16-
17-
// Qwen3 temporary buffer for intermediate calculations, size adjusted for local workgroup size.
14+
// Qwen3 specific fields
15+
// Temporary buffers for intermediate calculations.
1816
public FloatArray tempQcur;
1917
public FloatArray tempKcur;
2018

2119
public Qwen3State(Configuration config, int batchsize) {
2220
super(config, batchsize);
23-
// Initialize Qwen3-specific field
21+
// Initialize Qwen3-specific fields
2422
Qwen3Configuration qwen3config = (Qwen3Configuration) config;
2523
int nEmbdHead = qwen3config.numberOfHeads();
26-
this.kq = ArrayFloatTensor.allocate(config.numberOfHeads(), 32, 15);
2724
this.tempQcur = new FloatArray(nEmbdHead);
2825
this.tempKcur = new FloatArray(nEmbdHead);
2926
}
@@ -34,9 +31,7 @@ protected StateFields createStateFields(Configuration configuration) {
3431

3532
Qwen3Configuration config = (Qwen3Configuration) configuration;
3633

37-
//localSize = 128;
38-
39-
// Qwen3-specific calculations
34+
// Qwen3-specific sizes
4035
int nHeadKv = config.numberOfKeyValueHeads();
4136
int nEmbdHeadK = config.numberOfHeadsKey();
4237
int nEmbdKGqa = nEmbdHeadK * nHeadKv;
@@ -51,8 +46,8 @@ protected StateFields createStateFields(Configuration configuration) {
5146
fields.hb = ArrayFloatTensor.allocate(config.hiddenDim());
5247
fields.hb2 = ArrayFloatTensor.allocate(config.hiddenDim());
5348
fields.q = ArrayFloatTensor.allocate(nEmbdHeadK * config.numberOfHeads());
54-
fields.k = ArrayFloatTensor.allocate(nEmbdKGqa); // Different from Llama!
55-
fields.v = ArrayFloatTensor.allocate(nEmbdKGqa); // Different from Llama!
49+
fields.k = ArrayFloatTensor.allocate(nEmbdKGqa);
50+
fields.v = ArrayFloatTensor.allocate(nEmbdKGqa);
5651
fields.att = ArrayFloatTensor.allocate(config.numberOfHeads(), config.contextLength());
5752
fields.logits = ArrayFloatTensor.allocate(config.vocabularySize());
5853

@@ -64,14 +59,14 @@ protected StateFields createStateFields(Configuration configuration) {
6459

6560
// TornadoVM wrappers with Qwen3-specific sizes
6661
fields.wrapX = new FloatArray(config.dim());
67-
fields.wrapXb = new FloatArray(nEmbdHeadK * config.numberOfHeads()); // Different from Llama!
62+
fields.wrapXb = new FloatArray(nEmbdHeadK * config.numberOfHeads());
6863
fields.wrapXb2 = new FloatArray(config.dim());
6964
fields.wrapHb = new FloatArray(config.hiddenDim());
7065
fields.wrapHb2 = new FloatArray(config.hiddenDim());
7166
fields.wrapLogits = new FloatArray(config.vocabularySize());
72-
fields.wrapQ = new FloatArray(nEmbdHeadK * config.numberOfHeads()); // Different from Llama!
73-
fields.wrapK = new FloatArray(nEmbdKGqa); // Different from Llama!
74-
fields.wrapV = new FloatArray(nEmbdKGqa); // Different from Llama!
67+
fields.wrapQ = new FloatArray(nEmbdHeadK * config.numberOfHeads());
68+
fields.wrapK = new FloatArray(nEmbdKGqa);
69+
fields.wrapV = new FloatArray(nEmbdKGqa);
7570

7671
fields.wrapKeyCache = new FloatArray(config.contextLength() * nEmbdGqa * config.numberOfLayers());
7772
fields.wrapValueCache = new FloatArray(config.contextLength() * nEmbdGqa * config.numberOfLayers());

src/main/java/com/example/tornadovm/TornadoVMMasterPlan.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model mod
9696
* Dispatcher method to select the TornadoVMLayerPlanner for the model.
9797
*/
9898
TornadoVMLayerPlanner createPlanner(State state, Model model) {
99-
System.out.println("Creating TornadoVM layer planner : " + model.getModelType() );
10099
return switch (model.getModelType()) {
101100
case LLAMA_3, MISTRAL -> new TornadoVMLayerPlanner(state, model);
102101
case QWEN_3 -> new Qwen3TornadoVMLayerPlanner((Qwen3State) state, model);

0 commit comments

Comments
 (0)