|
| 1 | +package com.example.model.qwen2; |
| 2 | + |
| 3 | +import com.example.inference.InferenceCore; |
| 4 | +import com.example.inference.InferenceEngine; |
| 5 | +import com.example.inference.sampler.Sampler; |
| 6 | +import com.example.inference.state.Qwen2State; |
| 7 | +import com.example.inference.state.State; |
| 8 | +import com.example.inference.weights.Weights; |
| 9 | +import com.example.model.AbstractModel; |
| 10 | +import com.example.model.ModelType; |
| 11 | +import com.example.model.format.ChatFormat; |
| 12 | +import com.example.tokenizer.impl.Qwen3Tokenizer; |
| 13 | +import com.example.tokenizer.impl.Tokenizer; |
| 14 | +import com.example.tornadovm.TornadoVMMasterPlan; |
| 15 | + |
| 16 | +import java.util.List; |
| 17 | +import java.util.Set; |
| 18 | +import java.util.function.IntConsumer; |
| 19 | + |
| 20 | +public class Qwen2 extends AbstractModel { |
| 21 | + |
| 22 | + Qwen2Configuration configuration; |
| 23 | + |
| 24 | + public Qwen2(Qwen2Configuration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) { |
| 25 | + super(tokenizer, weights, chatFormat, null); |
| 26 | + this.configuration = configuration; |
| 27 | + } |
| 28 | + |
| 29 | + public Qwen2Configuration configuration() { |
| 30 | + return configuration; |
| 31 | + } |
| 32 | + |
| 33 | + @Override |
| 34 | + public Tokenizer tokenizer() { |
| 35 | + return (Qwen3Tokenizer) tokenizer; |
| 36 | + } |
| 37 | + |
| 38 | + @Override |
| 39 | + public ModelType getModelType() { |
| 40 | + return ModelType.QWEN_2; |
| 41 | + } |
| 42 | + |
| 43 | + @Override |
| 44 | + public State createNewState() { |
| 45 | + State state = new Qwen2State(configuration(), -1); |
| 46 | + state.latestToken = tokenizer.getSpecialTokens().get(chatFormat.chatTokens().tStartHeader()); |
| 47 | + return state; |
| 48 | + } |
| 49 | + |
| 50 | + @Override |
| 51 | + public State createNewState(int batchsize) { |
| 52 | + State state = new Qwen2State(configuration(), batchsize); |
| 53 | + state.latestToken = tokenizer.getSpecialTokens().get(chatFormat.chatTokens().tStartHeader()); |
| 54 | + return state; |
| 55 | + } |
| 56 | + |
| 57 | + @Override |
| 58 | + public void forward(State state, int token, int position) { |
| 59 | + if (plan == null) { |
| 60 | + InferenceCore.forwardJavaQwen2(this, state, token, position); |
| 61 | + } else { |
| 62 | + InferenceCore.forwardTornadoVM(this, state, token, position, tornadoVMPlan()); |
| 63 | + } |
| 64 | + } |
| 65 | + |
| 66 | + @Override |
| 67 | + public List<Integer> generateTokens(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo, |
| 68 | + IntConsumer onTokenGenerated) { |
| 69 | + return InferenceEngine.generateTokensQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); |
| 70 | + } |
| 71 | + |
| 72 | + @Override |
| 73 | + public List<Integer> generateTokensGPU(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo, |
| 74 | + IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { |
| 75 | + return InferenceEngine.generateTokensGPUQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); |
| 76 | + } |
| 77 | +} |
0 commit comments