1
1
package org .beehive .gpullama3 .model .loader ;
2
2
3
+ import org .beehive .gpullama3 .LlamaApp ;
3
4
import org .beehive .gpullama3 .auxiliary .Timer ;
5
+ import org .beehive .gpullama3 .core .model .GGMLType ;
4
6
import org .beehive .gpullama3 .core .model .GGUF ;
5
7
import org .beehive .gpullama3 .core .model .tensor .ArrayFloatTensor ;
6
8
import org .beehive .gpullama3 .core .model .tensor .GGMLTensorEntry ;
7
9
import org .beehive .gpullama3 .core .types .Pair ;
10
+ import org .beehive .gpullama3 .inference .operation .RoPE ;
8
11
import org .beehive .gpullama3 .inference .weights .Weights ;
9
12
import org .beehive .gpullama3 .inference .weights .standard .Qwen2StandardWeights ;
13
+ import org .beehive .gpullama3 .inference .weights .tornado .Qwen2TornadoWeights ;
10
14
import org .beehive .gpullama3 .model .Configuration ;
11
15
import org .beehive .gpullama3 .model .Model ;
12
16
import org .beehive .gpullama3 .model .format .ChatFormat ;
@@ -79,6 +83,32 @@ public Model loadModel() {
79
83
}
80
84
}
81
85
86
+ // @formatter:off
87
+ @ Override
88
+ public Weights loadWeights (Map <String , GGMLTensorEntry > tensorEntries , Configuration config ) {
89
+ Pair <float [], float []> ropeFreqs = RoPE .precomputeFreqsCis (
90
+ config .contextLengthModel (),
91
+ config .headSize (),
92
+ config .ropeTheta (),
93
+ false ,
94
+ 8 ,
95
+ 1 ,
96
+ 3 ,
97
+ 8192
98
+ );
99
+
100
+ GGMLTensorEntry tokenEmbeddings = tensorEntries .get ("token_embd.weight" );
101
+ GGMLTensorEntry outputWeight = tensorEntries .getOrDefault ("output.weight" , tokenEmbeddings );
102
+
103
+ if (LlamaApp .USE_TORNADOVM ) {
104
+ System .out .println ("Loading model weights in TornadoVM format (loading " + outputWeight .ggmlType () + " -> " + GGMLType .F16 + ")" );
105
+ return createTornadoVMWeights (tensorEntries , config , ropeFreqs , tokenEmbeddings , outputWeight );
106
+ } else {
107
+ return createStandardWeights (tensorEntries , config , ropeFreqs , tokenEmbeddings , outputWeight );
108
+ }
109
+ }
110
+ // @formatter:on
111
+
82
112
@ Override
83
113
public Weights createStandardWeights (Map <String , GGMLTensorEntry > tensorEntries , Configuration config , Pair <float [], float []> ropeFreqs , GGMLTensorEntry tokenEmbeddings ,
84
114
GGMLTensorEntry outputWeight ) {
@@ -104,4 +134,9 @@ public Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries,
104
134
loadQuantized (outputWeight ),
105
135
outputWeight .ggmlType ());
106
136
}
137
+
138
+ @ Override
139
+ }
140
+ // @formatter:on
141
+
107
142
}
0 commit comments