Skip to content

Commit c878541

Browse files
Add model loader for qwen2
1 parent 85f1875 commit c878541

File tree

1 file changed

+107
-0
lines changed

1 file changed

+107
-0
lines changed
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package org.beehive.gpullama3.model.loader;
2+
3+
import org.beehive.gpullama3.auxiliary.Timer;
4+
import org.beehive.gpullama3.core.model.GGUF;
5+
import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
6+
import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry;
7+
import org.beehive.gpullama3.core.types.Pair;
8+
import org.beehive.gpullama3.inference.weights.Weights;
9+
import org.beehive.gpullama3.inference.weights.standard.Qwen2StandardWeights;
10+
import org.beehive.gpullama3.model.Configuration;
11+
import org.beehive.gpullama3.model.Model;
12+
import org.beehive.gpullama3.model.format.ChatFormat;
13+
import org.beehive.gpullama3.model.format.ChatFormat.ChatTokens;
14+
import org.beehive.gpullama3.model.qwen2.Qwen2;
15+
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
16+
import org.beehive.gpullama3.tokenizer.impl.Qwen3Tokenizer;
17+
import org.beehive.gpullama3.tokenizer.impl.Tokenizer;
18+
import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary;
19+
20+
import java.io.IOException;
21+
import java.nio.channels.FileChannel;
22+
import java.util.Map;
23+
24+
import static org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary.loadQwen3Vocabulary;
25+
26+
public class Qwen2ModelLoader extends ModelLoader {
27+
28+
public Qwen2ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
29+
super(fileChannel, gguf, contextLength, loadWeights);
30+
}
31+
32+
@Override
33+
public Model loadModel() {
34+
try (var ignored = Timer.log("Load Qwen2 model")) {
35+
Map<String, Object> metadata = gguf.getMetadata();
36+
37+
// reuse method of Qwen3
38+
Vocabulary vocabulary = loadQwen3Vocabulary(metadata);
39+
boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename"));
40+
Tokenizer tokenizer = new Qwen3Tokenizer(metadata, vocabulary, isDeepSeekR1DistillQwen);
41+
42+
int modelContextLength = (int) metadata.get("qwen2.context_length");
43+
if (contextLength < 0 || modelContextLength < contextLength) {
44+
contextLength = modelContextLength;
45+
}
46+
47+
int numberOfKeyValueHeads = metadata.containsKey("qwen2.attention.head_count_kv")
48+
? (int) metadata.get("qwen2.attention.head_count_kv")
49+
: (int) metadata.get("qwen2.attention.head_count");
50+
Qwen2Configuration config = new Qwen2Configuration(
51+
(int) metadata.get("qwen2.embedding_length"), // dim
52+
(int) metadata.get("qwen2.feed_forward_length"), // hiddendim
53+
(int) metadata.get("qwen2.block_count"), // numberOfLayers
54+
(int) metadata.get("qwen2.attention.head_count"), // numberOfHeads
55+
56+
numberOfKeyValueHeads, // numberOfKeyValueHeads
57+
numberOfKeyValueHeads, // numberOfHeadsKey
58+
numberOfKeyValueHeads, // numberOfHeadsValue
59+
60+
vocabulary.size(),
61+
modelContextLength, contextLength,
62+
false,
63+
(float) metadata.get("qwen2.attention.layer_norm_rms_epsilon"),
64+
(float) metadata.get("qwen2.rope.freq_base")
65+
);
66+
67+
Weights weights = null;
68+
if (loadWeights) {
69+
Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
70+
weights = loadWeights(tensorEntries, config);
71+
}
72+
// Qwen2.5-Coder uses <|endoftext|> as stop-token.
73+
ChatTokens chatTokens = isDeepSeekR1DistillQwen ?
74+
new ChatTokens( "<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "") :
75+
new ChatTokens( "<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>");
76+
return new Qwen2(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens));
77+
} catch (IOException e) {
78+
throw new RuntimeException(e);
79+
}
80+
}
81+
82+
@Override
83+
public Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, Configuration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
84+
GGMLTensorEntry outputWeight) {
85+
return new Qwen2StandardWeights(
86+
loadQuantized(tokenEmbeddings),
87+
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
88+
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
89+
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
90+
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
91+
92+
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.bias")),
93+
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.bias")),
94+
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.bias")),
95+
96+
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
97+
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
98+
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
99+
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
100+
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
101+
loadQuantized(tensorEntries.get("output_norm.weight")),
102+
new ArrayFloatTensor(ropeFreqs.first()),
103+
new ArrayFloatTensor(ropeFreqs.second()),
104+
loadQuantized(outputWeight),
105+
outputWeight.ggmlType());
106+
}
107+
}

0 commit comments

Comments
 (0)