Skip to content

Commit 5413e38

Browse files
committed
Merge pull request #47 from mikepapadim/minor-fix
Minor cleanup
2 parents 93ce172 + 968acb3 commit 5413e38

File tree

1 file changed

+12
-21
lines changed

1 file changed

+12
-21
lines changed

src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
package org.beehive.gpullama3.model.loader;
22

3-
import org.beehive.gpullama3.LlamaApp;
43
import org.beehive.gpullama3.Options;
5-
import org.beehive.gpullama3.auxiliary.Timer;
64
import org.beehive.gpullama3.core.model.GGMLType;
75
import org.beehive.gpullama3.core.model.GGUF;
86
import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
@@ -21,6 +19,7 @@
2119
import org.beehive.gpullama3.tokenizer.impl.Qwen3Tokenizer;
2220
import org.beehive.gpullama3.tokenizer.impl.Tokenizer;
2321
import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary;
22+
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
2423
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
2524

2625
import java.io.IOException;
@@ -40,11 +39,9 @@ public Model loadModel() {
4039
Map<String, Object> metadata = gguf.getMetadata();
4140
String basename = (String) metadata.get("general.basename");
4241

43-
String modelName = "DeepSeek-R1-Distill-Qwen".equals(basename)
44-
? "DeepSeek-R1-Distill-Qwen"
45-
: "Qwen2.5";
42+
String modelName = "DeepSeek-R1-Distill-Qwen".equals(basename) ? "DeepSeek-R1-Distill-Qwen" : "Qwen2.5";
4643

47-
try (var ignored = Timer.log("Load " + modelName + " model")) {
44+
try {
4845
// reuse method of Qwen3
4946
Vocabulary vocabulary = loadQwen3Vocabulary(metadata);
5047
boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename"));
@@ -55,11 +52,8 @@ public Model loadModel() {
5552
contextLength = modelContextLength;
5653
}
5754

58-
int numberOfKeyValueHeads = metadata.containsKey("qwen2.attention.head_count_kv")
59-
? (int) metadata.get("qwen2.attention.head_count_kv")
60-
: (int) metadata.get("qwen2.attention.head_count");
61-
Qwen2Configuration config = new Qwen2Configuration(
62-
(int) metadata.get("qwen2.embedding_length"), // dim
55+
int numberOfKeyValueHeads = metadata.containsKey("qwen2.attention.head_count_kv") ? (int) metadata.get("qwen2.attention.head_count_kv") : (int) metadata.get("qwen2.attention.head_count");
56+
Qwen2Configuration config = new Qwen2Configuration((int) metadata.get("qwen2.embedding_length"), // dim
6357
(int) metadata.get("qwen2.feed_forward_length"), // hiddendim
6458
(int) metadata.get("qwen2.block_count"), // numberOfLayers
6559
(int) metadata.get("qwen2.attention.head_count"), // numberOfHeads
@@ -68,22 +62,17 @@ public Model loadModel() {
6862
numberOfKeyValueHeads, // numberOfHeadsKey
6963
numberOfKeyValueHeads, // numberOfHeadsValue
7064

71-
vocabulary.size(),
72-
modelContextLength, contextLength,
73-
false,
74-
(float) metadata.get("qwen2.attention.layer_norm_rms_epsilon"),
75-
(float) metadata.get("qwen2.rope.freq_base")
76-
);
65+
vocabulary.size(), modelContextLength, contextLength, false, (float) metadata.get("qwen2.attention.layer_norm_rms_epsilon"), (float) metadata.get("qwen2.rope.freq_base"));
7766

7867
Weights weights = null;
7968
if (loadWeights) {
8069
Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
8170
weights = loadWeights(tensorEntries, config);
8271
}
8372
// Qwen2.5-Coder uses <|endoftext|> as stop-token.
84-
ChatTokens chatTokens = isDeepSeekR1DistillQwen ?
85-
new ChatTokens( "<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "") :
86-
new ChatTokens( "<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>");
73+
ChatTokens chatTokens = isDeepSeekR1DistillQwen
74+
? new ChatTokens("<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "")
75+
: new ChatTokens("<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>");
8776
return new Qwen2(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens));
8877
} catch (IOException e) {
8978
throw new RuntimeException(e);
@@ -108,7 +97,9 @@ public Weights loadWeights(Map<String, GGMLTensorEntry> tensorEntries, Configura
10897
GGMLTensorEntry outputWeight = tensorEntries.getOrDefault("output.weight", tokenEmbeddings);
10998

11099
if (Options.getDefaultOptions().useTornadovm()) {
111-
System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")");
100+
if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
101+
System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")");
102+
}
112103
return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
113104
} else {
114105
return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);

0 commit comments

Comments
 (0)