Skip to content

Commit 8447af4

Browse files
committed
formatted
1 parent f8168ed commit 8447af4

File tree

4 files changed

+107
-59
lines changed

4 files changed

+107
-59
lines changed

candle-examples/examples/smollm3/main.rs

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ impl SmolLM3Model {
4747
num_attention_heads: cfg.num_attention_heads,
4848
num_key_value_heads: cfg.num_key_value_heads,
4949
rope_theta: cfg.rope_theta as f32, // Convert f64 to f32
50-
eos_token_id: Some(128012), // Default SmolLM3 EOS
50+
eos_token_id: Some(128012), // Default SmolLM3 EOS
5151
no_rope_layers: None,
5252
no_rope_layer_interval: None,
5353
}
@@ -61,7 +61,10 @@ impl SmolLM3Model {
6161
num_key_value_heads: cfg.num_key_value_heads,
6262
rope_theta: cfg.rope_theta as f32, // Convert f64 to f32
6363
eos_token_id: cfg.eos_token_id,
64-
no_rope_layers: cfg.no_rope_layers.as_ref().map(|v| v.iter().map(|&x| x as u32).collect()), // Convert Vec<usize> to Vec<u32>
64+
no_rope_layers: cfg
65+
.no_rope_layers
66+
.as_ref()
67+
.map(|v| v.iter().map(|&x| x as u32).collect()), // Convert Vec<usize> to Vec<u32>
6568
no_rope_layer_interval: cfg.no_rope_layer_interval,
6669
}
6770
}
@@ -313,13 +316,17 @@ fn format_prompt(prompt: &str, use_chat_template: bool, enable_thinking: bool) -
313316
let today_date = now.format("%d %B %Y").to_string();
314317

315318
// Set reasoning mode based on thinking flag
316-
let reasoning_mode = if enable_thinking { "/think" } else { "/no_think" };
319+
let reasoning_mode = if enable_thinking {
320+
"/think"
321+
} else {
322+
"/no_think"
323+
};
317324

318325
// Build the assistant start with or without thinking tags
319326
let assistant_start = if enable_thinking {
320-
"<|im_start|>assistant\n<think>\n" // Open for reasoning
327+
"<|im_start|>assistant\n<think>\n" // Open for reasoning
321328
} else {
322-
"<|im_start|>assistant\n<think>\n\n</think>\n" // Empty = skip reasoning
329+
"<|im_start|>assistant\n<think>\n\n</think>\n" // Empty = skip reasoning
323330
};
324331

