Skip to content

Commit 6adc02c

Browse files
Add class for qwen2
1 parent 5a3ab76 commit 6adc02c

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package com.example.model.qwen2;
2+
3+
import com.example.inference.InferenceCore;
4+
import com.example.inference.InferenceEngine;
5+
import com.example.inference.sampler.Sampler;
6+
import com.example.inference.state.Qwen2State;
7+
import com.example.inference.state.State;
8+
import com.example.inference.weights.Weights;
9+
import com.example.model.AbstractModel;
10+
import com.example.model.ModelType;
11+
import com.example.model.format.ChatFormat;
12+
import com.example.tokenizer.impl.Qwen3Tokenizer;
13+
import com.example.tokenizer.impl.Tokenizer;
14+
import com.example.tornadovm.TornadoVMMasterPlan;
15+
16+
import java.util.List;
17+
import java.util.Set;
18+
import java.util.function.IntConsumer;
19+
20+
public class Qwen2 extends AbstractModel {
21+
22+
Qwen2Configuration configuration;
23+
24+
public Qwen2(Qwen2Configuration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) {
25+
super(tokenizer, weights, chatFormat, null);
26+
this.configuration = configuration;
27+
}
28+
29+
public Qwen2Configuration configuration() {
30+
return configuration;
31+
}
32+
33+
@Override
34+
public Tokenizer tokenizer() {
35+
return (Qwen3Tokenizer) tokenizer;
36+
}
37+
38+
@Override
39+
public ModelType getModelType() {
40+
return ModelType.QWEN_2;
41+
}
42+
43+
@Override
44+
public State createNewState() {
45+
State state = new Qwen2State(configuration(), -1);
46+
state.latestToken = tokenizer.getSpecialTokens().get(chatFormat.chatTokens().tStartHeader());
47+
return state;
48+
}
49+
50+
@Override
51+
public State createNewState(int batchsize) {
52+
State state = new Qwen2State(configuration(), batchsize);
53+
state.latestToken = tokenizer.getSpecialTokens().get(chatFormat.chatTokens().tStartHeader());
54+
return state;
55+
}
56+
57+
@Override
58+
public void forward(State state, int token, int position) {
59+
if (plan == null) {
60+
InferenceCore.forwardJavaQwen2(this, state, token, position);
61+
} else {
62+
InferenceCore.forwardTornadoVM(this, state, token, position, tornadoVMPlan());
63+
}
64+
}
65+
66+
@Override
67+
public List<Integer> generateTokens(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
68+
IntConsumer onTokenGenerated) {
69+
return InferenceEngine.generateTokensQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated);
70+
}
71+
72+
@Override
73+
public List<Integer> generateTokensGPU(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
74+
IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) {
75+
return InferenceEngine.generateTokensGPUQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan);
76+
}
77+
}

0 commit comments

Comments
 (0)