4
4
import org .beehive .gpullama3 .core .model .tensor .FloatTensor ;
5
5
import org .beehive .gpullama3 .model .Configuration ;
6
6
import org .beehive .gpullama3 .model .qwen2 .Qwen2Configuration ;
7
+ import uk .ac .manchester .tornado .api .types .arrays .FloatArray ;
8
+ import uk .ac .manchester .tornado .api .types .arrays .IntArray ;
7
9
8
10
import java .util .stream .Stream ;
9
11
10
12
public class Qwen2State extends State {
11
13
12
- //Qwen2 specific fields TODO
13
-
14
14
public Qwen2State (Configuration config , int batchsize ) {
15
15
super (config , batchsize );
16
- // Initialize Qwen2-specific fields TODO
17
- Qwen2Configuration qwen2Config = (Qwen2Configuration ) config ;
16
+ this .localSize = 32 ;
18
17
}
19
18
@ Override
20
19
protected StateFields createStateFields (Configuration configuration ) {
@@ -40,6 +39,30 @@ protected StateFields createStateFields(Configuration configuration) {
40
39
fields .keyCache = Stream .generate (() -> ArrayFloatTensor .allocate (config .contextLength (), nEmbdGqa )).limit (config .numberOfLayers ()).toArray (FloatTensor []::new );
41
40
fields .valueCache = Stream .generate (() -> ArrayFloatTensor .allocate (config .contextLength (), nEmbdGqa )).limit (config .numberOfLayers ()).toArray (FloatTensor []::new );
42
41
42
+ // TornadoVM wrappers with Qwen2 dimensions
43
+ fields .wrapX = new FloatArray (config .dim ());
44
+ fields .wrapXb = new FloatArray (config .dim ());
45
+ fields .wrapXb2 = new FloatArray (config .dim ());
46
+ fields .wrapHb = new FloatArray (config .hiddenDim ());
47
+ fields .wrapHb2 = new FloatArray (config .hiddenDim ());
48
+
49
+ fields .wrapLogits = new FloatArray (config .vocabularySize ());
50
+ fields .wrapQ = new FloatArray (config .dim ());
51
+ fields .wrapK = new FloatArray (config .kvDim ());
52
+ fields .wrapV = new FloatArray (config .kvDim ());
53
+
54
+ fields .wrapKeyCache = new FloatArray (config .contextLength () * nEmbdGqa * config .numberOfLayers ());
55
+ fields .wrapValueCache = new FloatArray (config .contextLength () * nEmbdGqa * config .numberOfLayers ());
56
+ fields .wrapValueCache .init (0.f );
57
+ fields .wrapKeyCache .init (0.f );
58
+ fields .wrapAtt = new FloatArray (config .numberOfHeads () * config .contextLength ());
59
+ fields .positionHolder = new IntArray (1 );
60
+
61
+ // Temporary arrays
62
+ fields .temp = new FloatArray (1 + ((config .dim () + localSize - 1 ) / localSize ));
63
+ fields .tempFFN = new FloatArray (1 + ((config .dim () + localSize - 1 ) / localSize ));
64
+ fields .tempLogits = new FloatArray (1 + ((config .dim () + localSize - 1 ) / localSize ));
65
+
43
66
return fields ;
44
67
45
68
}
0 commit comments