@@ -12,9 +12,8 @@ impl GGufModel<'_> {
1212 pub fn llama ( & self ) -> nn:: LLaMA < Tensor < & [ u8 ] , 2 > > {
1313 let arch = meta ! [ self => general_architecture] ;
1414 let dt_bias = match arch {
15- "llama" => None ,
15+ "llama" | "qwen3" => None ,
1616 "qwen2" => Some ( self . tensors [ "blk.0.attn_qkv.bias" ] . dt ( ) ) ,
17- "qwen3" => None ,
1817 arch => panic ! ( "unsupported arch {arch}" ) ,
1918 } ;
2019
@@ -24,12 +23,9 @@ impl GGufModel<'_> {
2423 let d = meta ! [ self => llm_embedding_length] ;
2524 let nh = meta ! [ self => llm_attention_head_count] ;
2625 let nkvh = meta ! [ self => llm_attention_head_count_kv; nh] ;
27- let dh = match arch {
28- "qwen3" => self . tensors [ "blk.0.attn_qkv.weight" ] . shape ( ) [ 0 ]
29- . checked_div ( nh + nkvh + nkvh)
30- . unwrap ( ) ,
31- _ => meta ! [ self => llm_rope_dimension_count; d / nh] ,
32- } ;
26+ let dh = meta ! [ self => llm_rope_dimension_count; d / nh] ;
27+ let dk = meta ! [ self => llm_attention_key_length; dh] ;
28+ let dv = meta ! [ self => llm_attention_value_length; dh] ;
3329 let di = meta ! [ self => llm_feed_forward_length] ;
3430 let epsilon = meta ! [ self => llm_attention_layer_norm_rms_epsilon; 1e-5 ] ;
3531 let dt_linear = self . tensors [ "blk.0.attn_qkv.weight" ] . dt ( ) ;
@@ -70,7 +66,7 @@ impl GGufModel<'_> {
7066 nkvh,
7167 qkv : Linear :: new (
7268 dt_linear,
73- [ ( nh + nkvh + nkvh) * dh , d] ,
69+ [ ( nh + nkvh) * dk + nkvh * dv , d] ,
7470 get ( & format ! ( "blk.{iblk}.attn_qkv.weight" ) ) ,
7571 dt_bias. map ( |dt| ( dt, get ( & format ! ( "blk.{iblk}.attn_qkv.bias" ) ) ) ) ,
7672 ) ,
@@ -79,7 +75,7 @@ impl GGufModel<'_> {
7975 . contains_key ( format ! ( "blk.{iblk}.attn_q_norm.weight" ) . as_str ( ) )
8076 {
8177 Some ( Normalization {
82- d : dh ,
78+ d : dk ,
8379 epsilon : epsilon as _ ,
8480 items : NormType :: RmsNorm {
8581 dt : out_norm. dt ( ) ,
@@ -94,7 +90,7 @@ impl GGufModel<'_> {
9490 . contains_key ( format ! ( "blk.{iblk}.attn_k_norm.weight" ) . as_str ( ) )
9591 {
9692 Some ( Normalization {
97- d : dh ,
93+ d : dk ,
9894 epsilon : epsilon as _ ,
9995 items : NormType :: RmsNorm {
10096 dt : out_norm. dt ( ) ,
@@ -112,7 +108,7 @@ impl GGufModel<'_> {
112108 } ) ,
113109 output : Linear :: new (
114110 dt_linear,
115- [ d, nh * dh ] ,
111+ [ d, nh * dv ] ,
116112 get ( & format ! ( "blk.{iblk}.attn_output.weight" ) ) ,
117113 None ,
118114 ) ,
@@ -163,13 +159,8 @@ impl GGufModel<'_> {
163159 let nctx = meta ! [ self => llm_context_length] ;
164160 let d = meta ! [ self => llm_embedding_length] ;
165161 let nh = meta ! [ self => llm_attention_head_count] ;
166- let nkvh = meta ! [ self => llm_attention_head_count_kv; nh] ;
167- let dh = match arch {
168- "qwen3" => self . tensors [ "blk.0.attn_qkv.weight" ] . shape ( ) [ 0 ]
169- . checked_div ( nh + nkvh + nkvh)
170- . unwrap ( ) ,
171- _ => meta ! [ self => llm_rope_dimension_count; d / nh] ,
172- } ;
162+ let dh = meta ! [ self => llm_rope_dimension_count; d / nh] ;
163+ let dk = meta ! [ self => llm_attention_key_length; dh] ;
173164 let theta = meta ! [ self => llm_rope_freq_base; 1e4 ] ;
174165
175166 let [ sin, cos] = match self . get_str ( & format ! ( "{arch}.rope.scaling.type" ) ) {
@@ -178,17 +169,17 @@ impl GGufModel<'_> {
178169
179170 let factors = & self . tensors [ "rope_factors_long.weight" ] ;
180171 assert_eq ! ( factors. dt( ) , types:: F32 ) ;
181- assert_eq ! ( factors. shape( ) , [ dh / 2 ] ) ;
172+ assert_eq ! ( factors. shape( ) , [ dk / 2 ] ) ;
182173 let factors = unsafe {
183- std:: slice:: from_raw_parts ( factors. get ( ) . as_ptr ( ) . cast :: < f32 > ( ) , dh / 2 )
174+ std:: slice:: from_raw_parts ( factors. get ( ) . as_ptr ( ) . cast :: < f32 > ( ) , dk / 2 )
184175 } ;
185176
186177 info ! ( "detected longrope, ctx scale = {ctx_scale}, scale factor = {factors:.2?}" ) ;
187- build_sin_cos ( nctx, dh , theta, |pos, i| {
178+ build_sin_cos ( nctx, dk , theta, |pos, i| {
188179 pos as f32 * ctx_scale / factors[ i]
189180 } )
190181 }
191- Err ( GGufMetaError :: NotExist ) => build_sin_cos ( nctx, dh , theta, |pos, _| pos as _ ) ,
182+ Err ( GGufMetaError :: NotExist ) => build_sin_cos ( nctx, dk , theta, |pos, _| pos as _ ) ,
192183 Ok ( ty) => panic ! ( "Unsupported rope scaling `{ty}`" ) ,
193184 Err ( e) => panic ! ( "{e:?}" ) ,
194185 } ;
0 commit comments