|
| 1 | +package org.beehive.gpullama3.inference.state; |
| 2 | + |
| 3 | +import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor; |
| 4 | +import org.beehive.gpullama3.core.model.tensor.FloatTensor; |
| 5 | +import org.beehive.gpullama3.model.Configuration; |
| 6 | +import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; |
| 7 | + |
| 8 | +import java.util.stream.Stream; |
| 9 | + |
| 10 | +public class Qwen2State extends State { |
| 11 | + |
| 12 | + //Qwen2 specific fields TODO |
| 13 | + |
| 14 | + public Qwen2State(Configuration config, int batchsize) { |
| 15 | + super(config, batchsize); |
| 16 | + // Initialize Qwen2-specific fields TODO |
| 17 | + Qwen2Configuration qwen2Config = (Qwen2Configuration) config; |
| 18 | + } |
| 19 | + @Override |
| 20 | + protected StateFields createStateFields(Configuration configuration) { |
| 21 | + StateFields fields = new StateFields(); |
| 22 | + |
| 23 | + Qwen2Configuration config = (Qwen2Configuration) configuration; |
| 24 | + |
| 25 | + int nEmbdGqa = config.kvDim(); |
| 26 | + |
| 27 | + // with Qwen2-specific sizes |
| 28 | + fields.x = ArrayFloatTensor.allocate(config.dim()); |
| 29 | + fields.xb = ArrayFloatTensor.allocate(config.dim()); |
| 30 | + fields.xb2 = ArrayFloatTensor.allocate(config.dim()); |
| 31 | + fields.hb = ArrayFloatTensor.allocate(config.hiddenDim()); |
| 32 | + fields.hb2 = ArrayFloatTensor.allocate(config.hiddenDim()); |
| 33 | + fields.q = ArrayFloatTensor.allocate(config.dim()); |
| 34 | + fields.k = ArrayFloatTensor.allocate(config.kvDim()); |
| 35 | + fields.v = ArrayFloatTensor.allocate(config.kvDim()); |
| 36 | + fields.att = ArrayFloatTensor.allocate(config.numberOfHeads(), config.contextLength()); |
| 37 | + fields.logits = ArrayFloatTensor.allocate(config.vocabularySize()); |
| 38 | + |
| 39 | + // Key-value cache with Qwen2 dimensions |
| 40 | + fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new); |
| 41 | + fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new); |
| 42 | + |
| 43 | + return fields; |
| 44 | + |
| 45 | + } |
| 46 | +} |
0 commit comments