1+ use crate :: models:: with_tracing:: QMatMul ;
2+ use crate :: quantized_var_builder:: VarBuilder ;
3+ use candle:: quantized:: gguf_file;
14use candle:: { DType , Device , Module , Result , Tensor } ;
5+ use candle_nn:: kv_cache:: KvCache ;
26use candle_nn:: Activation ;
3- use candle:: quantized:: gguf_file;
4- use crate :: quantized_var_builder:: VarBuilder ;
5- use std:: sync:: Arc ;
67use std:: io:: Write ;
7- use crate :: models:: with_tracing:: QMatMul ;
8- use candle_nn:: kv_cache:: KvCache ;
8+ use std:: sync:: Arc ;
99
1010const MAX_SEQ_LEN : usize = 4096 ;
1111use candle:: IndexOp ;
@@ -82,17 +82,23 @@ impl QuantizedConfig {
8282
8383 // Helper to get required metadata
8484 let get_u32 = |key : & str | -> Result < usize > {
85- metadata. get ( key)
85+ metadata
86+ . get ( key)
8687 . and_then ( |v| v. to_u32 ( ) . ok ( ) )
8788 . map ( |v| v as usize )
88- . ok_or_else ( || candle:: Error :: Msg ( format ! ( "Missing or invalid metadata key: {}" , key) ) )
89+ . ok_or_else ( || {
90+ candle:: Error :: Msg ( format ! ( "Missing or invalid metadata key: {}" , key) )
91+ } )
8992 } ;
9093
9194 let get_f32 = |key : & str | -> Result < f64 > {
92- metadata. get ( key)
95+ metadata
96+ . get ( key)
9397 . and_then ( |v| v. to_f32 ( ) . ok ( ) )
9498 . map ( |v| v as f64 )
95- . ok_or_else ( || candle:: Error :: Msg ( format ! ( "Missing or invalid metadata key: {}" , key) ) )
99+ . ok_or_else ( || {
100+ candle:: Error :: Msg ( format ! ( "Missing or invalid metadata key: {}" , key) )
101+ } )
96102 } ;
97103
98104 Ok ( Self {
@@ -174,7 +180,12 @@ impl RotaryEmbedding {
174180 } )
175181 }
176182
177- pub fn apply_rotary_emb ( & self , q : & Tensor , k : & Tensor , offset : usize ) -> Result < ( Tensor , Tensor ) > {
183+ pub fn apply_rotary_emb (
184+ & self ,
185+ q : & Tensor ,
186+ k : & Tensor ,
187+ offset : usize ,
188+ ) -> Result < ( Tensor , Tensor ) > {
178189 let ( _, _, seq_len, _) = q. dims4 ( ) ?;
179190 let cos = self . cos . narrow ( 0 , offset, seq_len) ?;
180191 let sin = self . sin . narrow ( 0 , offset, seq_len) ?;
@@ -265,7 +276,7 @@ impl QuantizedAttention {
265276 let q_weight = q_weight. to_device ( device) ?; // Move to GPU
266277
267278 // Re-quantize (now on GPU)
268- use candle:: quantized:: { QTensor , GgmlDType } ;
279+ use candle:: quantized:: { GgmlDType , QTensor } ;
269280 let q_weight_qtensor = QTensor :: quantize ( & q_weight, GgmlDType :: Q8_0 ) ?;
270281 drop ( q_weight_raw) ; // Explicitly free CPU memory
271282 drop ( q_weight) ;
@@ -298,21 +309,22 @@ impl QuantizedAttention {
298309 } )
299310 }
300311
301- fn forward (
302- & mut self ,
303- x : & Tensor ,
304- mask : Option < & Tensor > ,
305- offset : usize ,
306- ) -> Result < Tensor > {
312+ fn forward ( & mut self , x : & Tensor , mask : Option < & Tensor > , offset : usize ) -> Result < Tensor > {
307313 let ( b, seq_len, _) = x. dims3 ( ) ?;
308314
309- let q = self . q_proj . forward ( x) ?
315+ let q = self
316+ . q_proj
317+ . forward ( x) ?
310318 . reshape ( ( b, seq_len, self . num_heads , self . head_dim ) ) ?
311319 . transpose ( 1 , 2 ) ?;
312- let k = self . k_proj . forward ( x) ?
320+ let k = self
321+ . k_proj
322+ . forward ( x) ?
313323 . reshape ( ( b, seq_len, self . num_kv_heads , self . head_dim ) ) ?
314324 . transpose ( 1 , 2 ) ?;
315- let v = self . v_proj . forward ( x) ?
325+ let v = self
326+ . v_proj
327+ . forward ( x) ?
316328 . reshape ( ( b, seq_len, self . num_kv_heads , self . head_dim ) ) ?
317329 . transpose ( 1 , 2 ) ?;
318330
@@ -375,22 +387,21 @@ impl QuantizedDecoderLayer {
375387 self_attn : QuantizedAttention :: new ( attn_vb. clone ( ) , cfg, layer_idx, rotary_emb) ?,
376388 mlp : QuantizedMLP :: new ( attn_vb. clone ( ) , layer_idx) ?,
377389 input_layernorm : RmsNorm :: new (
378- attn_vb. get_no_shape ( "attn_norm.weight" ) ?. dequantize ( vb. device ( ) ) ?,
390+ attn_vb
391+ . get_no_shape ( "attn_norm.weight" ) ?
392+ . dequantize ( vb. device ( ) ) ?,
379393 cfg. rms_norm_eps ,
380394 ) ,
381395 post_attention_layernorm : RmsNorm :: new (
382- attn_vb. get_no_shape ( "ffn_norm.weight" ) ?. dequantize ( vb. device ( ) ) ?,
396+ attn_vb
397+ . get_no_shape ( "ffn_norm.weight" ) ?
398+ . dequantize ( vb. device ( ) ) ?,
383399 cfg. rms_norm_eps ,
384400 ) ,
385401 } )
386402 }
387403
388- fn forward (
389- & mut self ,
390- x : & Tensor ,
391- mask : Option < & Tensor > ,
392- offset : usize ,
393- ) -> Result < Tensor > {
404+ fn forward ( & mut self , x : & Tensor , mask : Option < & Tensor > , offset : usize ) -> Result < Tensor > {
394405 let residual = x;
395406 let x = self . input_layernorm . forward ( x) ?;
396407 let x = self . self_attn . forward ( & x, mask, offset) ?;
@@ -419,7 +430,7 @@ pub struct QuantizedModelForCausalLM {
419430
420431impl QuantizedModelForCausalLM {
421432 pub fn from_gguf < P : AsRef < std:: path:: Path > > ( path : P , device : & Device ) -> Result < Self > {
422- use candle:: quantized:: { QTensor , GgmlDType } ;
433+ use candle:: quantized:: { GgmlDType , QTensor } ;
423434
424435 // Open file once to read metadata
425436 let mut file = std:: fs:: File :: open ( path. as_ref ( ) ) ?;
@@ -437,14 +448,9 @@ impl QuantizedModelForCausalLM {
437448 let embed_tokens = candle_nn:: Embedding :: new ( embed_tensor_gpu, config. hidden_size ) ;
438449
439450 // Create rotary embedding if needed
440- let needs_rope = ( 0 ..config. num_hidden_layers )
441- . any ( |i| !config. should_skip_rope ( i) ) ;
451+ let needs_rope = ( 0 ..config. num_hidden_layers ) . any ( |i| !config. should_skip_rope ( i) ) ;
442452 let rotary_emb = if needs_rope {
443- Some ( Arc :: new ( RotaryEmbedding :: new (
444- DType :: F32 ,
445- & config,
446- device,
447- ) ?) )
453+ Some ( Arc :: new ( RotaryEmbedding :: new ( DType :: F32 , & config, device) ?) )
448454 } else {
449455 None
450456 } ;
@@ -454,7 +460,11 @@ impl QuantizedModelForCausalLM {
454460 println ! ( "Loading {} decoder layers..." , config. num_hidden_layers) ;
455461 for layer_idx in 0 ..config. num_hidden_layers {
456462 if layer_idx % 4 == 0 || layer_idx == config. num_hidden_layers - 1 {
457- print ! ( " Layer {}/{}...\r " , layer_idx + 1 , config. num_hidden_layers) ;
463+ print ! (
464+ " Layer {}/{}...\r " ,
465+ layer_idx + 1 ,
466+ config. num_hidden_layers
467+ ) ;
458468 std:: io:: stdout ( ) . flush ( ) . ok ( ) ;
459469 }
460470 layers. push ( QuantizedDecoderLayer :: new (
@@ -464,7 +474,10 @@ impl QuantizedModelForCausalLM {
464474 rotary_emb. clone ( ) ,
465475 ) ?) ;
466476 }
467- println ! ( " Layer {}/{} - Done! " , config. num_hidden_layers, config. num_hidden_layers) ;
477+ println ! (
478+ " Layer {}/{} - Done! " ,
479+ config. num_hidden_layers, config. num_hidden_layers
480+ ) ;
468481
469482 // Load output norm
470483 let norm = RmsNorm :: new (
@@ -551,4 +564,4 @@ impl QuantizedModelForCausalLM {
551564 pub fn config ( & self ) -> & QuantizedConfig {
552565 & self . config
553566 }
554- }
567+ }
0 commit comments