Skip to content

Commit 4029373

Browse files
Add loadWeights method for Qwen2
1 parent 4ef777c commit 4029373

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
package org.beehive.gpullama3.model.loader;
22

3+
import org.beehive.gpullama3.LlamaApp;
34
import org.beehive.gpullama3.auxiliary.Timer;
5+
import org.beehive.gpullama3.core.model.GGMLType;
46
import org.beehive.gpullama3.core.model.GGUF;
57
import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
68
import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry;
79
import org.beehive.gpullama3.core.types.Pair;
10+
import org.beehive.gpullama3.inference.operation.RoPE;
811
import org.beehive.gpullama3.inference.weights.Weights;
912
import org.beehive.gpullama3.inference.weights.standard.Qwen2StandardWeights;
13+
import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
1014
import org.beehive.gpullama3.model.Configuration;
1115
import org.beehive.gpullama3.model.Model;
1216
import org.beehive.gpullama3.model.format.ChatFormat;
@@ -79,6 +83,32 @@ public Model loadModel() {
7983
}
8084
}
8185

86+
// @formatter:off
87+
@Override
88+
public Weights loadWeights(Map<String, GGMLTensorEntry> tensorEntries, Configuration config) {
89+
Pair<float[], float[]> ropeFreqs = RoPE.precomputeFreqsCis(
90+
config.contextLengthModel(),
91+
config.headSize(),
92+
config.ropeTheta(),
93+
false,
94+
8,
95+
1,
96+
3,
97+
8192
98+
);
99+
100+
GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight");
101+
GGMLTensorEntry outputWeight = tensorEntries.getOrDefault("output.weight", tokenEmbeddings);
102+
103+
if (LlamaApp.USE_TORNADOVM) {
104+
System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")");
105+
return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
106+
} else {
107+
return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
108+
}
109+
}
110+
// @formatter:on
111+
82112
@Override
83113
public Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, Configuration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
84114
GGMLTensorEntry outputWeight) {
@@ -104,4 +134,9 @@ public Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries,
104134
loadQuantized(outputWeight),
105135
outputWeight.ggmlType());
106136
}
137+
138+
@Override
139+
}
140+
// @formatter:on
141+
107142
}

0 commit comments

Comments
 (0)