Skip to content

Commit ddd6265

Browse files
committed
feat(llama.cu): 优化 qwen3 参数提取
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 26faf1b commit ddd6265

File tree

1 file changed

+14
-23
lines changed

1 file changed

+14
-23
lines changed

llama.cu/src/model/llama.rs

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@ impl GGufModel<'_> {
1212
pub fn llama(&self) -> nn::LLaMA<Tensor<&[u8], 2>> {
1313
let arch = meta![self => general_architecture];
1414
let dt_bias = match arch {
15-
"llama" => None,
15+
"llama" | "qwen3" => None,
1616
"qwen2" => Some(self.tensors["blk.0.attn_qkv.bias"].dt()),
17-
"qwen3" => None,
1817
arch => panic!("unsupported arch {arch}"),
1918
};
2019

@@ -24,12 +23,9 @@ impl GGufModel<'_> {
2423
let d = meta![self => llm_embedding_length];
2524
let nh = meta![self => llm_attention_head_count];
2625
let nkvh = meta![self => llm_attention_head_count_kv; nh];
27-
let dh = match arch {
28-
"qwen3" => self.tensors["blk.0.attn_qkv.weight"].shape()[0]
29-
.checked_div(nh + nkvh + nkvh)
30-
.unwrap(),
31-
_ => meta![self => llm_rope_dimension_count; d / nh],
32-
};
26+
let dh = meta![self => llm_rope_dimension_count; d / nh];
27+
let dk = meta![self => llm_attention_key_length; dh];
28+
let dv = meta![self => llm_attention_value_length; dh];
3329
let di = meta![self => llm_feed_forward_length];
3430
let epsilon = meta![self => llm_attention_layer_norm_rms_epsilon; 1e-5];
3531
let dt_linear = self.tensors["blk.0.attn_qkv.weight"].dt();
@@ -70,7 +66,7 @@ impl GGufModel<'_> {
7066
nkvh,
7167
qkv: Linear::new(
7268
dt_linear,
73-
[(nh + nkvh + nkvh) * dh, d],
69+
[(nh + nkvh) * dk + nkvh * dv, d],
7470
get(&format!("blk.{iblk}.attn_qkv.weight")),
7571
dt_bias.map(|dt| (dt, get(&format!("blk.{iblk}.attn_qkv.bias")))),
7672
),
@@ -79,7 +75,7 @@ impl GGufModel<'_> {
7975
.contains_key(format!("blk.{iblk}.attn_q_norm.weight").as_str())
8076
{
8177
Some(Normalization {
82-
d: dh,
78+
d: dk,
8379
epsilon: epsilon as _,
8480
items: NormType::RmsNorm {
8581
dt: out_norm.dt(),
@@ -94,7 +90,7 @@ impl GGufModel<'_> {
9490
.contains_key(format!("blk.{iblk}.attn_k_norm.weight").as_str())
9591
{
9692
Some(Normalization {
97-
d: dh,
93+
d: dk,
9894
epsilon: epsilon as _,
9995
items: NormType::RmsNorm {
10096
dt: out_norm.dt(),
@@ -112,7 +108,7 @@ impl GGufModel<'_> {
112108
}),
113109
output: Linear::new(
114110
dt_linear,
115-
[d, nh * dh],
111+
[d, nh * dv],
116112
get(&format!("blk.{iblk}.attn_output.weight")),
117113
None,
118114
),
@@ -163,13 +159,8 @@ impl GGufModel<'_> {
163159
let nctx = meta![self => llm_context_length];
164160
let d = meta![self => llm_embedding_length];
165161
let nh = meta![self => llm_attention_head_count];
166-
let nkvh = meta![self => llm_attention_head_count_kv; nh];
167-
let dh = match arch {
168-
"qwen3" => self.tensors["blk.0.attn_qkv.weight"].shape()[0]
169-
.checked_div(nh + nkvh + nkvh)
170-
.unwrap(),
171-
_ => meta![self => llm_rope_dimension_count; d / nh],
172-
};
162+
let dh = meta![self => llm_rope_dimension_count; d / nh];
163+
let dk = meta![self => llm_attention_key_length; dh];
173164
let theta = meta![self => llm_rope_freq_base; 1e4];
174165

175166
let [sin, cos] = match self.get_str(&format!("{arch}.rope.scaling.type")) {
@@ -178,17 +169,17 @@ impl GGufModel<'_> {
178169

179170
let factors = &self.tensors["rope_factors_long.weight"];
180171
assert_eq!(factors.dt(), types::F32);
181-
assert_eq!(factors.shape(), [dh / 2]);
172+
assert_eq!(factors.shape(), [dk / 2]);
182173
let factors = unsafe {
183-
std::slice::from_raw_parts(factors.get().as_ptr().cast::<f32>(), dh / 2)
174+
std::slice::from_raw_parts(factors.get().as_ptr().cast::<f32>(), dk / 2)
184175
};
185176

186177
info!("detected longrope, ctx scale = {ctx_scale}, scale factor = {factors:.2?}");
187-
build_sin_cos(nctx, dh, theta, |pos, i| {
178+
build_sin_cos(nctx, dk, theta, |pos, i| {
188179
pos as f32 * ctx_scale / factors[i]
189180
})
190181
}
191-
Err(GGufMetaError::NotExist) => build_sin_cos(nctx, dh, theta, |pos, _| pos as _),
182+
Err(GGufMetaError::NotExist) => build_sin_cos(nctx, dk, theta, |pos, _| pos as _),
192183
Ok(ty) => panic!("Unsupported rope scaling `{ty}`"),
193184
Err(e) => panic!("{e:?}"),
194185
};

0 commit comments

Comments
 (0)