Skip to content

Commit 5a3ab76

Browse files
Add state for qwen2
1 parent 177476a commit 5a3ab76

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package org.beehive.gpullama3.inference.state;
2+
3+
import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
4+
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
5+
import org.beehive.gpullama3.model.Configuration;
6+
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
7+
8+
import java.util.stream.Stream;
9+
10+
public class Qwen2State extends State {
11+
12+
//Qwen2 specific fields TODO
13+
14+
public Qwen2State(Configuration config, int batchsize) {
15+
super(config, batchsize);
16+
// Initialize Qwen2-specific fields TODO
17+
Qwen2Configuration qwen2Config = (Qwen2Configuration) config;
18+
}
19+
@Override
20+
protected StateFields createStateFields(Configuration configuration) {
21+
StateFields fields = new StateFields();
22+
23+
Qwen2Configuration config = (Qwen2Configuration) configuration;
24+
25+
int nEmbdGqa = config.kvDim();
26+
27+
// with Qwen2-specific sizes
28+
fields.x = ArrayFloatTensor.allocate(config.dim());
29+
fields.xb = ArrayFloatTensor.allocate(config.dim());
30+
fields.xb2 = ArrayFloatTensor.allocate(config.dim());
31+
fields.hb = ArrayFloatTensor.allocate(config.hiddenDim());
32+
fields.hb2 = ArrayFloatTensor.allocate(config.hiddenDim());
33+
fields.q = ArrayFloatTensor.allocate(config.dim());
34+
fields.k = ArrayFloatTensor.allocate(config.kvDim());
35+
fields.v = ArrayFloatTensor.allocate(config.kvDim());
36+
fields.att = ArrayFloatTensor.allocate(config.numberOfHeads(), config.contextLength());
37+
fields.logits = ArrayFloatTensor.allocate(config.vocabularySize());
38+
39+
// Key-value cache with Qwen2 dimensions
40+
fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
41+
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
42+
43+
return fields;
44+
45+
}
46+
}

0 commit comments

Comments
 (0)