From ea162e87d3d50722f27569c8959150df5851c75f Mon Sep 17 00:00:00 2001 From: DrJesseGlass Date: Wed, 12 Nov 2025 13:42:12 -0500 Subject: [PATCH 1/9] quantized and full SmolLM3 --- candle-examples/Cargo.toml | 1 + candle-examples/examples/smollm3/README.md | 120 ++++ candle-examples/examples/smollm3/main.rs | 600 ++++++++++++++++++ candle-transformers/src/models/mod.rs | 1 + candle-transformers/src/models/smol/README.md | 118 ++++ candle-transformers/src/models/smol/mod.rs | 67 ++ .../src/models/smol/quantized_smollm3.rs | 554 ++++++++++++++++ .../src/models/smol/smollm3.rs | 453 +++++++++++++ 8 files changed, 1914 insertions(+) create mode 100644 candle-examples/examples/smollm3/README.md create mode 100644 candle-examples/examples/smollm3/main.rs create mode 100644 candle-transformers/src/models/smol/README.md create mode 100644 candle-transformers/src/models/smol/mod.rs create mode 100644 candle-transformers/src/models/smol/quantized_smollm3.rs create mode 100644 candle-transformers/src/models/smol/smollm3.rs diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index e64619ae4c..201412c0ec 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -19,6 +19,7 @@ candle-flash-attn = { workspace = true, optional = true } candle-onnx = { workspace = true, optional = true } csv = "1.3.0" +crono = "0.4.0" cudarc = { workspace = true, optional = true } half = { workspace = true, optional = true } hf-hub = { workspace = true, features = ["tokio"] } diff --git a/candle-examples/examples/smollm3/README.md b/candle-examples/examples/smollm3/README.md new file mode 100644 index 0000000000..1051816b63 --- /dev/null +++ b/candle-examples/examples/smollm3/README.md @@ -0,0 +1,120 @@ +# SmolLM3 Unified Inference + +A unified Rust implementation for running SmolLM3 models using the Candle ML framework. Supports both quantized (GGUF) and full precision (safetensors) models with a single codebase. + +## Features + +- **Dual Model Support**: Run either quantized or full precision models +- **Multiple Quantization Levels**: Q4_K_M (1.9GB), Q8_0 (3.3GB), F16 (6.2GB) +- **Chat Template Support**: Automatic formatting for instruction-tuned models +- **Thinking Mode**: Enable reasoning traces with `/think` mode +- **NoPE Architecture**: Supports SmolLM3's mixed RoPE/NoPE layer configuration +- **Auto-download**: Automatically fetches models from HuggingFace Hub + +## Quick Start + +### Quantized Model (Recommended) +```bash +cargo run --release --example smollm3 -- \ + --model-type quantized \ + --quantization q8_0 \ + --prompt "Explain Rust's ownership system" +``` + +### Full Precision Model +```bash +cargo run --release --example smollm3 -- \ + --model-type full \ + --dtype f16 \ + --prompt "Write a sorting algorithm in Rust" +``` + +## Command Line Options + +### Model Selection +- `--model-type `: Choose `quantized` or `full` (default: quantized) +- `--model `: Choose `3b` (instruct) or `3b-base` (default: 3b) +- `--quantization `: For quantized models - `q4_k_m`, `q8_0`, or `f16` (default: q8_0) +- `--dtype `: For full models - `f32`, `f16`, `bf16`, or `auto` (default: auto) + +### Generation Parameters +- `--prompt `: The prompt to generate from +- `-n, --sample-len `: Number of tokens to generate (default: 1000) +- `--temperature `: Sampling temperature, 0 for greedy (default: 0.8) +- `--top-p `: Nucleus sampling probability cutoff +- `--top-k `: Only sample among top K tokens +- `--repeat-penalty `: Penalty for repeating tokens (default: 1.1) +- `--repeat-last-n `: Context size for repeat penalty (default: 64) + +### Advanced Options +- `--no-chat-template`: Disable chat template formatting (use for base models) +- `--thinking`: Enable thinking/reasoning mode with `/think` tags +- `--split-prompt`: Process prompt tokens individually (for debugging) +- `--tracing`: Enable performance tracing (generates trace JSON) +- `--model-path `: Use local model file instead of auto-download +- `--tokenizer `: Use local tokenizer instead of auto-download + +## Quantization Comparison + +| Level | Size | Quality | Use Case | +|--------|-------|---------|----------| +| Q4_K_M | 1.9GB | Good | Fast inference, constrained environments | +| Q8_0 | 3.3GB | Better | Balanced quality and speed | +| F16 | 6.2GB | Best | Maximum quality in GGUF format | + +## Examples + +### Creative Writing with Thinking Mode +```bash +cargo run --release --example smollm3 -- \ + --thinking \ + --temperature 0.9 \ + --prompt "Write a short sci-fi story about AI" +``` + +### Code Generation (Base Model) +```bash +cargo run --release --example smollm3 -- \ + --model 3b-base \ + --no-chat-template \ + --temperature 0.2 \ + --prompt "def fibonacci(n):" +``` + +### High Quality Output +```bash +cargo run --release --example smollm3 -- \ + --model-type full \ + --dtype f16 \ + --temperature 0.7 \ + --prompt "Explain quantum entanglement" +``` + +## Model Architecture + +SmolLM3 uses a hybrid RoPE/NoPE architecture: +- **RoPE layers**: Standard rotary position embeddings (75% of layers) +- **NoPE layers**: No position embeddings (25% of layers - every 4th layer) + +This configuration is automatically detected and handled by the implementation. + +## Hardware Requirements + +- **Quantized Q4_K_M**: ~2.5GB RAM +- **Quantized Q8_0**: ~4GB RAM +- **Full F16**: ~7GB RAM +- **Full F32**: ~13GB RAM + +GPU acceleration supported via CUDA (with `cuda` feature) or Metal (macOS). + +## Troubleshooting + +**Model download fails**: Check internet connection and HuggingFace Hub access + +**Out of memory**: Try a smaller quantization level or use `--sample-len` to reduce generation length + +**Compilation errors**: Ensure you're using the latest version of the Candle crate + +## License + +This implementation follows the Candle framework license. SmolLM3 models are available under Apache 2.0. \ No newline at end of file diff --git a/candle-examples/examples/smollm3/main.rs b/candle-examples/examples/smollm3/main.rs new file mode 100644 index 0000000000..cb2418f1fd --- /dev/null +++ b/candle-examples/examples/smollm3/main.rs @@ -0,0 +1,600 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::{Error as E, Result}; +use clap::{Parser, ValueEnum}; +use std::io::Write; + +use candle::{DType, Device, Tensor}; +use candle_examples::token_output_stream::TokenOutputStream; +use candle_nn::VarBuilder; +use candle_transformers::generation::{LogitsProcessor, Sampling}; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::Tokenizer; + +// Import both model implementations +use candle_transformers::models::smol::quantized_smollm3::QuantizedModelForCausalLM; +use candle_transformers::models::smol::smollm3::{Config, ModelForCausalLM}; + +const DEFAULT_PROMPT: &str = "Write a Rust function to calculate the factorial of a given number."; + +// ==================== Model Type Enum ==================== + +enum SmolLM3Model { + Quantized(QuantizedModelForCausalLM), + Full(ModelForCausalLM, Config), // Store config alongside model +} + +impl SmolLM3Model { + fn forward(&mut self, input: &Tensor, pos: usize) -> Result { + match self { + Self::Quantized(model) => Ok(model.forward(input, pos)?), + Self::Full(model, _) => Ok(model.forward(input, pos)?), + } + } + + fn config(&self) -> ModelConfig { + match self { + Self::Quantized(model) => { + let cfg = model.config(); + ModelConfig { + vocab_size: cfg.vocab_size, + hidden_size: cfg.hidden_size, + num_hidden_layers: cfg.num_hidden_layers, + num_attention_heads: cfg.num_attention_heads, + num_key_value_heads: cfg.num_key_value_heads, + rope_theta: cfg.rope_theta as f32, // Convert f64 to f32 + eos_token_id: Some(128012), // Default SmolLM3 EOS + no_rope_layers: None, + no_rope_layer_interval: None, + } + } + Self::Full(_, cfg) => { + ModelConfig { + vocab_size: cfg.vocab_size, + hidden_size: cfg.hidden_size, + num_hidden_layers: cfg.num_hidden_layers, + num_attention_heads: cfg.num_attention_heads, + num_key_value_heads: cfg.num_key_value_heads, + rope_theta: cfg.rope_theta as f32, // Convert f64 to f32 + eos_token_id: cfg.eos_token_id, + no_rope_layers: cfg.no_rope_layers.as_ref().map(|v| v.iter().map(|&x| x as u32).collect()), // Convert Vec to Vec + no_rope_layer_interval: cfg.no_rope_layer_interval, + } + } + } + } +} + +// Unified config representation +struct ModelConfig { + vocab_size: usize, + hidden_size: usize, + num_hidden_layers: usize, + num_attention_heads: usize, + num_key_value_heads: usize, + rope_theta: f32, + eos_token_id: Option, + no_rope_layers: Option>, + no_rope_layer_interval: Option, +} + +impl ModelConfig { + fn head_dim(&self) -> usize { + self.hidden_size / self.num_attention_heads + } +} + +// ==================== CLI Arguments ==================== + +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] +enum ModelType { + /// Use quantized GGUF model (smaller, faster) + #[value(name = "quantized")] + Quantized, + /// Use full precision safetensors model (larger, more accurate) + #[value(name = "full")] + Full, +} + +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] +enum Quantization { + #[value(name = "q4_k_m")] + Q4KM, + #[value(name = "q8_0")] + Q8_0, + #[value(name = "f16")] + F16, +} + +impl Quantization { + fn filename_unsloth(&self) -> &'static str { + match self { + Self::Q4KM => "SmolLM3-3B-Q4_K_M.gguf", + Self::Q8_0 => "SmolLM3-3B-Q8_0.gguf", + Self::F16 => "SmolLM3-3B-F16.gguf", + } + } + + fn size_gb(&self) -> f32 { + match self { + Self::Q4KM => 1.92, + Self::Q8_0 => 3.28, + Self::F16 => 6.16, + } + } +} + +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] +enum WhichModel { + #[value(name = "3b")] + W3b, + #[value(name = "3b-base")] + W3bBase, +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Model type: 'quantized' for GGUF or 'full' for safetensors + #[arg(long, default_value = "quantized")] + model_type: ModelType, + + /// Which model variant to use + #[arg(long, default_value = "3b")] + model: WhichModel, + + /// Quantization level (only for quantized models) + /// Q8_0: 3.3GB, best quality | Q4_K_M: 1.9GB, good balance | F16: 6.2GB, full precision + #[arg(long, default_value = "q8_0")] + quantization: Quantization, + + /// Data type (only for full models: f32, f16, bf16, or auto) + #[arg(long, default_value = "auto")] + dtype: String, + + /// Path to model file (optional, will auto-download if not provided) + #[arg(long)] + model_path: Option, + + /// Path to tokenizer file (optional, will auto-download if not provided) + #[arg(long)] + tokenizer: Option, + + /// The initial prompt + #[arg(long)] + prompt: Option, + + /// The length of the sample to generate (in tokens) + #[arg(short = 'n', long, default_value_t = 1000)] + sample_len: usize, + + /// The temperature used to generate samples, use 0 for greedy sampling + #[arg(long, default_value_t = 0.8)] + temperature: f64, + + /// Nucleus sampling probability cutoff + #[arg(long)] + top_p: Option, + + /// Only sample among the top K samples + #[arg(long)] + top_k: Option, + + /// The seed to use when generating random samples + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// Penalty to be applied for repeating tokens, 1. means no penalty + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, + + /// Skip chat template formatting (use raw prompt, like base model) + #[arg(long)] + no_chat_template: bool, + + /// Enable thinking/reasoning mode (allows model to show its reasoning process) + #[arg(long)] + thinking: bool, + + /// Process prompt elements separately (slower, for debugging) + #[arg(long)] + split_prompt: bool, + + /// Enable tracing (generates a trace-timestamp.json file) + #[arg(long)] + tracing: bool, +} + +impl Args { + fn get_tokenizer(&self) -> Result { + let tokenizer_path = match &self.tokenizer { + Some(path) => std::path::PathBuf::from(path), + None => { + let api = Api::new()?; + let api = api.model("HuggingFaceTB/SmolLM3-3B".to_string()); + api.get("tokenizer.json")? + } + }; + Tokenizer::from_file(tokenizer_path).map_err(E::msg) + } + + fn should_use_chat_template(&self) -> bool { + matches!(self.model, WhichModel::W3b) && !self.no_chat_template + } +} + +// ==================== Model Loading ==================== + +fn load_quantized_model(args: &Args, device: &Device) -> Result { + let model_path = match &args.model_path { + Some(path) => std::path::PathBuf::from(path), + None => { + let filename = args.quantization.filename_unsloth(); + let repo_id = "unsloth/SmolLM3-3B-GGUF"; + let api = Api::new()?; + println!( + "Downloading {} from {} (~{:.2}GB)...", + filename, + repo_id, + args.quantization.size_gb() + ); + api.repo(Repo::with_revision( + repo_id.to_string(), + RepoType::Model, + "main".to_string(), + )) + .get(filename)? + } + }; + + println!("Loading quantized model from {:?}...", model_path); + let model = QuantizedModelForCausalLM::from_gguf(&model_path, device)?; + Ok(SmolLM3Model::Quantized(model)) +} + +fn load_full_model(args: &Args, device: &Device) -> Result { + let api = Api::new()?; + let model_id = match args.model { + WhichModel::W3b => "HuggingFaceTB/SmolLM3-3B", + WhichModel::W3bBase => "HuggingFaceTB/SmolLM3-3B-Base", + }; + + println!("Loading full model from: {}", model_id); + let repo = api.repo(Repo::with_revision( + model_id.to_string(), + RepoType::Model, + "main".to_string(), + )); + + let filenames = match &args.model_path { + Some(path) => vec![std::path::PathBuf::from(path)], + None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, + }; + + let config_file = repo.get("config.json")?; + let config: Config = serde_json::from_slice(&std::fs::read(config_file)?)?; + + let dtype = match args.dtype.as_str() { + "f16" => DType::F16, + "bf16" => DType::BF16, + "f32" => DType::F32, + "auto" => { + if device.is_cuda() || device.is_metal() { + DType::BF16 + } else { + DType::F32 + } + } + other => anyhow::bail!("Unsupported dtype: {}, use f16, bf16, f32, or auto", other), + }; + + println!("Using dtype: {:?}", dtype); + + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, device)? }; + let model = ModelForCausalLM::new(&config, vb)?; + + Ok(SmolLM3Model::Full(model, config)) +} + +// ==================== Text Generation ==================== + +fn format_prompt(prompt: &str, use_chat_template: bool, enable_thinking: bool) -> String { + if use_chat_template { + // Generate current date dynamically + let now = chrono::Local::now(); + let today_date = now.format("%d %B %Y").to_string(); + + // Set reasoning mode based on thinking flag + let reasoning_mode = if enable_thinking { "/think" } else { "/no_think" }; + + // Build the assistant start with or without thinking tags + let assistant_start = if enable_thinking { + "<|im_start|>assistant\n\n" // Open for reasoning + } else { + "<|im_start|>assistant\n\n\n\n" // Empty = skip reasoning + }; + + format!( + "<|im_start|>system\n\ +## Metadata\n\ +\n\ +Knowledge Cutoff Date: June 2025\n\ +Today Date: {}\n\ +Reasoning Mode: {}\n\ +\n\ +## Custom Instructions\n\ +\n\ +You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\ +\n\ +<|im_start|>user\n\ +{}<|im_end|>\n\ +{}", + today_date, + reasoning_mode, + prompt, + assistant_start + ) + } else { + prompt.to_string() + } +} + +fn get_eos_token(tokenizer: &Tokenizer, config: &ModelConfig) -> u32 { + if let Some(eos_id) = config.eos_token_id { + return eos_id; + } + + let vocab = tokenizer.get_vocab(true); + if let Some(&eos_id) = vocab.get("<|im_end|>") { + return eos_id; + } + if let Some(&eos_id) = vocab.get("<|endoftext|>") { + return eos_id; + } + + 128012 // Default SmolLM3 EOS token +} + +fn run_generation( + model: &mut SmolLM3Model, + tokenizer: Tokenizer, + args: &Args, + device: &Device, +) -> Result<()> { + let mut tos = TokenOutputStream::new(tokenizer); + + // Prepare prompt + let prompt_str = args + .prompt + .clone() + .unwrap_or_else(|| DEFAULT_PROMPT.to_string()); + let use_chat_template = args.should_use_chat_template(); + let formatted_prompt = format_prompt(&prompt_str, use_chat_template, args.thinking); + + println!("\n=== Generation Settings ==="); + println!("Model type: {:?}", args.model_type); + println!("Chat template: {}", if use_chat_template { "enabled" } else { "disabled" }); + println!("Thinking mode: {}", if args.thinking { "enabled (/think)" } else { "disabled (/no_think)" }); + println!("Raw prompt: {}", prompt_str); + + // Encode prompt + let tokens = tos + .tokenizer() + .encode(formatted_prompt.as_str(), false) + .map_err(E::msg)?; + let tokens = tokens.get_ids(); + println!("Encoded {} tokens", tokens.len()); + + // Setup logits processor + let sampling = if args.temperature <= 0.0 { + Sampling::ArgMax + } else { + match (args.top_k, args.top_p) { + (None, None) => Sampling::All { + temperature: args.temperature, + }, + (Some(k), None) => Sampling::TopK { + k, + temperature: args.temperature, + }, + (None, Some(p)) => Sampling::TopP { + p, + temperature: args.temperature, + }, + (Some(k), Some(p)) => Sampling::TopKThenTopP { + k, + p, + temperature: args.temperature, + }, + } + }; + let mut logits_processor = LogitsProcessor::from_sampling(args.seed, sampling); + + // Process prompt + let start_prompt = std::time::Instant::now(); + let mut next_token = if !args.split_prompt { + let input = Tensor::new(tokens, device)?.unsqueeze(0)?; + let logits = model.forward(&input, 0)?; + let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; + logits_processor.sample(&logits)? + } else { + let mut next_token = 0; + for (pos, &token) in tokens.iter().enumerate() { + let input = Tensor::new(&[token], device)?.unsqueeze(0)?; + let logits = model.forward(&input, pos)?; + let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; + next_token = logits_processor.sample(&logits)?; + } + next_token + }; + let prompt_dt = start_prompt.elapsed(); + + // Get EOS token + let config = model.config(); + let eos_token = get_eos_token(tos.tokenizer(), &config); + + // Generate tokens + let mut all_tokens = vec![next_token]; + print!("\n=== Output ===\n"); + if let Some(t) = tos.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + + let start_generation = std::time::Instant::now(); + let to_sample = args.sample_len.saturating_sub(1); + let mut sampled = 0; + + for index in 0..to_sample { + let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?; + let logits = model.forward(&input, tokens.len() + index)?; + let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; + + let logits = if args.repeat_penalty == 1.0 { + logits + } else { + let start_at = all_tokens.len().saturating_sub(args.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + args.repeat_penalty, + &all_tokens[start_at..], + )? + }; + + next_token = logits_processor.sample(&logits)?; + all_tokens.push(next_token); + + if let Some(t) = tos.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + + sampled += 1; + if next_token == eos_token { + break; + } + } + + if let Some(rest) = tos.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } + + let generation_dt = start_generation.elapsed(); + + // Print statistics + println!( + "\n\n=== Statistics ===\n\ + {:4} prompt tokens processed: {:.2} token/s\n\ + {:4} tokens generated: {:.2} token/s", + tokens.len(), + tokens.len() as f64 / prompt_dt.as_secs_f64(), + sampled, + sampled as f64 / generation_dt.as_secs_f64(), + ); + + Ok(()) +} + +// ==================== Main ==================== + +fn print_model_info(config: &ModelConfig) { + println!("\n=== Model Configuration ==="); + println!("Vocab size: {}", config.vocab_size); + println!("Hidden size: {}", config.hidden_size); + println!("Num layers: {}", config.num_hidden_layers); + println!("Num attention heads: {}", config.num_attention_heads); + println!("Num KV heads: {}", config.num_key_value_heads); + println!("Head dim: {}", config.head_dim()); + println!("RoPE theta: {:.0}", config.rope_theta); + + // Print RoPE/NoPE layer info for full models + if let Some(ref no_rope_layers) = config.no_rope_layers { + let num_rope_layers = no_rope_layers.iter().filter(|&&x| x == 1).count(); + let num_nope_layers = no_rope_layers.iter().filter(|&&x| x == 0).count(); + println!("\nLayer Configuration:"); + println!( + " RoPE layers: {} ({}%)", + num_rope_layers, + num_rope_layers * 100 / config.num_hidden_layers + ); + println!( + " NoPE layers: {} ({}%)", + num_nope_layers, + num_nope_layers * 100 / config.num_hidden_layers + ); + } else if let Some(interval) = config.no_rope_layer_interval { + let num_nope_layers = config.num_hidden_layers / interval; + let num_rope_layers = config.num_hidden_layers - num_nope_layers; + println!("\nLayer Configuration:"); + println!( + " RoPE layers: {} ({}%)", + num_rope_layers, + num_rope_layers * 100 / config.num_hidden_layers + ); + println!( + " NoPE layers: {} ({}%) - every {}th layer", + num_nope_layers, + num_nope_layers * 100 / config.num_hidden_layers, + interval + ); + } +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + println!("=== SmolLM3 Unified Inference ==="); + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2}, repeat-penalty: {:.2}, repeat-last-n: {}", + args.temperature, args.repeat_penalty, args.repeat_last_n + ); + + let start = std::time::Instant::now(); + let device = candle_examples::device(false)?; + + // Load model + let mut model = match args.model_type { + ModelType::Quantized => load_quantized_model(&args, &device)?, + ModelType::Full => load_full_model(&args, &device)?, + }; + + println!("Model loaded in {:.2}s", start.elapsed().as_secs_f32()); + + // Print model info + let config = model.config(); + print_model_info(&config); + + // Load tokenizer + let tokenizer = args.get_tokenizer()?; + + // Run generation + run_generation(&mut model, tokenizer, &args, &device)?; + + Ok(()) +} \ No newline at end of file diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index e77ba4a36f..b087553fb8 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -112,6 +112,7 @@ pub mod rwkv_v6; pub mod segformer; pub mod segment_anything; pub mod siglip; +pub mod smol; pub mod snac; pub mod stable_diffusion; pub mod stable_lm; diff --git a/candle-transformers/src/models/smol/README.md b/candle-transformers/src/models/smol/README.md new file mode 100644 index 0000000000..5862c4424f --- /dev/null +++ b/candle-transformers/src/models/smol/README.md @@ -0,0 +1,118 @@ +# SmolLM Model Family + +This directory contains implementations for the SmolLM family of models +developed by HuggingFace. + +## Models + +### SmolLM2 (see `models/llama`) +SmolLM2 models (135M, 360M, 1.7B) use the standard Llama3 architecture +and are implemented in `models/llama.rs`. No separate implementation +is needed. + +**Variants:** +- HuggingFaceTB/SmolLM2-135M +- HuggingFaceTB/SmolLM2-360M +- HuggingFaceTB/SmolLM2-1.7B + +### SmolLM3 +SmolLM3-3B introduces NoPE (No Positional Encoding) which requires +a custom implementation in `smollm3.rs`. + +**Key innovations:** +- Hybrid RoPE/NoPE (3:1 ratio - every 4th layer uses NoPE) +- GQA with 4 groups (32 attention heads, 8 KV heads) +- Very high rope_theta (5M vs typical 10k-500k) +- Long context support (64k-128k tokens) +- Thinking mode support with `` tags + +**Implementations:** +- `smollm3.rs` - Full precision model (safetensors) +- `quantized_smollm3.rs` - Quantized GGUF model with weight reconstruction + +**Available Models:** +- HuggingFaceTB/SmolLM3-3B (Instruct-tuned) +- HuggingFaceTB/SmolLM3-3B-Base (Base model) +- unsloth/SmolLM3-3B-GGUF (Quantized: Q4_K_M, Q8_0, F16) + +### SmolVLM (planned) +Vision-language model variant, to be implemented. + +## Implementation Details + +### NoPE Architecture +SmolLM3 uses a mixed approach to positional encoding: +```rust +pub fn should_skip_rope(&self, layer_idx: usize) -> bool { + // Method 1: Explicit array from config + if let Some(ref no_rope_layers) = self.no_rope_layers { + if layer_idx < no_rope_layers.len() { + return no_rope_layers[layer_idx] == 0; + } + } + + // Method 2: Interval pattern (SmolLM3-3B default) + // Every 4th layer (indices 3, 7, 11, ...) skips RoPE + if let Some(interval) = self.no_rope_layer_interval { + return (layer_idx + 1) % interval == 0; + } + + false // Default: use RoPE +} +``` + +### Quantized Weight Reconstruction +The quantized implementation includes special handling for Q/K weight +reconstruction to maintain compatibility with the GGUF format's +interleaved weight storage. + +### Thinking Mode +SmolLM3 supports explicit reasoning with thinking tags: +- **Enabled**: `<|im_start|>assistant\n\n` (model generates reasoning) +- **Disabled**: `<|im_start|>assistant\n\n\n\n` (skip to answer) + +## Usage Example + +See `examples/smollm3/main.rs` for a unified implementation that supports +both quantized and full precision models with a single codebase. + +```bash +# Quantized model (recommended) +cargo run --release --example smollm3 -- \ + --model-type quantized \ + --quantization q8_0 \ + --prompt "Explain Rust's ownership system" + +# Full precision model +cargo run --release --example smollm3 -- \ + --model-type full \ + --dtype f16 \ + --prompt "Write a sorting algorithm" + +# Enable thinking mode +cargo run --release --example smollm3 -- \ + --thinking \ + --prompt "Solve this logic puzzle step by step" +``` + +## Performance Characteristics + +| Model Type | Size | Speed | Quality | Use Case | +|------------|-------|-------|---------|----------| +| Q4_K_M | 1.9GB | Fast | Good | Resource-constrained | +| Q8_0 | 3.3GB | Fast | Better | Balanced | +| F16 (GGUF) | 6.2GB | Med | Best | High quality GGUF | +| F16 (Safe) | 6.2GB | Med | Best | Maximum quality | +| F32 (Safe) | 12GB | Slow | Best | Research/debugging | + +## Related Models + +### Granite-Docling +Document understanding VLM that originally used SmolLM-2 but now uses +Granite 165M as its language backbone. See IBM's Docling project. + +## References + +- [SmolLM Blog Post](https://huggingface.co/blog/smollm) +- [SmolLM3 Announcement](https://huggingface.co/blog/smollm3) +- [NoPE Paper](https://arxiv.org/abs/2410.01926) - "Length Generalization of Causal Transformers without Position Encoding" diff --git a/candle-transformers/src/models/smol/mod.rs b/candle-transformers/src/models/smol/mod.rs new file mode 100644 index 0000000000..9d68260256 --- /dev/null +++ b/candle-transformers/src/models/smol/mod.rs @@ -0,0 +1,67 @@ +//! SmolLM model family implementations. +//! +//! The SmolLM family consists of efficient language models developed by HuggingFace: +//! - **SmolLM2** (135M, 360M, 1.7B): Uses standard Llama architecture (see `models::llama`) +//! - **SmolLM3** (3B): Introduces hybrid RoPE/NoPE architecture (implemented here) +//! +//! # SmolLM3 Architecture +//! +//! SmolLM3-3B introduces NoPE (No Positional Encoding) as a key innovation: +//! - 3:1 RoPE/NoPE ratio: every 4th layer skips positional encoding +//! - Grouped Query Attention: 32 attention heads, 8 KV heads (4 groups) +//! - High RoPE theta: 5,000,000 (vs typical 10,000-500,000) +//! - Extended context: 64k-128k tokens +//! +//! # Module Structure +//! +//! - [`smollm3`]: Full precision model implementation (safetensors) +//! - [`quantized_smollm3`]: Quantized model implementation (GGUF) +//! +//! # Example Usage +//! +//! ```rust,no_run +//! use candle_transformers::models::smol::smollm3::{Config, ModelForCausalLM}; +//! use candle_transformers::models::smol::quantized_smollm3::QuantizedModelForCausalLM; +//! use candle::{Device, Tensor}; +//! use candle_nn::VarBuilder; +//! +//! # fn main() -> anyhow::Result<()> { +//! let device = Device::Cpu; +//! +//! // Load full precision model +//! let vb = VarBuilder::zeros(candle::DType::F32, &device); +//! let config = Config::default(); +//! let model = ModelForCausalLM::new(&config, vb)?; +//! +//! // Or load quantized model +//! // let model = QuantizedModelForCausalLM::from_gguf(path, &device)?; +//! +//! // Run inference +//! let input = Tensor::new(&[1u32, 2, 3], &device)?.unsqueeze(0)?; +//! let logits = model.forward(&input, 0)?; +//! # Ok(()) +//! # } +//! ``` +//! +//! # Thinking Mode +//! +//! SmolLM3 supports explicit reasoning via thinking tags in chat templates: +//! - Thinking enabled: `<|im_start|>assistant\n\n` (model generates reasoning) +//! - Thinking disabled: `<|im_start|>assistant\n\n\n\n` (skip to answer) +//! +//! # Performance Considerations +//! +//! | Format | Size | Inference Speed | Quality | +//! |--------|-------|-----------------|---------| +//! | Q4_K_M | 1.9GB | Fastest | Good | +//! | Q8_0 | 3.3GB | Fast | Better | +//! | F16 | 6.2GB | Medium | Best | +//! | F32 | 12GB | Slow | Best | +//! +//! # References +//! +//! - [SmolLM3 Model Card](https://huggingface.co/HuggingFaceTB/SmolLM3-3B) +//! - [NoPE Paper](https://arxiv.org/abs/2410.01926) + +pub mod smollm3; +pub mod quantized_smollm3; diff --git a/candle-transformers/src/models/smol/quantized_smollm3.rs b/candle-transformers/src/models/smol/quantized_smollm3.rs new file mode 100644 index 0000000000..4a521ca896 --- /dev/null +++ b/candle-transformers/src/models/smol/quantized_smollm3.rs @@ -0,0 +1,554 @@ +use candle::{DType, Device, Module, Result, Tensor}; +use candle_nn::Activation; +use candle::quantized::gguf_file; +use crate::quantized_var_builder::VarBuilder; +use std::sync::Arc; +use std::io::Write; +use crate::models::with_tracing::QMatMul; +use candle_nn::kv_cache::KvCache; + +const MAX_SEQ_LEN: usize = 4096; +use candle::IndexOp; + +// ===== RECONSTRUCTION FUNCTION ===== +fn reconstruct_qk_weights(gguf_weight: &Tensor, num_heads: usize) -> Result { + let total_rows = gguf_weight.dim(0)?; + let half_rows = total_rows / 2; + let chunk_size = 128; + let chunks_per_half = half_rows / chunk_size; + + let mut heads = Vec::new(); + + // First half + for chunk_idx in 0..chunks_per_half { + let chunk_start = chunk_idx * chunk_size; + + // Even rows + let mut head_even = Vec::new(); + for i in (chunk_start..chunk_start + chunk_size).step_by(2) { + head_even.push(gguf_weight.i(i)?); + } + heads.push(Tensor::stack(&head_even, 0)?); + + // Odd rows + let mut head_odd = Vec::new(); + for i in (chunk_start + 1..chunk_start + chunk_size).step_by(2) { + head_odd.push(gguf_weight.i(i)?); + } + heads.push(Tensor::stack(&head_odd, 0)?); + } + + // Second half + for chunk_idx in 0..chunks_per_half { + let chunk_start = half_rows + chunk_idx * chunk_size; + + // Even rows + let mut head_even = Vec::new(); + for i in (chunk_start..chunk_start + chunk_size).step_by(2) { + head_even.push(gguf_weight.i(i)?); + } + heads.push(Tensor::stack(&head_even, 0)?); + + // Odd rows + let mut head_odd = Vec::new(); + for i in (chunk_start + 1..chunk_start + chunk_size).step_by(2) { + head_odd.push(gguf_weight.i(i)?); + } + heads.push(Tensor::stack(&head_odd, 0)?); + } + + Ok(Tensor::cat(&heads, 0)?) +} + +#[derive(Debug, Clone)] +pub struct QuantizedConfig { + pub vocab_size: usize, + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: usize, + pub max_position_embeddings: usize, + pub rope_theta: f64, + pub rms_norm_eps: f64, + pub rope_dimension_count: usize, + pub no_rope_layer_interval: Option, +} + +impl QuantizedConfig { + /// Load config from GGUF metadata + pub fn from_gguf(ct: &gguf_file::Content) -> Result { + let metadata = &ct.metadata; + + // Helper to get required metadata + let get_u32 = |key: &str| -> Result { + metadata.get(key) + .and_then(|v| v.to_u32().ok()) + .map(|v| v as usize) + .ok_or_else(|| candle::Error::Msg(format!("Missing or invalid metadata key: {}", key))) + }; + + let get_f32 = |key: &str| -> Result { + metadata.get(key) + .and_then(|v| v.to_f32().ok()) + .map(|v| v as f64) + .ok_or_else(|| candle::Error::Msg(format!("Missing or invalid metadata key: {}", key))) + }; + + Ok(Self { + vocab_size: get_u32("smollm3.vocab_size")?, + hidden_size: get_u32("smollm3.embedding_length")?, + intermediate_size: get_u32("smollm3.feed_forward_length")?, + num_hidden_layers: get_u32("smollm3.block_count")?, + num_attention_heads: get_u32("smollm3.attention.head_count")?, + num_key_value_heads: get_u32("smollm3.attention.head_count_kv")?, + max_position_embeddings: get_u32("smollm3.context_length").unwrap_or(MAX_SEQ_LEN), + rope_theta: get_f32("smollm3.rope.freq_base")?, + rms_norm_eps: get_f32("smollm3.attention.layer_norm_rms_epsilon")?, + rope_dimension_count: get_u32("smollm3.rope.dimension_count")?, + no_rope_layer_interval: Some(4), + }) + } + + pub fn should_skip_rope(&self, layer_idx: usize) -> bool { + if let Some(interval) = self.no_rope_layer_interval { + return (layer_idx + 1) % interval == 0; + } + false + } + + pub fn head_dim(&self) -> usize { + self.rope_dimension_count + } +} + +#[derive(Debug, Clone)] +struct RmsNorm { + weight: Tensor, + eps: f64, +} + +impl RmsNorm { + fn new(weight: Tensor, eps: f64) -> Self { + Self { weight, eps } + } + + fn forward(&self, x: &Tensor) -> Result { + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let hidden_size = x.dim(candle::D::Minus1)?; + let x = x.to_dtype(internal_dtype)?; + let norm_x = (x.sqr()?.sum_keepdim(candle::D::Minus1)? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; + let result = x_normed.broadcast_mul(&self.weight)?; + result.to_dtype(x_dtype) + } +} + +#[derive(Debug, Clone)] +pub struct RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +impl RotaryEmbedding { + pub fn new(dtype: DType, cfg: &QuantizedConfig, dev: &Device) -> Result { + let dim = cfg.head_dim(); + let max_seq_len = cfg.max_position_embeddings; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(DType::F32)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + Ok(Self { + sin: freqs.sin()?.to_dtype(dtype)?, + cos: freqs.cos()?.to_dtype(dtype)?, + }) + } + + pub fn apply_rotary_emb(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> { + let (_, _, seq_len, _) = q.dims4()?; + let cos = self.cos.narrow(0, offset, seq_len)?; + let sin = self.sin.narrow(0, offset, seq_len)?; + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; + Ok((q_embed, k_embed)) + } +} + +fn repeat_kv(x: Tensor, n_rep: usize) -> Result { + if n_rep == 1 { + Ok(x) + } else { + let (b, n_kv_heads, seq_len, head_dim) = x.dims4()?; + x.unsqueeze(2)? + .expand(&[b, n_kv_heads, n_rep, seq_len, head_dim])? + .reshape(&[b, n_kv_heads * n_rep, seq_len, head_dim]) + } +} + +#[derive(Debug, Clone)] +struct QuantizedMLP { + gate_proj: QMatMul, + up_proj: QMatMul, + down_proj: QMatMul, +} + +impl QuantizedMLP { + fn new(vb: VarBuilder, _layer_idx: usize) -> Result { + // VarBuilder.get_no_shape() returns Arc which QMatMul::from_weights expects + let gate_proj = QMatMul::from_weights(vb.get_no_shape("ffn_gate.weight")?)?; + let up_proj = QMatMul::from_weights(vb.get_no_shape("ffn_up.weight")?)?; + let down_proj = QMatMul::from_weights(vb.get_no_shape("ffn_down.weight")?)?; + + Ok(Self { + gate_proj, + up_proj, + down_proj, + }) + } + + fn forward(&self, x: &Tensor) -> Result { + let gate = self.gate_proj.forward(x)?.apply(&Activation::Silu)?; + let up = self.up_proj.forward(x)?; + self.down_proj.forward(&(gate * up)?) + } +} + +#[derive(Debug, Clone)] +struct QuantizedAttention { + q_proj: QMatMul, + k_proj: QMatMul, + v_proj: QMatMul, + o_proj: QMatMul, + num_heads: usize, + num_kv_heads: usize, + num_kv_groups: usize, + head_dim: usize, + rotary_emb: Option>, + skip_rope: bool, + kv_cache: KvCache, +} + +impl QuantizedAttention { + fn new( + vb: VarBuilder, + cfg: &QuantizedConfig, + layer_idx: usize, + rotary_emb: Option>, + ) -> Result { + let head_dim = cfg.head_dim(); + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + + // For v and o weights, use directly from VarBuilder (already quantized) + // VarBuilder.get_no_shape() returns Arc + let v_proj = QMatMul::from_weights(vb.get_no_shape("attn_v.weight")?)?; + let o_proj = QMatMul::from_weights(vb.get_no_shape("attn_output.weight")?)?; + + // For q and k weights, we need to dequantize, reconstruct, then re-quantize + // IMPORTANT: Do reconstruction on CPU to avoid VRAM exhaustion during model loading + let device = vb.device(); + let cpu = Device::Cpu; + + let q_weight_qtensor = vb.get_no_shape("attn_q.weight")?; + let q_weight_raw = q_weight_qtensor.dequantize(&cpu)?; // Dequantize to CPU + let q_weight = reconstruct_qk_weights(&q_weight_raw, num_heads)?; // Reconstruct on CPU + let q_weight = q_weight.to_device(device)?; // Move to GPU + + // Re-quantize (now on GPU) + use candle::quantized::{QTensor, GgmlDType}; + let q_weight_qtensor = QTensor::quantize(&q_weight, GgmlDType::Q8_0)?; + drop(q_weight_raw); // Explicitly free CPU memory + drop(q_weight); + + let k_weight_qtensor = vb.get_no_shape("attn_k.weight")?; + let k_weight_raw = k_weight_qtensor.dequantize(&cpu)?; // Dequantize to CPU + let k_weight = reconstruct_qk_weights(&k_weight_raw, num_kv_heads)?; // Reconstruct on CPU + let k_weight = k_weight.to_device(device)?; // Move to GPU + + // Re-quantize (now on GPU) + let k_weight_qtensor = QTensor::quantize(&k_weight, GgmlDType::Q8_0)?; + drop(k_weight_raw); // Explicitly free CPU memory + drop(k_weight); + + let q_proj = QMatMul::from_weights(Arc::new(q_weight_qtensor))?; + let k_proj = QMatMul::from_weights(Arc::new(k_weight_qtensor))?; + + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + num_heads, + num_kv_heads, + num_kv_groups: num_heads / num_kv_heads, + head_dim, + rotary_emb, + skip_rope: cfg.should_skip_rope(layer_idx), + kv_cache: KvCache::new(2, 512), + }) + } + + fn forward( + &mut self, + x: &Tensor, + mask: Option<&Tensor>, + offset: usize, + ) -> Result { + let (b, seq_len, _) = x.dims3()?; + + let q = self.q_proj.forward(x)? + .reshape((b, seq_len, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let k = self.k_proj.forward(x)? + .reshape((b, seq_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let v = self.v_proj.forward(x)? + .reshape((b, seq_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + + let (q, k) = if self.skip_rope { + (q, k) + } else if let Some(rope) = &self.rotary_emb { + rope.apply_rotary_emb(&q, &k, offset)? + } else { + (q, k) + }; + + // can remove this continguous call if using ConcatKV-Cache https://github.com/huggingface/candle/pull/3143 + let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?; + + let k = repeat_kv(k, self.num_kv_groups)?; + let v = repeat_kv(v, self.num_kv_groups)?; + + let scale = 1.0 / (self.head_dim as f64).sqrt(); + // Make q contiguous before matmul to avoid stride mismatch + let q = q.contiguous()?; + let attn_weights = (q.matmul(&k.t()?)? * scale)?; + + let mut attn_weights = match mask { + Some(mask) => attn_weights.broadcast_add(mask)?, + None => attn_weights, + }; + + attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + let attn_output = attn_weights.matmul(&v)?; + + attn_output + .transpose(1, 2)? + .reshape((b, seq_len, self.num_heads * self.head_dim))? + .apply(&self.o_proj) + } + + fn clear_kv_cache(&mut self) { + self.kv_cache.reset(); + } +} + +#[derive(Debug, Clone)] +struct QuantizedDecoderLayer { + self_attn: QuantizedAttention, + mlp: QuantizedMLP, + input_layernorm: RmsNorm, + post_attention_layernorm: RmsNorm, +} + +impl QuantizedDecoderLayer { + fn new( + vb: VarBuilder, + cfg: &QuantizedConfig, + layer_idx: usize, + rotary_emb: Option>, + ) -> Result { + let attn_vb = vb.pp(&format!("blk.{layer_idx}")); + + Ok(Self { + self_attn: QuantizedAttention::new(attn_vb.clone(), cfg, layer_idx, rotary_emb)?, + mlp: QuantizedMLP::new(attn_vb.clone(), layer_idx)?, + input_layernorm: RmsNorm::new( + attn_vb.get_no_shape("attn_norm.weight")?.dequantize(vb.device())?, + cfg.rms_norm_eps, + ), + post_attention_layernorm: RmsNorm::new( + attn_vb.get_no_shape("ffn_norm.weight")?.dequantize(vb.device())?, + cfg.rms_norm_eps, + ), + }) + } + + fn forward( + &mut self, + x: &Tensor, + mask: Option<&Tensor>, + offset: usize, + ) -> Result { + let residual = x; + let x = self.input_layernorm.forward(x)?; + let x = self.self_attn.forward(&x, mask, offset)?; + let x = (residual + x)?; + + let residual = &x; + let x = self.post_attention_layernorm.forward(&x)?; + let x = self.mlp.forward(&x)?; + residual + x + } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache(); + } +} + +#[derive(Debug, Clone)] +pub struct QuantizedModelForCausalLM { + embed_tokens: candle_nn::Embedding, + layers: Vec, + norm: RmsNorm, + lm_head: QMatMul, + device: Device, + config: QuantizedConfig, +} + +impl QuantizedModelForCausalLM { + pub fn from_gguf>(path: P, device: &Device) -> Result { + use candle::quantized::{QTensor, GgmlDType}; + + // Open file once to read metadata + let mut file = std::fs::File::open(path.as_ref())?; + let content = gguf_file::Content::read(&mut file)?; + let config = QuantizedConfig::from_gguf(&content)?; + + // Create VarBuilder for tensor loading + let vb = VarBuilder::from_gguf(path, device)?; + + // Load embedding tensor - dequantize on CPU first to save VRAM + // (will be used for both embed_tokens and lm_head - tied embeddings) + let cpu = Device::Cpu; + let embed_tensor = vb.get_no_shape("token_embd.weight")?.dequantize(&cpu)?; + let embed_tensor_gpu = embed_tensor.to_device(device)?; // Move to GPU for embedding layer + let embed_tokens = candle_nn::Embedding::new(embed_tensor_gpu, config.hidden_size); + + // Create rotary embedding if needed + let needs_rope = (0..config.num_hidden_layers) + .any(|i| !config.should_skip_rope(i)); + let rotary_emb = if needs_rope { + Some(Arc::new(RotaryEmbedding::new( + DType::F32, + &config, + device, + )?)) + } else { + None + }; + + // Load decoder layers + let mut layers = Vec::with_capacity(config.num_hidden_layers); + println!("Loading {} decoder layers...", config.num_hidden_layers); + for layer_idx in 0..config.num_hidden_layers { + if layer_idx % 4 == 0 || layer_idx == config.num_hidden_layers - 1 { + print!(" Layer {}/{}...\r", layer_idx + 1, config.num_hidden_layers); + std::io::stdout().flush().ok(); + } + layers.push(QuantizedDecoderLayer::new( + vb.clone(), + &config, + layer_idx, + rotary_emb.clone(), + )?); + } + println!(" Layer {}/{} - Done! ", config.num_hidden_layers, config.num_hidden_layers); + + // Load output norm + let norm = RmsNorm::new( + vb.get_no_shape("output_norm.weight")?.dequantize(device)?, + config.rms_norm_eps, + ); + + // Load LM head - move CPU embedding tensor to GPU, then quantize + let embed_tensor_for_lm = embed_tensor.to_device(device)?; + let embed_qtensor = QTensor::quantize(&embed_tensor_for_lm, GgmlDType::Q8_0)?; + let lm_head = QMatMul::from_weights(Arc::new(embed_qtensor))?; + drop(embed_tensor); // Free CPU memory + drop(embed_tensor_for_lm); + + Ok(Self { + embed_tokens, + layers, + norm, + lm_head, + device: device.clone(), + config, + }) + } + + pub fn forward(&mut self, input_ids: &Tensor, offset: usize) -> Result { + let (batch_size, seq_len) = input_ids.dims2()?; + + // Embed tokens + let mut hidden_states = self.embed_tokens.forward(input_ids)?; + + // Create causal mask if needed + let mask = if seq_len > 1 { + Some(self.create_causal_mask(batch_size, seq_len, offset)?) + } else { + None + }; + + // Forward through decoder layers + for layer in &mut self.layers { + hidden_states = layer.forward(&hidden_states, mask.as_ref(), offset)?; + } + + // Final norm + hidden_states = self.norm.forward(&hidden_states)?; + + // LM head (only last token for generation) + let last_hidden = hidden_states.narrow(1, seq_len - 1, 1)?; + let logits = last_hidden.apply(&self.lm_head)?; + + Ok(logits) + } + + fn create_causal_mask( + &self, + batch_size: usize, + tgt_len: usize, + offset: usize, + ) -> Result { + let mask: Vec<_> = (0..tgt_len) + .flat_map(|i| { + (0..tgt_len + offset).map(move |j| { + if j <= i + offset { + 0f32 + } else { + f32::NEG_INFINITY + } + }) + }) + .collect(); + + Tensor::from_slice( + &mask, + (batch_size, 1, tgt_len, tgt_len + offset), + &self.device, + ) + } + + pub fn clear_kv_cache(&mut self) { + for layer in &mut self.layers { + layer.clear_kv_cache(); + } + } + + pub fn config(&self) -> &QuantizedConfig { + &self.config + } +} \ No newline at end of file diff --git a/candle-transformers/src/models/smol/smollm3.rs b/candle-transformers/src/models/smol/smollm3.rs new file mode 100644 index 0000000000..917045daa8 --- /dev/null +++ b/candle-transformers/src/models/smol/smollm3.rs @@ -0,0 +1,453 @@ +use crate::{ + models::with_tracing::{linear_b, linear_no_bias, Linear, RmsNorm}, + utils::repeat_kv, +}; +use candle::{DType, Device, Module, Result, Tensor}; +use candle_nn::{kv_cache::KvCache, Activation, VarBuilder}; +use std::sync::Arc; + +#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +pub struct Config { + pub vocab_size: usize, + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: usize, + pub max_position_embeddings: usize, + pub tie_word_embeddings: bool, + pub rope_theta: f64, + pub rms_norm_eps: f64, + pub hidden_act: Activation, + // Optional fields + pub attention_bias: Option, + pub attention_dropout: Option, + pub mlp_bias: Option, + pub sliding_window: Option, + pub use_sliding_window: Option, + pub rope_scaling: Option, + pub bos_token_id: Option, + pub eos_token_id: Option, + pub pad_token_id: Option, + pub max_window_layers: Option, + // SmolLM3-specific: NoPE configuration + pub no_rope_layers: Option>, + pub no_rope_layer_interval: Option, +} + +impl Config { + + pub fn should_skip_rope(&self, layer_idx: usize) -> bool { + // Method 1: Explicit array (some model variants may provide this) + if let Some(ref no_rope_layers) = self.no_rope_layers { + if layer_idx < no_rope_layers.len() { + // 0 = skip RoPE (NoPE), 1 = use RoPE + return no_rope_layers[layer_idx] == 0; + } + } + + // Method 2: Interval pattern (SmolLM3-3B uses this) + // With interval=4: layers 0,1,2 use RoPE; layer 3 skips RoPE (NoPE) + // Pattern: every 4th layer (3,7,11...) skips RoPE + if let Some(interval) = self.no_rope_layer_interval { + return (layer_idx + 1) % interval == 0; + } + + // Default: use RoPE on all layers (standard Llama behavior) + false + } + + /// Calculates head_dim from hidden_size and num_attention_heads + pub fn head_dim(&self) -> usize { + self.hidden_size / self.num_attention_heads + } +} + +#[derive(Debug, Clone)] +pub(crate) struct SmolLM3RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +impl SmolLM3RotaryEmbedding { + pub(crate) fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { + let dim = cfg.head_dim(); + let max_seq_len = cfg.max_position_embeddings; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(DType::F32)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + Ok(Self { + sin: freqs.sin()?.to_dtype(dtype)?, + cos: freqs.cos()?.to_dtype(dtype)?, + }) + } + + /// Apply RoPE (q, k shape: B x H x L x D) + pub(crate) fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> { + let (_, _, seq_len, _) = q.dims4()?; + let cos = self.cos.narrow(0, offset, seq_len)?; + let sin = self.sin.narrow(0, offset, seq_len)?; + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; + Ok((q_embed, k_embed)) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct SmolLM3MLP { + gate_proj: Linear, + up_proj: Linear, + down_proj: Linear, + act_fn: Activation, +} + +impl SmolLM3MLP { + pub(crate) fn new(cfg: &Config, vb: VarBuilder) -> Result { + let mlp_bias = cfg.mlp_bias.unwrap_or(false); + Ok(Self { + gate_proj: linear_b(cfg.hidden_size, cfg.intermediate_size, mlp_bias, vb.pp("gate_proj"))?, + up_proj: linear_b(cfg.hidden_size, cfg.intermediate_size, mlp_bias, vb.pp("up_proj"))?, + down_proj: linear_b(cfg.intermediate_size, cfg.hidden_size, mlp_bias, vb.pp("down_proj"))?, + act_fn: cfg.hidden_act, + }) + } +} + +impl Module for SmolLM3MLP { + fn forward(&self, x: &Tensor) -> Result { + let lhs = x.apply(&self.gate_proj)?.apply(&self.act_fn)?; + let rhs = x.apply(&self.up_proj)?; + (lhs * rhs)?.apply(&self.down_proj) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct SmolLM3Attention { + // projections + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + // hyper params + num_heads: usize, + num_kv_heads: usize, + num_kv_groups: usize, + head_dim: usize, + hidden_size: usize, + // utils + rotary_emb: Option>, + kv_cache: KvCache, + // NoPE flag + skip_rope: bool, +} + +impl SmolLM3Attention { + pub(crate) fn new( + cfg: &Config, + layer_idx: usize, + rotary_emb: Option>, + vb: VarBuilder, + ) -> Result { + let use_sliding_window = cfg.use_sliding_window.unwrap_or(false); + if use_sliding_window { + candle::bail!("sliding window is not supported") + } + + let head_dim = cfg.head_dim(); + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + let num_kv_groups = num_heads / num_kv_heads; + + let attention_bias = cfg.attention_bias.unwrap_or(false); + + let q_proj = linear_b( + cfg.hidden_size, + num_heads * head_dim, + attention_bias, + vb.pp("q_proj"), + )?; + + let k_proj = linear_b( + cfg.hidden_size, + num_kv_heads * head_dim, + attention_bias, + vb.pp("k_proj"), + )?; + + let v_proj = linear_b( + cfg.hidden_size, + num_kv_heads * head_dim, + attention_bias, + vb.pp("v_proj"), + )?; + let o_proj = linear_b( + num_heads * head_dim, + cfg.hidden_size, + attention_bias, + vb.pp("o_proj"), + )?; + + // Necessary because the hidden_size in the config isn't always accurate + let hidden_size = head_dim * cfg.num_attention_heads; + + // Initialize KV cache with 512 tokens capacity to reduce initial memory allocation. + // The cache will grow in chunks of 512 tokens when needed. + let kv_cache = KvCache::new(2, 512); + + // Check if this layer should skip RoPE (NoPE) + let skip_rope = cfg.should_skip_rope(layer_idx); + + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + num_heads, + num_kv_heads, + num_kv_groups, + head_dim, + hidden_size, + rotary_emb, + kv_cache, + skip_rope, + }) + } + + pub(crate) fn forward( + &mut self, + x: &Tensor, + attn_mask: Option<&Tensor>, + offset: usize, + ) -> Result { + let (b, l, _) = x.dims3()?; + + // 1. Proj + let q = self.q_proj.forward(x)?; + let k = self.k_proj.forward(x)?; + let v = self.v_proj.forward(x)?; + + // 2. Reshape: (B, L, H, D) -> (B, H, L, D) + let q = q + .reshape((b, l, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let k = k + .reshape((b, l, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let v = v + .reshape((b, l, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + + // 3. RoPE - only apply if this layer should use RoPE (not NoPE) + let (q, k) = if self.skip_rope { + // NoPE: Skip rotary embeddings, but ensure tensors are contiguous + (q.contiguous()?, k.contiguous()?) + } else { + // Apply RoPE + if let Some(ref rope) = self.rotary_emb { + rope.apply(&q, &k, offset)? + } else { + (q, k) + } + }; + + // 4. Accumulate KV cache + // Reset KV cache if we're at the first position + if offset == 0 { + self.kv_cache.reset(); + } + let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?; + + // 5. GQA repeat_kv + let k = repeat_kv(k, self.num_kv_groups)?; + let v = repeat_kv(v, self.num_kv_groups)?; + + // 6. Attention score + let scale = 1.0 / (self.head_dim as f64).sqrt(); + let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?; + if let Some(m) = attn_mask { + scores = scores.broadcast_add(m)?; + } + let probs = candle_nn::ops::softmax_last_dim(&scores)?; + let ctx = probs.matmul(&v)?; // (B, H, L, D) + + // 7. Output proj + ctx.transpose(1, 2)? + .reshape((b, l, self.hidden_size))? + .apply(&self.o_proj) + } + + pub fn clear_kv_cache(&mut self) { + self.kv_cache.reset(); + } +} + +#[derive(Debug, Clone)] +struct DecoderLayer { + self_attn: SmolLM3Attention, + mlp: SmolLM3MLP, + ln1: RmsNorm, + ln2: RmsNorm, +} + +impl DecoderLayer { + fn new( + cfg: &Config, + layer_idx: usize, + rotary: Option>, + vb: VarBuilder, + ) -> Result { + let self_attn = SmolLM3Attention::new(cfg, layer_idx, rotary, vb.pp("self_attn"))?; + let mlp = SmolLM3MLP::new(cfg, vb.pp("mlp"))?; + let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let ln2 = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + Ok(Self { + self_attn, + mlp, + ln1, + ln2, + }) + } + + fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result { + let h = self.ln1.forward(x)?; + let h = self.self_attn.forward(&h, mask, offset)?; + let x = (x + h)?; + let h2 = self.ln2.forward(&x)?; + let h2 = h2.apply(&self.mlp)?; + x + h2 + } + + pub fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache(); + } +} + +#[derive(Debug, Clone)] +pub struct Model { + pub(crate) embed_tokens: candle_nn::Embedding, + pub(crate) layers: Vec, + pub(crate) norm: RmsNorm, + device: Device, + dtype: DType, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let embed_tokens = + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; + + // Only create rotary embedding if at least one layer uses RoPE + let needs_rope = (0..cfg.num_hidden_layers).any(|i| !cfg.should_skip_rope(i)); + let rotary = if needs_rope { + Some(Arc::new(SmolLM3RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?)) + } else { + None + }; + + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb.pp("model.layers"); + for i in 0..cfg.num_hidden_layers { + layers.push(DecoderLayer::new(cfg, i, rotary.clone(), vb_l.pp(i))?); + } + Ok(Self { + embed_tokens, + layers, + norm: RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?, + device: vb.device().clone(), + dtype: vb.dtype(), + }) + } + + pub fn clear_kv_cache(&mut self) { + for l in &mut self.layers { + l.clear_kv_cache(); + } + } + + fn causal_mask( + &self, + b: usize, + tgt: usize, + offset: usize, + sw: Option, + ) -> Result { + let minf = f32::NEG_INFINITY; + let mask: Vec<_> = (0..tgt) + .flat_map(|i| { + (0..(tgt + offset)).map(move |j| { + let past_ok = j <= i + offset; + let sw_ok = match sw { + Some(w) => (i + offset) as i64 - j as i64 <= w as i64, + None => true, + }; + if past_ok && sw_ok { + 0. + } else { + minf + } + }) + }) + .collect(); + Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype) + } + + pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result { + let (b, l) = input.dims2()?; + + let mut h = self.embed_tokens.forward(input)?; + + let causal = if l == 1 { + None + } else { + Some(self.causal_mask(b, l, offset, None)?) + }; + + for layer in &mut self.layers { + h = layer.forward(&h, causal.as_ref(), offset)?; + } + self.norm.forward(&h) + } +} + +#[derive(Debug, Clone)] +pub struct ModelForCausalLM { + base: Model, + lm_head: Linear, +} + +impl ModelForCausalLM { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let base = Model::new(cfg, vb.clone())?; + let lm_head = if cfg.tie_word_embeddings { + Linear::from_weights(base.embed_tokens.embeddings().clone(), None) + } else { + linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))? + }; + Ok(Self { base, lm_head }) + } + + pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result { + let (_, l) = input.dims2()?; + + self.base + .forward(input, offset)? + .narrow(1, l - 1, 1)? + .apply(&self.lm_head) + + } + + pub fn clear_kv_cache(&mut self) { + self.base.clear_kv_cache(); + } +} \ No newline at end of file From ebfc4562c110ba0d9cfe4adff95f1ec4d8a5ffab Mon Sep 17 00:00:00 2001 From: DrJesseGlass Date: Wed, 12 Nov 2025 13:43:47 -0500 Subject: [PATCH 2/9] include chrono for prompt --- candle-examples/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 201412c0ec..44ef1c9ac6 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -18,8 +18,8 @@ candle-transformers = { workspace = true } candle-flash-attn = { workspace = true, optional = true } candle-onnx = { workspace = true, optional = true } +chrono = "0.4" csv = "1.3.0" -crono = "0.4.0" cudarc = { workspace = true, optional = true } half = { workspace = true, optional = true } hf-hub = { workspace = true, features = ["tokio"] } From f8168ed65da802b93bad442ef43a13ab58a28b80 Mon Sep 17 00:00:00 2001 From: DrJesseGlass Date: Wed, 12 Nov 2025 14:09:25 -0500 Subject: [PATCH 3/9] resolve pub consist and unused var --- candle-transformers/src/models/smol/quantized_smollm3.rs | 2 +- candle-transformers/src/models/smol/smollm3.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/candle-transformers/src/models/smol/quantized_smollm3.rs b/candle-transformers/src/models/smol/quantized_smollm3.rs index 4a521ca896..5e3013e788 100644 --- a/candle-transformers/src/models/smol/quantized_smollm3.rs +++ b/candle-transformers/src/models/smol/quantized_smollm3.rs @@ -11,7 +11,7 @@ const MAX_SEQ_LEN: usize = 4096; use candle::IndexOp; // ===== RECONSTRUCTION FUNCTION ===== -fn reconstruct_qk_weights(gguf_weight: &Tensor, num_heads: usize) -> Result { +fn reconstruct_qk_weights(gguf_weight: &Tensor, _num_heads: usize) -> Result { let total_rows = gguf_weight.dim(0)?; let half_rows = total_rows / 2; let chunk_size = 128; diff --git a/candle-transformers/src/models/smol/smollm3.rs b/candle-transformers/src/models/smol/smollm3.rs index 917045daa8..0670ed0e0b 100644 --- a/candle-transformers/src/models/smol/smollm3.rs +++ b/candle-transformers/src/models/smol/smollm3.rs @@ -289,7 +289,7 @@ impl SmolLM3Attention { } #[derive(Debug, Clone)] -struct DecoderLayer { +pub(crate) struct DecoderLayer { self_attn: SmolLM3Attention, mlp: SmolLM3MLP, ln1: RmsNorm, From 8447af4fb7aefcc4d9fc5459fe015080f3e4cc6f Mon Sep 17 00:00:00 2001 From: DrJesseGlass Date: Wed, 12 Nov 2025 14:24:21 -0500 Subject: [PATCH 4/9] formatted --- candle-examples/examples/smollm3/main.rs | 42 ++++++--- candle-transformers/src/models/smol/mod.rs | 2 +- .../src/models/smol/quantized_smollm3.rs | 91 +++++++++++-------- .../src/models/smol/smollm3.rs | 31 +++++-- 4 files changed, 107 insertions(+), 59 deletions(-) diff --git a/candle-examples/examples/smollm3/main.rs b/candle-examples/examples/smollm3/main.rs index cb2418f1fd..397417121e 100644 --- a/candle-examples/examples/smollm3/main.rs +++ b/candle-examples/examples/smollm3/main.rs @@ -47,7 +47,7 @@ impl SmolLM3Model { num_attention_heads: cfg.num_attention_heads, num_key_value_heads: cfg.num_key_value_heads, rope_theta: cfg.rope_theta as f32, // Convert f64 to f32 - eos_token_id: Some(128012), // Default SmolLM3 EOS + eos_token_id: Some(128012), // Default SmolLM3 EOS no_rope_layers: None, no_rope_layer_interval: None, } @@ -61,7 +61,10 @@ impl SmolLM3Model { num_key_value_heads: cfg.num_key_value_heads, rope_theta: cfg.rope_theta as f32, // Convert f64 to f32 eos_token_id: cfg.eos_token_id, - no_rope_layers: cfg.no_rope_layers.as_ref().map(|v| v.iter().map(|&x| x as u32).collect()), // Convert Vec to Vec + no_rope_layers: cfg + .no_rope_layers + .as_ref() + .map(|v| v.iter().map(|&x| x as u32).collect()), // Convert Vec to Vec no_rope_layer_interval: cfg.no_rope_layer_interval, } } @@ -313,13 +316,17 @@ fn format_prompt(prompt: &str, use_chat_template: bool, enable_thinking: bool) - let today_date = now.format("%d %B %Y").to_string(); // Set reasoning mode based on thinking flag - let reasoning_mode = if enable_thinking { "/think" } else { "/no_think" }; + let reasoning_mode = if enable_thinking { + "/think" + } else { + "/no_think" + }; // Build the assistant start with or without thinking tags let assistant_start = if enable_thinking { - "<|im_start|>assistant\n\n" // Open for reasoning + "<|im_start|>assistant\n\n" // Open for reasoning } else { - "<|im_start|>assistant\n\n\n\n" // Empty = skip reasoning + "<|im_start|>assistant\n\n\n\n" // Empty = skip reasoning }; format!( @@ -337,10 +344,7 @@ You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\ <|im_start|>user\n\ {}<|im_end|>\n\ {}", - today_date, - reasoning_mode, - prompt, - assistant_start + today_date, reasoning_mode, prompt, assistant_start ) } else { prompt.to_string() @@ -381,8 +385,22 @@ fn run_generation( println!("\n=== Generation Settings ==="); println!("Model type: {:?}", args.model_type); - println!("Chat template: {}", if use_chat_template { "enabled" } else { "disabled" }); - println!("Thinking mode: {}", if args.thinking { "enabled (/think)" } else { "disabled (/no_think)" }); + println!( + "Chat template: {}", + if use_chat_template { + "enabled" + } else { + "disabled" + } + ); + println!( + "Thinking mode: {}", + if args.thinking { + "enabled (/think)" + } else { + "disabled (/no_think)" + } + ); println!("Raw prompt: {}", prompt_str); // Encode prompt @@ -597,4 +615,4 @@ fn main() -> Result<()> { run_generation(&mut model, tokenizer, &args, &device)?; Ok(()) -} \ No newline at end of file +} diff --git a/candle-transformers/src/models/smol/mod.rs b/candle-transformers/src/models/smol/mod.rs index 9d68260256..b3744385e0 100644 --- a/candle-transformers/src/models/smol/mod.rs +++ b/candle-transformers/src/models/smol/mod.rs @@ -63,5 +63,5 @@ //! - [SmolLM3 Model Card](https://huggingface.co/HuggingFaceTB/SmolLM3-3B) //! - [NoPE Paper](https://arxiv.org/abs/2410.01926) -pub mod smollm3; pub mod quantized_smollm3; +pub mod smollm3; diff --git a/candle-transformers/src/models/smol/quantized_smollm3.rs b/candle-transformers/src/models/smol/quantized_smollm3.rs index 5e3013e788..7bbc88f7c3 100644 --- a/candle-transformers/src/models/smol/quantized_smollm3.rs +++ b/candle-transformers/src/models/smol/quantized_smollm3.rs @@ -1,11 +1,11 @@ +use crate::models::with_tracing::QMatMul; +use crate::quantized_var_builder::VarBuilder; +use candle::quantized::gguf_file; use candle::{DType, Device, Module, Result, Tensor}; +use candle_nn::kv_cache::KvCache; use candle_nn::Activation; -use candle::quantized::gguf_file; -use crate::quantized_var_builder::VarBuilder; -use std::sync::Arc; use std::io::Write; -use crate::models::with_tracing::QMatMul; -use candle_nn::kv_cache::KvCache; +use std::sync::Arc; const MAX_SEQ_LEN: usize = 4096; use candle::IndexOp; @@ -82,17 +82,23 @@ impl QuantizedConfig { // Helper to get required metadata let get_u32 = |key: &str| -> Result { - metadata.get(key) + metadata + .get(key) .and_then(|v| v.to_u32().ok()) .map(|v| v as usize) - .ok_or_else(|| candle::Error::Msg(format!("Missing or invalid metadata key: {}", key))) + .ok_or_else(|| { + candle::Error::Msg(format!("Missing or invalid metadata key: {}", key)) + }) }; let get_f32 = |key: &str| -> Result { - metadata.get(key) + metadata + .get(key) .and_then(|v| v.to_f32().ok()) .map(|v| v as f64) - .ok_or_else(|| candle::Error::Msg(format!("Missing or invalid metadata key: {}", key))) + .ok_or_else(|| { + candle::Error::Msg(format!("Missing or invalid metadata key: {}", key)) + }) }; Ok(Self { @@ -174,7 +180,12 @@ impl RotaryEmbedding { }) } - pub fn apply_rotary_emb(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> { + pub fn apply_rotary_emb( + &self, + q: &Tensor, + k: &Tensor, + offset: usize, + ) -> Result<(Tensor, Tensor)> { let (_, _, seq_len, _) = q.dims4()?; let cos = self.cos.narrow(0, offset, seq_len)?; let sin = self.sin.narrow(0, offset, seq_len)?; @@ -265,7 +276,7 @@ impl QuantizedAttention { let q_weight = q_weight.to_device(device)?; // Move to GPU // Re-quantize (now on GPU) - use candle::quantized::{QTensor, GgmlDType}; + use candle::quantized::{GgmlDType, QTensor}; let q_weight_qtensor = QTensor::quantize(&q_weight, GgmlDType::Q8_0)?; drop(q_weight_raw); // Explicitly free CPU memory drop(q_weight); @@ -298,21 +309,22 @@ impl QuantizedAttention { }) } - fn forward( - &mut self, - x: &Tensor, - mask: Option<&Tensor>, - offset: usize, - ) -> Result { + fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result { let (b, seq_len, _) = x.dims3()?; - let q = self.q_proj.forward(x)? + let q = self + .q_proj + .forward(x)? .reshape((b, seq_len, self.num_heads, self.head_dim))? .transpose(1, 2)?; - let k = self.k_proj.forward(x)? + let k = self + .k_proj + .forward(x)? .reshape((b, seq_len, self.num_kv_heads, self.head_dim))? .transpose(1, 2)?; - let v = self.v_proj.forward(x)? + let v = self + .v_proj + .forward(x)? .reshape((b, seq_len, self.num_kv_heads, self.head_dim))? .transpose(1, 2)?; @@ -375,22 +387,21 @@ impl QuantizedDecoderLayer { self_attn: QuantizedAttention::new(attn_vb.clone(), cfg, layer_idx, rotary_emb)?, mlp: QuantizedMLP::new(attn_vb.clone(), layer_idx)?, input_layernorm: RmsNorm::new( - attn_vb.get_no_shape("attn_norm.weight")?.dequantize(vb.device())?, + attn_vb + .get_no_shape("attn_norm.weight")? + .dequantize(vb.device())?, cfg.rms_norm_eps, ), post_attention_layernorm: RmsNorm::new( - attn_vb.get_no_shape("ffn_norm.weight")?.dequantize(vb.device())?, + attn_vb + .get_no_shape("ffn_norm.weight")? + .dequantize(vb.device())?, cfg.rms_norm_eps, ), }) } - fn forward( - &mut self, - x: &Tensor, - mask: Option<&Tensor>, - offset: usize, - ) -> Result { + fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result { let residual = x; let x = self.input_layernorm.forward(x)?; let x = self.self_attn.forward(&x, mask, offset)?; @@ -419,7 +430,7 @@ pub struct QuantizedModelForCausalLM { impl QuantizedModelForCausalLM { pub fn from_gguf>(path: P, device: &Device) -> Result { - use candle::quantized::{QTensor, GgmlDType}; + use candle::quantized::{GgmlDType, QTensor}; // Open file once to read metadata let mut file = std::fs::File::open(path.as_ref())?; @@ -437,14 +448,9 @@ impl QuantizedModelForCausalLM { let embed_tokens = candle_nn::Embedding::new(embed_tensor_gpu, config.hidden_size); // Create rotary embedding if needed - let needs_rope = (0..config.num_hidden_layers) - .any(|i| !config.should_skip_rope(i)); + let needs_rope = (0..config.num_hidden_layers).any(|i| !config.should_skip_rope(i)); let rotary_emb = if needs_rope { - Some(Arc::new(RotaryEmbedding::new( - DType::F32, - &config, - device, - )?)) + Some(Arc::new(RotaryEmbedding::new(DType::F32, &config, device)?)) } else { None }; @@ -454,7 +460,11 @@ impl QuantizedModelForCausalLM { println!("Loading {} decoder layers...", config.num_hidden_layers); for layer_idx in 0..config.num_hidden_layers { if layer_idx % 4 == 0 || layer_idx == config.num_hidden_layers - 1 { - print!(" Layer {}/{}...\r", layer_idx + 1, config.num_hidden_layers); + print!( + " Layer {}/{}...\r", + layer_idx + 1, + config.num_hidden_layers + ); std::io::stdout().flush().ok(); } layers.push(QuantizedDecoderLayer::new( @@ -464,7 +474,10 @@ impl QuantizedModelForCausalLM { rotary_emb.clone(), )?); } - println!(" Layer {}/{} - Done! ", config.num_hidden_layers, config.num_hidden_layers); + println!( + " Layer {}/{} - Done! ", + config.num_hidden_layers, config.num_hidden_layers + ); // Load output norm let norm = RmsNorm::new( @@ -551,4 +564,4 @@ impl QuantizedModelForCausalLM { pub fn config(&self) -> &QuantizedConfig { &self.config } -} \ No newline at end of file +} diff --git a/candle-transformers/src/models/smol/smollm3.rs b/candle-transformers/src/models/smol/smollm3.rs index 0670ed0e0b..f006cdd797 100644 --- a/candle-transformers/src/models/smol/smollm3.rs +++ b/candle-transformers/src/models/smol/smollm3.rs @@ -36,7 +36,6 @@ pub struct Config { } impl Config { - pub fn should_skip_rope(&self, layer_idx: usize) -> bool { // Method 1: Explicit array (some model variants may provide this) if let Some(ref no_rope_layers) = self.no_rope_layers { @@ -112,9 +111,24 @@ impl SmolLM3MLP { pub(crate) fn new(cfg: &Config, vb: VarBuilder) -> Result { let mlp_bias = cfg.mlp_bias.unwrap_or(false); Ok(Self { - gate_proj: linear_b(cfg.hidden_size, cfg.intermediate_size, mlp_bias, vb.pp("gate_proj"))?, - up_proj: linear_b(cfg.hidden_size, cfg.intermediate_size, mlp_bias, vb.pp("up_proj"))?, - down_proj: linear_b(cfg.intermediate_size, cfg.hidden_size, mlp_bias, vb.pp("down_proj"))?, + gate_proj: linear_b( + cfg.hidden_size, + cfg.intermediate_size, + mlp_bias, + vb.pp("gate_proj"), + )?, + up_proj: linear_b( + cfg.hidden_size, + cfg.intermediate_size, + mlp_bias, + vb.pp("up_proj"), + )?, + down_proj: linear_b( + cfg.intermediate_size, + cfg.hidden_size, + mlp_bias, + vb.pp("down_proj"), + )?, act_fn: cfg.hidden_act, }) } @@ -350,7 +364,11 @@ impl Model { // Only create rotary embedding if at least one layer uses RoPE let needs_rope = (0..cfg.num_hidden_layers).any(|i| !cfg.should_skip_rope(i)); let rotary = if needs_rope { - Some(Arc::new(SmolLM3RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?)) + Some(Arc::new(SmolLM3RotaryEmbedding::new( + vb.dtype(), + cfg, + vb.device(), + )?)) } else { None }; @@ -444,10 +462,9 @@ impl ModelForCausalLM { .forward(input, offset)? .narrow(1, l - 1, 1)? .apply(&self.lm_head) - } pub fn clear_kv_cache(&mut self) { self.base.clear_kv_cache(); } -} \ No newline at end of file +} From fcb22b4bcaba208dc3cebc8643f6d40c08d2d8cc Mon Sep 17 00:00:00 2001 From: DrJesseGlass Date: Wed, 12 Nov 2025 14:26:06 -0500 Subject: [PATCH 5/9] last spacing in format --- candle-transformers/src/models/smol/smollm3.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-transformers/src/models/smol/smollm3.rs b/candle-transformers/src/models/smol/smollm3.rs index f006cdd797..f6b3120d78 100644 --- a/candle-transformers/src/models/smol/smollm3.rs +++ b/candle-transformers/src/models/smol/smollm3.rs @@ -463,7 +463,7 @@ impl ModelForCausalLM { .narrow(1, l - 1, 1)? .apply(&self.lm_head) } - + pub fn clear_kv_cache(&mut self) { self.base.clear_kv_cache(); } From ce922dcb5e4ddce9ad0a10c7f49fa70ebdba7396 Mon Sep 17 00:00:00 2001 From: DrJesseGlass Date: Wed, 12 Nov 2025 15:40:39 -0500 Subject: [PATCH 6/9] add credits --- candle-transformers/src/models/smol/README.md | 157 +++++++++++++++++- .../src/models/smol/smollm3.rs | 2 +- 2 files changed, 150 insertions(+), 9 deletions(-) diff --git a/candle-transformers/src/models/smol/README.md b/candle-transformers/src/models/smol/README.md index 5862c4424f..5a9e260c9b 100644 --- a/candle-transformers/src/models/smol/README.md +++ b/candle-transformers/src/models/smol/README.md @@ -105,14 +105,155 @@ cargo run --release --example smollm3 -- \ | F16 (Safe) | 6.2GB | Med | Best | Maximum quality | | F32 (Safe) | 12GB | Slow | Best | Research/debugging | -## Related Models +# Credits & Attribution -### Granite-Docling -Document understanding VLM that originally used SmolLM-2 but now uses -Granite 165M as its language backbone. See IBM's Docling project. +## SmolLM3 Model -## References +### Developers +**HuggingFace Team (HuggingFaceTB)** -- [SmolLM Blog Post](https://huggingface.co/blog/smollm) -- [SmolLM3 Announcement](https://huggingface.co/blog/smollm3) -- [NoPE Paper](https://arxiv.org/abs/2410.01926) - "Length Generalization of Causal Transformers without Position Encoding" +The SmolLM family of models represents cutting-edge work in efficient language models, demonstrating that small models can achieve impressive capabilities when trained on high-quality data. + +### Resources +- **Model Card**: https://huggingface.co/HuggingFaceTB/SmolLM3-3B +- **Model Card (Base)**: https://huggingface.co/HuggingFaceTB/SmolLM3-3B-Base +- **Collection**: https://huggingface.co/collections/HuggingFaceTB/smollm3-6723884a9c35673e4f9b74a2 +- **Blog Post**: https://huggingface.co/blog/smollm3 +- **GitHub Repository**: https://github.com/huggingface/smollm +- **License**: Apache 2.0 + +### Key Contributors +The SmolLM project is developed by the HuggingFace team with contributions from researchers focused on efficient LLM architectures and training methods. + +## NoPE Architecture + +### Research Paper +**Title**: "Length Generalization of Causal Transformers without Position Encoding" + +**Authors**: +- Jie Wang (Fudan University) +- Tao Ji (Fudan University) +- Yuanbin Wu (Fudan University) +- Hang Yan (Fudan University) +- Tao Gui (Fudan University) +- Qi Zhang (Fudan University) +- Xuanjing Huang (Fudan University) +- Xiaoling Wang (Fudan University) + +**Published**: NeurIPS 2024 (Thirty-Eighth Annual Conference on Neural Information Processing Systems) + +**Abstract Summary**: The paper demonstrates that removing positional encoding from selected layers (NoPE - No Positional Encoding) can improve length generalization in causal transformers while maintaining or improving performance. SmolLM3 implements this with a 3:1 RoPE/NoPE ratio. + +**Resources**: +- **arXiv**: https://arxiv.org/abs/2410.01926 +- **Conference**: NeurIPS 2024 + +### Key Innovation +The hybrid approach uses: +- **RoPE layers** (75%): Standard rotary positional embeddings for local context +- **NoPE layers** (25%): No positional encoding for improved length generalization +- **Pattern**: Every 4th layer uses NoPE (layers 3, 7, 11, 15, etc.) + +This architecture enables SmolLM3 to handle much longer contexts (64k-128k tokens) while maintaining efficiency. + +## Quantized Models + +### Unsloth +Quantized GGUF models are provided by **Unsloth**, a team focused on making LLM inference and fine-tuning more accessible. + +**Resources**: +- **GGUF Repository**: https://huggingface.co/unsloth/SmolLM3-3B-GGUF +- **Available Quantizations**: Q4_K_M, Q8_0, F16 +- **Website**: https://unsloth.ai/ + +The quantization work enables running SmolLM3 efficiently on consumer hardware with minimal quality loss. + +## Implementation Credits + +### This Candle Implementation +**Implemented for**: Candle ML Framework +**Implementation Date**: Nov 2025 +**Features**: +- Full precision model (F32/F16/BF16) +- Quantized model (Q4_K_M/Q8_0/F16 GGUF) +- Unified example supporting both +- Verified against reference implementations + +**Verification**: +- Full precision: Validated against HuggingFace Transformers Python implementation +- Quantized: Validated against llama.cpp implementation + +### Related Tools & Frameworks + +**Candle**: Minimalist ML framework in Rust by HuggingFace +- GitHub: https://github.com/huggingface/candle + +**llama.cpp**: Efficient LLM inference in C/C++ +- GitHub: https://github.com/ggerganov/llama.cpp +- Used for quantized model verification + +**HuggingFace Transformers**: Reference Python implementation +- GitHub: https://github.com/huggingface/transformers +- Used for full model verification + +## Acknowledgments + +Special thanks to: + +1. **HuggingFace Team** - For developing SmolLM3 and making it openly available under Apache 2.0 license +2. **NoPE Researchers** - For advancing the field with novel positional encoding approaches +3. **Unsloth** - For providing optimized quantized versions +4. **Candle Contributors** - For building an excellent ML framework in Rust +5. **Open Source Community** - For tools like llama.cpp that enable verification and benchmarking + +## Citation + +If you use SmolLM3 in your research or applications, please cite: + +### SmolLM3 Model +```bibtex +@misc{smollm3, + title={SmolLM3}, + author={HuggingFace Team}, + year={2024}, + publisher={HuggingFace}, + howpublished={\url{https://huggingface.co/HuggingFaceTB/SmolLM3-3B}} +} +``` + +### NoPE Paper +```bibtex +@inproceedings{wang2024length, + title={Length Generalization of Causal Transformers without Position Encoding}, + author={Wang, Jie and Ji, Tao and Wu, Yuanbin and Yan, Hang and Gui, Tao and Zhang, Qi and Huang, Xuanjing and Wang, Xiaoling}, + booktitle={Thirty-Eighth Annual Conference on Neural Information Processing Systems}, + year={2024} +} +``` + +### Candle Framework +```bibtex +@software{candle, + title={Candle: Minimalist ML Framework}, + author={HuggingFace}, + year={2024}, + url={https://github.com/huggingface/candle} +} +``` + +## License + +- **SmolLM3 Model**: Apache 2.0 +- **This Implementation**: Follows Candle framework license +- **Candle Framework**: Apache 2.0 and MIT dual-licensed + +## Further Reading + +- **SmolLM Blog Series**: https://huggingface.co/blog/smollm and https://huggingface.co/blog/smollm3 +- **Model Card Details**: https://huggingface.co/HuggingFaceTB/SmolLM3-3B +- **NoPE Paper**: https://arxiv.org/abs/2410.01926 +- **Candle Documentation**: https://huggingface.github.io/candle/ + +--- + +This implementation stands on the shoulders of giants. Thank you to all the researchers, engineers, and open source contributors who make this work possible. diff --git a/candle-transformers/src/models/smol/smollm3.rs b/candle-transformers/src/models/smol/smollm3.rs index f6b3120d78..f006cdd797 100644 --- a/candle-transformers/src/models/smol/smollm3.rs +++ b/candle-transformers/src/models/smol/smollm3.rs @@ -463,7 +463,7 @@ impl ModelForCausalLM { .narrow(1, l - 1, 1)? .apply(&self.lm_head) } - + pub fn clear_kv_cache(&mut self) { self.base.clear_kv_cache(); } From 9621cbf2b84b7268d2b50b9486d1710589af97ed Mon Sep 17 00:00:00 2001 From: DrJesseGlass Date: Wed, 3 Dec 2025 12:42:55 -0500 Subject: [PATCH 7/9] chat template --- candle-examples/Cargo.toml | 1 + candle-examples/src/chat_template.rs | 532 +++++++++++++++++++++++++++ candle-examples/src/lib.rs | 1 + 3 files changed, 534 insertions(+) create mode 100644 candle-examples/src/chat_template.rs diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 44ef1c9ac6..b8b57dc155 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -26,6 +26,7 @@ hf-hub = { workspace = true, features = ["tokio"] } image = { workspace = true } intel-mkl-src = { workspace = true, optional = true } num-traits = { workspace = true } +minijinja = { version = "2", features = ["loader"] } palette = { version = "0.7.6", optional = true } enterpolation = { version = "0.2.1", optional = true } pyo3 = { version = "0.27", features = [ diff --git a/candle-examples/src/chat_template.rs b/candle-examples/src/chat_template.rs new file mode 100644 index 0000000000..a131ca9858 --- /dev/null +++ b/candle-examples/src/chat_template.rs @@ -0,0 +1,532 @@ +//! Chat template support for LLM examples +//! +//! This module provides Jinja-based chat template rendering compatible with +//! HuggingFace's `tokenizer.apply_chat_template()` functionality. +//! +//! # Example +//! +//! ```no_run +//! use candle_examples::chat_template::{ChatTemplate, Message, Conversation}; +//! +//! // Load template from a model's tokenizer_config.json +//! let template = ChatTemplate::from_tokenizer_config("path/to/tokenizer_config.json")?; +//! +//! // Or use a preset for known models +//! let template = ChatTemplate::chatml(); // SmolLM, Qwen, etc. +//! +//! // Single-turn +//! let messages = vec![ +//! Message::system("You are helpful."), +//! Message::user("Hello!"), +//! ]; +//! let prompt = template.apply(&messages, true)?; +//! +//! // Multi-turn conversation +//! let mut conv = Conversation::new(template, "You are helpful."); +//! let prompt = conv.user_turn("Hello!")?; +//! // ... generate response ... +//! conv.assistant_response("Hi there!"); +//! let prompt = conv.user_turn("How are you?")?; +//! ``` + +use minijinja::{context, Environment}; +use serde::{Deserialize, Serialize}; +use std::path::Path; + +/// A chat message with role and content +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Message { + pub role: String, + pub content: String, +} + +impl Message { + pub fn new(role: impl Into, content: impl Into) -> Self { + Self { + role: role.into(), + content: content.into(), + } + } + + pub fn system(content: impl Into) -> Self { + Self::new("system", content) + } + + pub fn user(content: impl Into) -> Self { + Self::new("user", content) + } + + pub fn assistant(content: impl Into) -> Self { + Self::new("assistant", content) + } +} + +/// Options for applying a chat template +#[derive(Debug, Clone, Default)] +pub struct ChatTemplateOptions { + /// Add tokens that prompt the model to generate an assistant response + pub add_generation_prompt: bool, + /// Continue the final message instead of starting a new one (for prefilling) + pub continue_final_message: bool, + /// Enable thinking/reasoning mode (adds tags) + pub enable_thinking: bool, + /// Custom variables to pass to the template + pub extra_context: std::collections::HashMap, +} + +impl ChatTemplateOptions { + pub fn for_generation() -> Self { + Self { + add_generation_prompt: true, + ..Default::default() + } + } + + pub fn for_training() -> Self { + Self { + add_generation_prompt: false, + ..Default::default() + } + } + + pub fn with_thinking(mut self) -> Self { + self.enable_thinking = true; + self + } +} + +/// Token configuration loaded from tokenizer_config.json +#[derive(Debug, Clone, Default, Deserialize)] +pub struct TokenConfig { + #[serde(default)] + pub bos_token: Option, + #[serde(default)] + pub eos_token: Option, + #[serde(default)] + pub unk_token: Option, + #[serde(default)] + pub pad_token: Option, + #[serde(default)] + pub chat_template: Option, +} + +/// Handle both string and object token formats in tokenizer_config.json +#[derive(Debug, Clone, Deserialize)] +#[serde(untagged)] +pub enum StringOrToken { + String(String), + Token { content: String }, +} + +impl StringOrToken { + pub fn as_str(&self) -> &str { + match self { + StringOrToken::String(s) => s, + StringOrToken::Token { content } => content, + } + } +} + +impl Default for StringOrToken { + fn default() -> Self { + StringOrToken::String(String::new()) + } +} + +/// Chat template can be a single string or multiple named templates +#[derive(Debug, Clone, Deserialize)] +#[serde(untagged)] +pub enum ChatTemplateConfig { + Single(String), + Multiple(Vec), +} + +#[derive(Debug, Clone, Deserialize)] +pub struct NamedTemplate { + pub name: String, + pub template: String, +} + +/// Chat template renderer using MiniJinja +pub struct ChatTemplate { + env: Environment<'static>, + bos_token: String, + eos_token: String, +} + +impl ChatTemplate { + /// Create from a Jinja template string + pub fn new( + template: impl Into, + bos_token: impl Into, + eos_token: impl Into, + ) -> Result { + let mut env = Environment::new(); + // Add the raise_exception function that HF templates use + env.add_function("raise_exception", |msg: String| -> Result { + Err(minijinja::Error::new( + minijinja::ErrorKind::InvalidOperation, + msg, + )) + }); + + env.add_template_owned("chat".to_string(), template.into()) + .map_err(|e| ChatTemplateError::TemplateError(e.to_string()))?; + + Ok(Self { + env, + bos_token: bos_token.into(), + eos_token: eos_token.into(), + }) + } + + /// Load chat template from a tokenizer_config.json file + pub fn from_tokenizer_config(path: impl AsRef) -> Result { + let content = std::fs::read_to_string(path.as_ref()) + .map_err(|e| ChatTemplateError::IoError(e.to_string()))?; + + Self::from_tokenizer_config_str(&content) + } + + /// Load chat template from tokenizer_config.json content + pub fn from_tokenizer_config_str(json: &str) -> Result { + let config: TokenConfig = serde_json::from_str(json) + .map_err(|e| ChatTemplateError::ParseError(e.to_string()))?; + + let template = match config.chat_template { + Some(ChatTemplateConfig::Single(t)) => t, + Some(ChatTemplateConfig::Multiple(templates)) => { + // Use "default" template if available, otherwise first one + templates + .iter() + .find(|t| t.name == "default") + .or_else(|| templates.first()) + .map(|t| t.template.clone()) + .ok_or_else(|| ChatTemplateError::NoTemplate)? + } + None => return Err(ChatTemplateError::NoTemplate), + }; + + let bos = config + .bos_token + .map(|t| t.as_str().to_string()) + .unwrap_or_default(); + let eos = config + .eos_token + .map(|t| t.as_str().to_string()) + .unwrap_or_default(); + + Self::new(template, bos, eos) + } + + /// ChatML template used by SmolLM, Qwen, and many other models + pub fn chatml() -> Self { + let template = r#" +{%- for message in messages %} +{{- '<|im_start|>' + message.role + '\n' + message.content | trim + '<|im_end|>\n' }} +{%- endfor %} +{%- if add_generation_prompt %} +{{- '<|im_start|>assistant\n' }} +{%- endif %} +"#; + Self::new(template, "", "<|im_end|>").unwrap() + } + + /// ChatML template with thinking/reasoning support + pub fn chatml_with_thinking() -> Self { + let template = r#" +{%- for message in messages %} +{{- '<|im_start|>' + message.role + '\n' + message.content | trim + '<|im_end|>\n' }} +{%- endfor %} +{%- if add_generation_prompt %} +{%- if enable_thinking %} +{{- '<|im_start|>assistant\n\n' }} +{%- else %} +{{- '<|im_start|>assistant\n' }} +{%- endif %} +{%- endif %} +"#; + Self::new(template, "", "<|im_end|>").unwrap() + } + + /// Llama 2 chat template + pub fn llama2() -> Self { + let template = r#" +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = '<>\n' + messages[0]['content'] + '\n<>\n\n' %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = '' %} +{%- endif %} +{%- for message in messages %} + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} + {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} + {%- endif %} + {%- if loop.index0 == 0 %} + {{- bos_token + '[INST] ' + system_message + message['content'] + ' [/INST]' }} + {%- elif message['role'] == 'user' %} + {{- bos_token + '[INST] ' + message['content'] + ' [/INST]' }} + {%- elif message['role'] == 'assistant' %} + {{- ' ' + message['content'] + ' ' + eos_token }} + {%- endif %} +{%- endfor %} +"#; + Self::new(template, "", "").unwrap() + } + + /// Llama 3 / 3.1 chat template + pub fn llama3() -> Self { + let template = r#" +{%- set loop_messages = messages %} +{%- for message in loop_messages %} + {%- set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' %} + {%- if loop.index0 == 0 %} + {{- bos_token + content }} + {%- else %} + {{- content }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif %} +"#; + Self::new(template, "<|begin_of_text|>", "<|eot_id|>").unwrap() + } + + /// Mistral Instruct template + pub fn mistral() -> Self { + let template = r#" +{{- bos_token }} +{%- for message in messages %} + {%- if message['role'] == 'user' %} + {{- '[INST] ' + message['content'] + ' [/INST]' }} + {%- elif message['role'] == 'assistant' %} + {{- ' ' + message['content'] + eos_token }} + {%- endif %} +{%- endfor %} +"#; + Self::new(template, "", "").unwrap() + } + + /// Gemma template + pub fn gemma() -> Self { + let template = r#" +{%- for message in messages %} + {%- if message['role'] == 'user' %} + {{- 'user\n' + message['content'] + '\n' }} + {%- elif message['role'] == 'assistant' %} + {{- 'model\n' + message['content'] + '\n' }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- 'model\n' }} +{%- endif %} +"#; + Self::new(template, "", "").unwrap() + } + + /// Apply the chat template to messages + pub fn apply( + &self, + messages: &[Message], + options: &ChatTemplateOptions, + ) -> Result { + let template = self + .env + .get_template("chat") + .map_err(|e| ChatTemplateError::TemplateError(e.to_string()))?; + + let result = template + .render(context! { + messages => messages, + add_generation_prompt => options.add_generation_prompt, + continue_final_message => options.continue_final_message, + enable_thinking => options.enable_thinking, + bos_token => &self.bos_token, + eos_token => &self.eos_token, + }) + .map_err(|e| ChatTemplateError::RenderError(e.to_string()))?; + + Ok(result.trim_start().to_string()) + } + + /// Convenience method: apply with add_generation_prompt=true + pub fn apply_for_generation(&self, messages: &[Message]) -> Result { + self.apply(messages, &ChatTemplateOptions::for_generation()) + } +} + +/// Multi-turn conversation manager +pub struct Conversation { + messages: Vec, + template: ChatTemplate, + options: ChatTemplateOptions, +} + +impl Conversation { + /// Create a new conversation with a system prompt + pub fn new(template: ChatTemplate, system_prompt: impl Into) -> Self { + Self { + messages: vec![Message::system(system_prompt)], + template, + options: ChatTemplateOptions::for_generation(), + } + } + + /// Create without a system prompt + pub fn without_system(template: ChatTemplate) -> Self { + Self { + messages: Vec::new(), + template, + options: ChatTemplateOptions::for_generation(), + } + } + + /// Set options (e.g., enable thinking mode) + pub fn with_options(mut self, options: ChatTemplateOptions) -> Self { + self.options = options; + self + } + + /// Add a user message and return the formatted prompt for generation + pub fn user_turn(&mut self, content: impl Into) -> Result { + self.messages.push(Message::user(content)); + self.template.apply(&self.messages, &self.options) + } + + /// Record the assistant's response after generation + pub fn assistant_response(&mut self, content: impl Into) { + self.messages.push(Message::assistant(content)); + } + + /// Add a message with a custom role + pub fn add_message(&mut self, message: Message) { + self.messages.push(message); + } + + /// Get the conversation history + pub fn messages(&self) -> &[Message] { + &self.messages + } + + /// Clear conversation history (keeps system prompt if present) + pub fn clear(&mut self) { + if let Some(first) = self.messages.first() { + if first.role == "system" { + let system = self.messages.remove(0); + self.messages.clear(); + self.messages.push(system); + return; + } + } + self.messages.clear(); + } + + /// Format entire conversation for display (no generation prompt) + pub fn format_history(&self) -> Result { + self.template + .apply(&self.messages, &ChatTemplateOptions::for_training()) + } +} + +/// Errors that can occur with chat templates +#[derive(Debug)] +pub enum ChatTemplateError { + IoError(String), + ParseError(String), + TemplateError(String), + RenderError(String), + NoTemplate, +} + +impl std::fmt::Display for ChatTemplateError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::IoError(e) => write!(f, "IO error: {}", e), + Self::ParseError(e) => write!(f, "Parse error: {}", e), + Self::TemplateError(e) => write!(f, "Template error: {}", e), + Self::RenderError(e) => write!(f, "Render error: {}", e), + Self::NoTemplate => write!(f, "No chat_template found in config"), + } + } +} + +impl std::error::Error for ChatTemplateError {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_chatml_basic() { + let template = ChatTemplate::chatml(); + let messages = vec![ + Message::system("You are helpful."), + Message::user("Hello"), + ]; + + let result = template.apply_for_generation(&messages).unwrap(); + + assert!(result.contains("<|im_start|>system\nYou are helpful.<|im_end|>")); + assert!(result.contains("<|im_start|>user\nHello<|im_end|>")); + assert!(result.ends_with("<|im_start|>assistant\n")); + } + + #[test] + fn test_multi_turn_conversation() { + let mut conv = Conversation::new(ChatTemplate::chatml(), "You are helpful."); + + let prompt1 = conv.user_turn("Hi").unwrap(); + assert!(prompt1.contains("Hi")); + + conv.assistant_response("Hello!"); + + let prompt2 = conv.user_turn("How are you?").unwrap(); + assert!(prompt2.contains("Hi")); + assert!(prompt2.contains("Hello!")); + assert!(prompt2.contains("How are you?")); + } + + #[test] + fn test_thinking_mode() { + let template = ChatTemplate::chatml_with_thinking(); + let messages = vec![Message::user("Think about this")]; + + let result = template + .apply(&messages, &ChatTemplateOptions::for_generation().with_thinking()) + .unwrap(); + + assert!(result.contains("")); + } + + #[test] + fn test_llama3_format() { + let template = ChatTemplate::llama3(); + let messages = vec![ + Message::system("You are helpful."), + Message::user("Hello"), + ]; + + let result = template.apply_for_generation(&messages).unwrap(); + + assert!(result.contains("<|begin_of_text|>")); + assert!(result.contains("<|start_header_id|>system<|end_header_id|>")); + assert!(result.contains("<|start_header_id|>user<|end_header_id|>")); + assert!(result.contains("<|eot_id|>")); + } + + #[test] + fn test_from_json_config() { + let json = r#"{ + "bos_token": "", + "eos_token": "", + "chat_template": "{% for m in messages %}{{ m.role }}: {{ m.content }}\n{% endfor %}" + }"#; + + let template = ChatTemplate::from_tokenizer_config_str(json).unwrap(); + let messages = vec![Message::user("test")]; + let result = template.apply_for_generation(&messages).unwrap(); + + assert!(result.contains("user: test")); + } +} \ No newline at end of file diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs index af49ab5928..757eb44c07 100644 --- a/candle-examples/src/lib.rs +++ b/candle-examples/src/lib.rs @@ -1,5 +1,6 @@ pub mod audio; pub mod bs1770; +pub mod chat_template; pub mod coco_classes; pub mod imagenet; pub mod token_output_stream; From 526ab5aa20f62521215ada8b4786f7f0f6763534 Mon Sep 17 00:00:00 2001 From: DrJesseGlass Date: Wed, 3 Dec 2025 13:53:04 -0500 Subject: [PATCH 8/9] integrate new chat template for smollm3 example --- candle-examples/examples/smollm3/main.rs | 68 +++++++++++------------- 1 file changed, 32 insertions(+), 36 deletions(-) diff --git a/candle-examples/examples/smollm3/main.rs b/candle-examples/examples/smollm3/main.rs index 397417121e..3c35a534e9 100644 --- a/candle-examples/examples/smollm3/main.rs +++ b/candle-examples/examples/smollm3/main.rs @@ -10,6 +10,8 @@ use std::io::Write; use candle::{DType, Device, Tensor}; use candle_examples::token_output_stream::TokenOutputStream; +use candle_examples::chat_template::{ChatTemplate, Message, ChatTemplateOptions}; + use candle_nn::VarBuilder; use candle_transformers::generation::{LogitsProcessor, Sampling}; use hf_hub::{api::sync::Api, Repo, RepoType}; @@ -310,45 +312,39 @@ fn load_full_model(args: &Args, device: &Device) -> Result { // ==================== Text Generation ==================== fn format_prompt(prompt: &str, use_chat_template: bool, enable_thinking: bool) -> String { - if use_chat_template { - // Generate current date dynamically - let now = chrono::Local::now(); - let today_date = now.format("%d %B %Y").to_string(); - - // Set reasoning mode based on thinking flag - let reasoning_mode = if enable_thinking { - "/think" - } else { - "/no_think" - }; + if !use_chat_template { + return prompt.to_string(); + } - // Build the assistant start with or without thinking tags - let assistant_start = if enable_thinking { - "<|im_start|>assistant\n\n" // Open for reasoning - } else { - "<|im_start|>assistant\n\n\n\n" // Empty = skip reasoning - }; + let template = ChatTemplate::chatml_with_thinking(); + + // Build system message with SmolLM3's metadata format + let now = chrono::Local::now(); + let today_date = now.format("%d %B %Y").to_string(); + let reasoning_mode = if enable_thinking { "/think" } else { "/no_think" }; + + let system_content = format!( + "## Metadata\n\n\ + Knowledge Cutoff Date: June 2025\n\ + Today Date: {}\n\ + Reasoning Mode: {}\n\n\ + ## Custom Instructions\n\n\ + You are a helpful AI assistant named SmolLM, trained by Hugging Face.", + today_date, reasoning_mode + ); + + let messages = vec![ + Message::system(system_content), + Message::user(prompt), + ]; - format!( - "<|im_start|>system\n\ -## Metadata\n\ -\n\ -Knowledge Cutoff Date: June 2025\n\ -Today Date: {}\n\ -Reasoning Mode: {}\n\ -\n\ -## Custom Instructions\n\ -\n\ -You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\ -\n\ -<|im_start|>user\n\ -{}<|im_end|>\n\ -{}", - today_date, reasoning_mode, prompt, assistant_start - ) + let options = if enable_thinking { + ChatTemplateOptions::for_generation().with_thinking() } else { - prompt.to_string() - } + ChatTemplateOptions::for_generation() + }; + + template.apply(&messages, &options).unwrap() } fn get_eos_token(tokenizer: &Tokenizer, config: &ModelConfig) -> u32 { From e9cf0e3d993b2391bd71e23a769f394ead4142b1 Mon Sep 17 00:00:00 2001 From: DrJesseGlass Date: Fri, 5 Dec 2025 15:45:56 -0500 Subject: [PATCH 9/9] fmt and clippy --- candle-examples/examples/smollm3/main.rs | 13 ++++++----- candle-examples/src/chat_template.rs | 23 ++++++++----------- .../src/models/smol/quantized_smollm3.rs | 4 ++-- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/candle-examples/examples/smollm3/main.rs b/candle-examples/examples/smollm3/main.rs index 3c35a534e9..93abf6b673 100644 --- a/candle-examples/examples/smollm3/main.rs +++ b/candle-examples/examples/smollm3/main.rs @@ -9,8 +9,8 @@ use clap::{Parser, ValueEnum}; use std::io::Write; use candle::{DType, Device, Tensor}; +use candle_examples::chat_template::{ChatTemplate, ChatTemplateOptions, Message}; use candle_examples::token_output_stream::TokenOutputStream; -use candle_examples::chat_template::{ChatTemplate, Message, ChatTemplateOptions}; use candle_nn::VarBuilder; use candle_transformers::generation::{LogitsProcessor, Sampling}; @@ -321,7 +321,11 @@ fn format_prompt(prompt: &str, use_chat_template: bool, enable_thinking: bool) - // Build system message with SmolLM3's metadata format let now = chrono::Local::now(); let today_date = now.format("%d %B %Y").to_string(); - let reasoning_mode = if enable_thinking { "/think" } else { "/no_think" }; + let reasoning_mode = if enable_thinking { + "/think" + } else { + "/no_think" + }; let system_content = format!( "## Metadata\n\n\ @@ -333,10 +337,7 @@ fn format_prompt(prompt: &str, use_chat_template: bool, enable_thinking: bool) - today_date, reasoning_mode ); - let messages = vec![ - Message::system(system_content), - Message::user(prompt), - ]; + let messages = vec![Message::system(system_content), Message::user(prompt)]; let options = if enable_thinking { ChatTemplateOptions::for_generation().with_thinking() diff --git a/candle-examples/src/chat_template.rs b/candle-examples/src/chat_template.rs index a131ca9858..81988dc0ec 100644 --- a/candle-examples/src/chat_template.rs +++ b/candle-examples/src/chat_template.rs @@ -190,8 +190,8 @@ impl ChatTemplate { /// Load chat template from tokenizer_config.json content pub fn from_tokenizer_config_str(json: &str) -> Result { - let config: TokenConfig = serde_json::from_str(json) - .map_err(|e| ChatTemplateError::ParseError(e.to_string()))?; + let config: TokenConfig = + serde_json::from_str(json).map_err(|e| ChatTemplateError::ParseError(e.to_string()))?; let template = match config.chat_template { Some(ChatTemplateConfig::Single(t)) => t, @@ -202,7 +202,7 @@ impl ChatTemplate { .find(|t| t.name == "default") .or_else(|| templates.first()) .map(|t| t.template.clone()) - .ok_or_else(|| ChatTemplateError::NoTemplate)? + .ok_or(ChatTemplateError::NoTemplate)? } None => return Err(ChatTemplateError::NoTemplate), }; @@ -460,10 +460,7 @@ mod tests { #[test] fn test_chatml_basic() { let template = ChatTemplate::chatml(); - let messages = vec![ - Message::system("You are helpful."), - Message::user("Hello"), - ]; + let messages = vec![Message::system("You are helpful."), Message::user("Hello")]; let result = template.apply_for_generation(&messages).unwrap(); @@ -493,7 +490,10 @@ mod tests { let messages = vec![Message::user("Think about this")]; let result = template - .apply(&messages, &ChatTemplateOptions::for_generation().with_thinking()) + .apply( + &messages, + &ChatTemplateOptions::for_generation().with_thinking(), + ) .unwrap(); assert!(result.contains("")); @@ -502,10 +502,7 @@ mod tests { #[test] fn test_llama3_format() { let template = ChatTemplate::llama3(); - let messages = vec![ - Message::system("You are helpful."), - Message::user("Hello"), - ]; + let messages = vec![Message::system("You are helpful."), Message::user("Hello")]; let result = template.apply_for_generation(&messages).unwrap(); @@ -529,4 +526,4 @@ mod tests { assert!(result.contains("user: test")); } -} \ No newline at end of file +} diff --git a/candle-transformers/src/models/smol/quantized_smollm3.rs b/candle-transformers/src/models/smol/quantized_smollm3.rs index 7bbc88f7c3..371ef263ab 100644 --- a/candle-transformers/src/models/smol/quantized_smollm3.rs +++ b/candle-transformers/src/models/smol/quantized_smollm3.rs @@ -57,7 +57,7 @@ fn reconstruct_qk_weights(gguf_weight: &Tensor, _num_heads: usize) -> Result>, ) -> Result { - let attn_vb = vb.pp(&format!("blk.{layer_idx}")); + let attn_vb = vb.pp(format!("blk.{layer_idx}")); Ok(Self { self_attn: QuantizedAttention::new(attn_vb.clone(), cfg, layer_idx, rotary_emb)?,