11
11
12
12
public final class Qwen3State extends State {
13
13
14
- // Qwen3-specific field
15
- public final FloatTensor kq ;
16
-
17
- // Qwen3 temporary buffer for intermediate calculations, size adjusted for local workgroup size.
14
+ // Qwen3 specific fields
15
+ // Temporary buffers for intermediate calculations.
18
16
public FloatArray tempQcur ;
19
17
public FloatArray tempKcur ;
20
18
21
19
public Qwen3State (Configuration config , int batchsize ) {
22
20
super (config , batchsize );
23
- // Initialize Qwen3-specific field
21
+ // Initialize Qwen3-specific fields
24
22
Qwen3Configuration qwen3config = (Qwen3Configuration ) config ;
25
23
int nEmbdHead = qwen3config .numberOfHeads ();
26
- this .kq = ArrayFloatTensor .allocate (config .numberOfHeads (), 32 , 15 );
27
24
this .tempQcur = new FloatArray (nEmbdHead );
28
25
this .tempKcur = new FloatArray (nEmbdHead );
29
26
}
@@ -34,9 +31,7 @@ protected StateFields createStateFields(Configuration configuration) {
34
31
35
32
Qwen3Configuration config = (Qwen3Configuration ) configuration ;
36
33
37
- //localSize = 128;
38
-
39
- // Qwen3-specific calculations
34
+ // Qwen3-specific sizes
40
35
int nHeadKv = config .numberOfKeyValueHeads ();
41
36
int nEmbdHeadK = config .numberOfHeadsKey ();
42
37
int nEmbdKGqa = nEmbdHeadK * nHeadKv ;
@@ -51,8 +46,8 @@ protected StateFields createStateFields(Configuration configuration) {
51
46
fields .hb = ArrayFloatTensor .allocate (config .hiddenDim ());
52
47
fields .hb2 = ArrayFloatTensor .allocate (config .hiddenDim ());
53
48
fields .q = ArrayFloatTensor .allocate (nEmbdHeadK * config .numberOfHeads ());
54
- fields .k = ArrayFloatTensor .allocate (nEmbdKGqa ); // Different from Llama!
55
- fields .v = ArrayFloatTensor .allocate (nEmbdKGqa ); // Different from Llama!
49
+ fields .k = ArrayFloatTensor .allocate (nEmbdKGqa );
50
+ fields .v = ArrayFloatTensor .allocate (nEmbdKGqa );
56
51
fields .att = ArrayFloatTensor .allocate (config .numberOfHeads (), config .contextLength ());
57
52
fields .logits = ArrayFloatTensor .allocate (config .vocabularySize ());
58
53
@@ -64,14 +59,14 @@ protected StateFields createStateFields(Configuration configuration) {
64
59
65
60
// TornadoVM wrappers with Qwen3-specific sizes
66
61
fields .wrapX = new FloatArray (config .dim ());
67
- fields .wrapXb = new FloatArray (nEmbdHeadK * config .numberOfHeads ()); // Different from Llama!
62
+ fields .wrapXb = new FloatArray (nEmbdHeadK * config .numberOfHeads ());
68
63
fields .wrapXb2 = new FloatArray (config .dim ());
69
64
fields .wrapHb = new FloatArray (config .hiddenDim ());
70
65
fields .wrapHb2 = new FloatArray (config .hiddenDim ());
71
66
fields .wrapLogits = new FloatArray (config .vocabularySize ());
72
- fields .wrapQ = new FloatArray (nEmbdHeadK * config .numberOfHeads ()); // Different from Llama!
73
- fields .wrapK = new FloatArray (nEmbdKGqa ); // Different from Llama!
74
- fields .wrapV = new FloatArray (nEmbdKGqa ); // Different from Llama!
67
+ fields .wrapQ = new FloatArray (nEmbdHeadK * config .numberOfHeads ());
68
+ fields .wrapK = new FloatArray (nEmbdKGqa );
69
+ fields .wrapV = new FloatArray (nEmbdKGqa );
75
70
76
71
fields .wrapKeyCache = new FloatArray (config .contextLength () * nEmbdGqa * config .numberOfLayers ());
77
72
fields .wrapValueCache = new FloatArray (config .contextLength () * nEmbdGqa * config .numberOfLayers ());
0 commit comments