325332
format!(
@@ -337,10 +344,7 @@ You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\
337344
<|im_start|>user\n\
338345
{}<|im_end|>\n\
339346
{}",
340-
today_date,
341-
reasoning_mode,
342-
prompt,
343-
assistant_start
347+
today_date, reasoning_mode, prompt, assistant_start
344348
)
345349
} else {
346350
prompt.to_string()
@@ -381,8 +385,22 @@ fn run_generation(
381385

382386
println!("\n=== Generation Settings ===");
383387
println!("Model type: {:?}", args.model_type);
384-
println!("Chat template: {}", if use_chat_template { "enabled" } else { "disabled" });
385-
println!("Thinking mode: {}", if args.thinking { "enabled (/think)" } else { "disabled (/no_think)" });
388+
println!(
389+
"Chat template: {}",
390+
if use_chat_template {
391+
"enabled"
392+
} else {
393+
"disabled"
394+
}
395+
);
396+
println!(
397+
"Thinking mode: {}",
398+
if args.thinking {
399+
"enabled (/think)"
400+
} else {
401+
"disabled (/no_think)"
402+
}
403+
);
386404
println!("Raw prompt: {}", prompt_str);
387405

388406
// Encode prompt
@@ -597,4 +615,4 @@ fn main() -> Result<()> {
597615
run_generation(&mut model, tokenizer, &args, &device)?;
598616

599617
Ok(())
600-
}
618+
}

candle-transformers/src/models/smol/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,5 +63,5 @@
6363
//! - [SmolLM3 Model Card](https://huggingface.co/HuggingFaceTB/SmolLM3-3B)
6464
//! - [NoPE Paper](https://arxiv.org/abs/2410.01926)
6565
66-
pub mod smollm3;
6766
pub mod quantized_smollm3;
67+
pub mod smollm3;

candle-transformers/src/models/smol/quantized_smollm3.rs

Lines changed: 52 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
use crate::models::with_tracing::QMatMul;
2+
use crate::quantized_var_builder::VarBuilder;
3+
use candle::quantized::gguf_file;
14
use candle::{DType, Device, Module, Result, Tensor};
5+
use candle_nn::kv_cache::KvCache;
26
use candle_nn::Activation;
3-
use candle::quantized::gguf_file;
4-
use crate::quantized_var_builder::VarBuilder;
5-
use std::sync::Arc;
67
use std::io::Write;
7-
use crate::models::with_tracing::QMatMul;
8-
use candle_nn::kv_cache::KvCache;
8+
use std::sync::Arc;
99

1010
const MAX_SEQ_LEN: usize = 4096;
1111
use candle::IndexOp;
@@ -82,17 +82,23 @@ impl QuantizedConfig {
8282

8383
// Helper to get required metadata
8484
let get_u32 = |key: &str| -> Result<usize> {
85-
metadata.get(key)
85+
metadata
86+
.get(key)
8687
.and_then(|v| v.to_u32().ok())
8788
.map(|v| v as usize)
88-
.ok_or_else(|| candle::Error::Msg(format!("Missing or invalid metadata key: {}", key)))
89+
.ok_or_else(|| {
90+
candle::Error::Msg(format!("Missing or invalid metadata key: {}", key))
91+
})
8992
};
9093

9194
let get_f32 = |key: &str| -> Result<f64> {
92-
metadata.get(key)
95+
metadata
96+
.get(key)
9397
.and_then(|v| v.to_f32().ok())
9498
.map(|v| v as f64)
95-
.ok_or_else(|| candle::Error::Msg(format!("Missing or invalid metadata key: {}", key)))
99+
.ok_or_else(|| {
100+
candle::Error::Msg(format!("Missing or invalid metadata key: {}", key))
101+
})
96102
};
97103

98104
Ok(Self {
@@ -174,7 +180,12 @@ impl RotaryEmbedding {
174180
})
175181
}
176182

177-
pub fn apply_rotary_emb(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> {
183+
pub fn apply_rotary_emb(
184+
&self,
185+
q: &Tensor,
186+
k: &Tensor,
187+
offset: usize,
188+
) -> Result<(Tensor, Tensor)> {
178189
let (_, _, seq_len, _) = q.dims4()?;
179190
let cos = self.cos.narrow(0, offset, seq_len)?;
180191
let sin = self.sin.narrow(0, offset, seq_len)?;
@@ -265,7 +276,7 @@ impl QuantizedAttention {
265276
let q_weight = q_weight.to_device(device)?; // Move to GPU
266277

267278
// Re-quantize (now on GPU)
268-
use candle::quantized::{QTensor, GgmlDType};
279+
use candle::quantized::{GgmlDType, QTensor};
269280
let q_weight_qtensor = QTensor::quantize(&q_weight, GgmlDType::Q8_0)?;
270281
drop(q_weight_raw); // Explicitly free CPU memory
271282
drop(q_weight);
@@ -298,21 +309,22 @@ impl QuantizedAttention {
298309
})
299310
}
300311

301-
fn forward(
302-
&mut self,
303-
x: &Tensor,
304-
mask: Option<&Tensor>,
305-
offset: usize,
306-
) -> Result<Tensor> {
312+
fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
307313
let (b, seq_len, _) = x.dims3()?;
308314

309-
let q = self.q_proj.forward(x)?
315+
let q = self
316+
.q_proj
317+
.forward(x)?
310318
.reshape((b, seq_len, self.num_heads, self.head_dim))?
311319
.transpose(1, 2)?;
312-
let k = self.k_proj.forward(x)?
320+
let k = self
321+
.k_proj
322+
.forward(x)?
313323
.reshape((b, seq_len, self.num_kv_heads, self.head_dim))?
314324
.transpose(1, 2)?;
315-
let v = self.v_proj.forward(x)?
325+
let v = self
326+
.v_proj
327+
.forward(x)?
316328
.reshape((b, seq_len, self.num_kv_heads, self.head_dim))?
317329
.transpose(1, 2)?;
318330

@@ -375,22 +387,21 @@ impl QuantizedDecoderLayer {
375387
self_attn: QuantizedAttention::new(attn_vb.clone(), cfg, layer_idx, rotary_emb)?,
376388
mlp: QuantizedMLP::new(attn_vb.clone(), layer_idx)?,
377389
input_layernorm: RmsNorm::new(
378-
attn_vb.get_no_shape("attn_norm.weight")?.dequantize(vb.device())?,
390+
attn_vb
391+
.get_no_shape("attn_norm.weight")?
392+
.dequantize(vb.device())?,
379393
cfg.rms_norm_eps,
380394
),
381395
post_attention_layernorm: RmsNorm::new(
382-
attn_vb.get_no_shape("ffn_norm.weight")?.dequantize(vb.device())?,
396+
attn_vb
397+
.get_no_shape("ffn_norm.weight")?
398+
.dequantize(vb.device())?,
383399
cfg.rms_norm_eps,
384400
),
385401
})
386402
}
387403

388-
fn forward(
389-
&mut self,
390-
x: &Tensor,
391-
mask: Option<&Tensor>,
392-
offset: usize,
393-
) -> Result<Tensor> {
404+
fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
394405
let residual = x;
395406
let x = self.input_layernorm.forward(x)?;
396407
let x = self.self_attn.forward(&x, mask, offset)?;
@@ -419,7 +430,7 @@ pub struct QuantizedModelForCausalLM {
419430

420431
impl QuantizedModelForCausalLM {
421432
pub fn from_gguf<P: AsRef<std::path::Path>>(path: P, device: &Device) -> Result<Self> {
422-
use candle::quantized::{QTensor, GgmlDType};
433+
use candle::quantized::{GgmlDType, QTensor};
423434

424435
// Open file once to read metadata
425436
let mut file = std::fs::File::open(path.as_ref())?;
@@ -437,14 +448,9 @@ impl QuantizedModelForCausalLM {
437448
let embed_tokens = candle_nn::Embedding::new(embed_tensor_gpu, config.hidden_size);
438449

439450
// Create rotary embedding if needed
440-
let needs_rope = (0..config.num_hidden_layers)
441-
.any(|i| !config.should_skip_rope(i));
451+
let needs_rope = (0..config.num_hidden_layers).any(|i| !config.should_skip_rope(i));
442452
let rotary_emb = if needs_rope {
443-
Some(Arc::new(RotaryEmbedding::new(
444-
DType::F32,
445-
&config,
446-
device,
447-
)?))
453+
Some(Arc::new(RotaryEmbedding::new(DType::F32, &config, device)?))
448454
} else {
449455
None
450456
};
@@ -454,7 +460,11 @@ impl QuantizedModelForCausalLM {
454460
println!("Loading {} decoder layers...", config.num_hidden_layers);
455461
for layer_idx in 0..config.num_hidden_layers {
456462
if layer_idx % 4 == 0 || layer_idx == config.num_hidden_layers - 1 {
457-
print!(" Layer {}/{}...\r", layer_idx + 1, config.num_hidden_layers);
463+
print!(
464+
" Layer {}/{}...\r",
465+
layer_idx + 1,
466+
config.num_hidden_layers
467+
);
458468
std::io::stdout().flush().ok();
459469
}
460470
layers.push(QuantizedDecoderLayer::new(
@@ -464,7 +474,10 @@ impl QuantizedModelForCausalLM {
464474
rotary_emb.clone(),
465475
)?);
466476
}
467-
println!(" Layer {}/{} - Done! ", config.num_hidden_layers, config.num_hidden_layers);
477+
println!(
478+
" Layer {}/{} - Done! ",
479+
config.num_hidden_layers, config.num_hidden_layers
480+
);
468481

469482
// Load output norm
470483
let norm = RmsNorm::new(
@@ -551,4 +564,4 @@ impl QuantizedModelForCausalLM {
551564
pub fn config(&self) -> &QuantizedConfig {
552565
&self.config
553566
}
554-
}
567+
}

candle-transformers/src/models/smol/smollm3.rs

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ pub struct Config {
3636
}
3737

3838
impl Config {
39-
4039
pub fn should_skip_rope(&self, layer_idx: usize) -> bool {
4140
// Method 1: Explicit array (some model variants may provide this)
4241
if let Some(ref no_rope_layers) = self.no_rope_layers {
@@ -112,9 +111,24 @@ impl SmolLM3MLP {
112111
pub(crate) fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
113112
let mlp_bias = cfg.mlp_bias.unwrap_or(false);
114113
Ok(Self {
115-
gate_proj: linear_b(cfg.hidden_size, cfg.intermediate_size, mlp_bias, vb.pp("gate_proj"))?,
116-
up_proj: linear_b(cfg.hidden_size, cfg.intermediate_size, mlp_bias, vb.pp("up_proj"))?,
117-
down_proj: linear_b(cfg.intermediate_size, cfg.hidden_size, mlp_bias, vb.pp("down_proj"))?,
114+
gate_proj: linear_b(
115+
cfg.hidden_size,
116+
cfg.intermediate_size,
117+
mlp_bias,
118+
vb.pp("gate_proj"),
119+
)?,
120+
up_proj: linear_b(
121+
cfg.hidden_size,
122+
cfg.intermediate_size,
123+
mlp_bias,
124+
vb.pp("up_proj"),
125+
)?,
126+
down_proj: linear_b(
127+
cfg.intermediate_size,
128+
cfg.hidden_size,
129+
mlp_bias,
130+
vb.pp("down_proj"),
131+
)?,
118132
act_fn: cfg.hidden_act,
119133
})
120134
}
@@ -350,7 +364,11 @@ impl Model {
350364
// Only create rotary embedding if at least one layer uses RoPE
351365
let needs_rope = (0..cfg.num_hidden_layers).any(|i| !cfg.should_skip_rope(i));
352366
let rotary = if needs_rope {
353-
Some(Arc::new(SmolLM3RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?))
367+
Some(Arc::new(SmolLM3RotaryEmbedding::new(
368+
vb.dtype(),
369+
cfg,
370+
vb.device(),
371+
)?))
354372
} else {
355373
None
356374
};
@@ -444,10 +462,9 @@ impl ModelForCausalLM {
444462
.forward(input, offset)?
445463
.narrow(1, l - 1, 1)?
446464
.apply(&self.lm_head)
447-
448465
}
449466

450467
pub fn clear_kv_cache(&mut self) {
451468
self.base.clear_kv_cache();
452469
}
453-
}
470+
}

0 commit comments

Comments
 (0)