Skip to content

Commit 177476a

Browse files
Add weights for qwen2
1 parent c878541 commit 177476a

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package org.beehive.gpullama3.inference.weights.standard;
2+
3+
import org.beehive.gpullama3.core.model.GGMLType;
4+
import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
5+
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
6+
import org.beehive.gpullama3.inference.weights.Weights;
7+
8+
public class Qwen2StandardWeights extends StandardWeights {
9+
public final FloatTensor[] q_bias, k_bias, v_bias;
10+
11+
public Qwen2StandardWeights(
12+
FloatTensor token_embedding_table,
13+
FloatTensor[] rms_att_weight,
14+
FloatTensor[] wq,
15+
FloatTensor[] wk,
16+
FloatTensor[] wv,
17+
FloatTensor[] q_bias,
18+
FloatTensor[] k_bias,
19+
FloatTensor[] v_bias,
20+
FloatTensor[] wo,
21+
FloatTensor[] rms_ffn_weight,
22+
FloatTensor[] w1,
23+
FloatTensor[] w2,
24+
FloatTensor[] w3,
25+
FloatTensor rms_final_weight,
26+
ArrayFloatTensor freq_cis_real,
27+
ArrayFloatTensor freq_cis_imag,
28+
FloatTensor wcls,
29+
GGMLType weightType) {
30+
// call to StandardWeights constructor
31+
super(token_embedding_table,
32+
rms_att_weight,
33+
wq,
34+
wk,
35+
wv,
36+
wo,
37+
rms_ffn_weight,
38+
w1,
39+
w2,
40+
w3,
41+
rms_final_weight,
42+
freq_cis_real,
43+
freq_cis_imag,
44+
wcls,
45+
weightType);
46+
// init Qwen2-specific fields
47+
this.q_bias = q_bias;
48+
this.k_bias = k_bias;
49+
this.v_bias = v_bias;
50+
}
51+
52+
@Override
53+
public GGMLType getWeightType() {
54+
return weightType;
55+
}
56+
}

0 commit comments

Comments
 (0)