diff --git a/Cargo.toml b/Cargo.toml index 91df43d86..6d6bea322 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,6 @@ members = [ "mlx-lm", "mlx-lm-utils", "mlx-tests", - "examples/*", "xtask", ] diff --git a/mlx-lm/Cargo.toml b/mlx-lm/Cargo.toml index 4dcdfe42a..ef91118ac 100644 --- a/mlx-lm/Cargo.toml +++ b/mlx-lm/Cargo.toml @@ -13,6 +13,7 @@ description = "Rust implementation of mlx-lm" [dependencies] # Local dependencies mlx-rs.workspace = true +mlx-macros.workspace = true mlx-lm-utils.workspace = true # External dependencies @@ -23,4 +24,7 @@ clap = { version = "4", features = ["derive"] } idna_adapter = "1.2" thiserror = "2" serde_json = "1" -minijinja = "2" \ No newline at end of file +minijinja = "2" + +[dev-dependencies] +lazy_static = "1" \ No newline at end of file diff --git a/mlx-lm/src/models/llama.rs b/mlx-lm/src/models/llama.rs new file mode 100644 index 000000000..25909a792 --- /dev/null +++ b/mlx-lm/src/models/llama.rs @@ -0,0 +1,775 @@ +use std::{ + collections::{HashMap, HashSet}, + path::Path, +}; + +use mlx_rs::{ + argmax_axis, array, + builder::Builder, + categorical, + error::Exception, + macros::{ModuleParameters, Quantizable}, + module::{Module, ModuleParametersExt}, + nn, + ops::indexing::{IndexOp, NewAxis}, + quantization::MaybeQuantized, + Array, +}; +use serde::Deserialize; +use serde_json::Value; +use tokenizers::Tokenizer; + +use crate::{ + cache::KeyValueCache, + error::Error, + utils::rope::{initialize_rope, FloatOrString, RopeVariant}, +}; + +#[derive(Debug, Clone, Deserialize)] +pub struct ModelArgs { + pub model_type: String, + pub hidden_size: i32, + pub num_hidden_layers: i32, + pub intermediate_size: i32, + pub num_attention_heads: i32, + pub rms_norm_eps: f32, + pub vocab_size: i32, + pub num_key_value_heads: i32, + pub max_position_embeddings: i32, + pub rope_theta: f32, + pub head_dim: i32, + #[serde(default = "default_true")] + pub tie_word_embeddings: bool, + #[serde(default)] + pub attention_bias: bool, + #[serde(default)] + pub mlp_bias: bool, + pub rope_scaling: Option>, +} + +fn default_true() -> bool { + true +} + +#[derive(Debug, Clone, ModuleParameters, Quantizable)] +pub struct Attention { + pub n_heads: i32, + pub n_kv_heads: i32, + pub scale: f32, + + #[quantizable] + #[param] + pub q_proj: MaybeQuantized, + #[quantizable] + #[param] + pub k_proj: MaybeQuantized, + #[quantizable] + #[param] + pub v_proj: MaybeQuantized, + #[quantizable] + #[param] + pub o_proj: MaybeQuantized, + #[param] + pub rope: RopeVariant, +} + +impl Attention { + pub fn new(args: &ModelArgs) -> Result { + let dim = args.hidden_size; + let n_heads = args.num_attention_heads; + let n_kv_heads = args.num_key_value_heads; + + let head_dim = args.head_dim; + let scale = (head_dim as f32).sqrt().recip(); + + let q_proj = nn::LinearBuilder::new(dim, n_heads * head_dim) + .bias(args.attention_bias) + .build()?; + let k_proj = nn::LinearBuilder::new(dim, n_kv_heads * head_dim) + .bias(args.attention_bias) + .build()?; + let v_proj = nn::LinearBuilder::new(dim, n_kv_heads * head_dim) + .bias(args.attention_bias) + .build()?; + let o_proj = nn::LinearBuilder::new(n_heads * head_dim, dim) + .bias(args.attention_bias) + .build()?; + + let rope = initialize_rope( + head_dim, + args.rope_theta, + false, + &args.rope_scaling, + args.max_position_embeddings, + )?; + + Ok(Self { + n_heads, + n_kv_heads, + scale, + q_proj: MaybeQuantized::Original(q_proj), + k_proj: MaybeQuantized::Original(k_proj), + v_proj: MaybeQuantized::Original(v_proj), + o_proj: MaybeQuantized::Original(o_proj), + rope, + }) + } +} + +pub struct AttentionInput<'a, C> { + pub x: &'a Array, + pub mask: Option<&'a Array>, + pub cache: Option<&'a mut C>, +} + +impl Module> for Attention +where + C: KeyValueCache, +{ + type Output = Array; + + type Error = Exception; + + #[allow(non_snake_case)] + fn forward(&mut self, input: AttentionInput<'_, C>) -> Result { + let AttentionInput { x, mask, mut cache } = input; + + let shape = x.shape(); + let B = shape[0]; + let L = shape[1]; + + let queries = self.q_proj.forward(x)?; + let keys = self.k_proj.forward(x)?; + let values = self.v_proj.forward(x)?; + + let mut queries = queries + .reshape(&[B, L, self.n_heads, -1])? + .transpose_axes(&[0, 2, 1, 3])?; + let mut keys = keys + .reshape(&[B, L, self.n_kv_heads, -1])? + .transpose_axes(&[0, 2, 1, 3])?; + let mut values = values + .reshape(&[B, L, self.n_kv_heads, -1])? + .transpose_axes(&[0, 2, 1, 3])?; + + if let Some(cache) = cache.as_mut() { + let q_input = nn::RopeInputBuilder::new(&queries) + .offset(cache.offset()) + .build()?; + queries = self.rope.forward(q_input)?; + let k_input = nn::RopeInputBuilder::new(&keys) + .offset(cache.offset()) + .build()?; + keys = self.rope.forward(k_input)?; + + (keys, values) = cache.update_and_fetch(keys, values)?; + } else { + queries = self.rope.forward(nn::RopeInput::new(&queries))?; + keys = self.rope.forward(nn::RopeInput::new(&keys))?; + } + + let output = crate::utils::scaled_dot_product_attention( + queries, keys, values, cache, self.scale, mask, + )? + .transpose_axes(&[0, 2, 1, 3])? + .reshape(&[B, L, -1])?; + + self.o_proj.forward(&output) + } + + fn training_mode(&mut self, mode: bool) { + self.q_proj.training_mode(mode); + self.k_proj.training_mode(mode); + self.v_proj.training_mode(mode); + self.o_proj.training_mode(mode); + >::training_mode(&mut self.rope, mode); + } +} + +#[derive(Debug, Clone, ModuleParameters, Quantizable)] +pub struct Mlp { + #[quantizable] + #[param] + pub gate_proj: MaybeQuantized, + + #[quantizable] + #[param] + pub down_proj: MaybeQuantized, + + #[quantizable] + #[param] + pub up_proj: MaybeQuantized, +} + +impl Mlp { + pub fn new(dim: i32, hidden_dim: i32, mlp_bias: bool) -> Result { + let gate_proj = nn::LinearBuilder::new(dim, hidden_dim) + .bias(mlp_bias) + .build()?; + let down_proj = nn::LinearBuilder::new(hidden_dim, dim) + .bias(mlp_bias) + .build()?; + let up_proj = nn::LinearBuilder::new(dim, hidden_dim) + .bias(mlp_bias) + .build()?; + + Ok(Self { + gate_proj: MaybeQuantized::Original(gate_proj), + down_proj: MaybeQuantized::Original(down_proj), + up_proj: MaybeQuantized::Original(up_proj), + }) + } +} + +impl Module<&Array> for Mlp { + type Output = Array; + + type Error = Exception; + + fn forward(&mut self, input: &Array) -> Result { + let down_proj_input = + nn::silu(self.gate_proj.forward(input)?)?.multiply(self.up_proj.forward(input)?)?; + self.down_proj.forward(&down_proj_input) + } + + fn training_mode(&mut self, mode: bool) { + self.gate_proj.training_mode(mode); + self.down_proj.training_mode(mode); + self.up_proj.training_mode(mode); + } +} + +#[derive(Debug, Clone, ModuleParameters, Quantizable)] +pub struct TransformerBlock { + pub num_attention_heads: i32, + pub hidden_size: i32, + + #[quantizable] + #[param] + pub self_attn: Attention, + + #[quantizable] + #[param] + pub mlp: Mlp, + + #[param] + pub input_layernorm: nn::RmsNorm, + + #[param] + pub post_attention_layernorm: nn::RmsNorm, +} + +impl TransformerBlock { + pub fn new(args: &ModelArgs) -> Result { + let num_attention_heads = args.num_attention_heads; + let hidden_size = args.hidden_size; + + let self_attn = Attention::new(args)?; + let mlp = Mlp::new(args.hidden_size, args.intermediate_size, args.mlp_bias)?; + let input_layernorm = nn::RmsNormBuilder::new(args.hidden_size) + .eps(args.rms_norm_eps) + .build()?; + let post_attention_layernorm = nn::RmsNormBuilder::new(args.hidden_size) + .eps(args.rms_norm_eps) + .build()?; + + Ok(Self { + num_attention_heads, + hidden_size, + self_attn, + mlp, + input_layernorm, + post_attention_layernorm, + }) + } +} + +impl Module> for TransformerBlock +where + C: KeyValueCache, +{ + type Output = Array; + + type Error = Exception; + + fn forward(&mut self, input: AttentionInput<'_, C>) -> Result { + let AttentionInput { x, mask, cache } = input; + + let self_attn_input = AttentionInput { + x: &self.input_layernorm.forward(x)?, + mask, + cache, + }; + let r = self.self_attn.forward(self_attn_input)?; + let h = x.add(r)?; + + let r = self + .mlp + .forward(&self.post_attention_layernorm.forward(&h)?)?; + h.add(r) + } + + fn training_mode(&mut self, mode: bool) { + >>::training_mode(&mut self.self_attn, mode); + self.mlp.training_mode(mode); + self.input_layernorm.training_mode(mode); + self.post_attention_layernorm.training_mode(mode); + } +} + +#[derive(Debug, Clone, ModuleParameters, Quantizable)] +pub struct LlamaModel { + pub vocab_size: i32, + pub num_hidden_layers: i32, + + #[quantizable] + #[param] + pub embed_tokens: MaybeQuantized, + + #[quantizable] + #[param] + pub layers: Vec, + + #[param] + pub norm: nn::RmsNorm, +} + +impl LlamaModel { + pub fn new(args: &ModelArgs) -> Result { + assert!(args.vocab_size.is_positive()); + + let vocab_size = args.vocab_size; + let num_hidden_layers = args.num_hidden_layers; + + let embed_tokens = nn::Embedding::new(args.vocab_size, args.hidden_size)?; + let layers = (0..num_hidden_layers) + .map(|_| TransformerBlock::new(args)) + .collect::, _>>()?; + let norm = nn::RmsNormBuilder::new(args.hidden_size) + .eps(args.rms_norm_eps) + .build()?; + + Ok(Self { + vocab_size, + num_hidden_layers, + embed_tokens: MaybeQuantized::Original(embed_tokens), + layers, + norm, + }) + } +} + +pub struct ModelInput<'a, C> { + pub inputs: &'a Array, + pub mask: Option<&'a Array>, + pub cache: &'a mut Vec>, +} + +impl Module> for LlamaModel +where + C: KeyValueCache + Default, +{ + type Output = Array; + + type Error = Exception; + + fn forward(&mut self, input: ModelInput<'_, C>) -> Result { + let ModelInput { + inputs, + mask, + cache, + } = input; + + let mut h = self.embed_tokens.forward(inputs)?; + + let mask = match mask { + Some(mask) => Some(mask.clone()), + None => { + if h.shape()[1] > 1 { + let m = + nn::MultiHeadAttention::create_additive_causal_mask::(h.shape()[1])?; + Some(m.as_dtype(h.dtype())?) + } else { + None + } + } + }; + + if cache.is_empty() { + *cache = (0..self.layers.len()).map(|_| Some(C::default())).collect(); + } + + for (layer, c) in self.layers.iter_mut().zip(cache.iter_mut()) { + let layer_input = AttentionInput { + x: &h, + mask: mask.as_ref(), + cache: c.as_mut(), + }; + h = layer.forward(layer_input)?; + } + + self.norm.forward(&h) + } + + fn training_mode(&mut self, mode: bool) { + self.embed_tokens.training_mode(mode); + for layer in &mut self.layers { + >>::training_mode(layer, mode); + } + self.norm.training_mode(mode); + } +} + +#[derive(Debug, Clone, ModuleParameters, Quantizable)] +pub struct Model { + pub args: ModelArgs, + + #[quantizable] + #[param] + pub model: LlamaModel, + + #[quantizable] + #[param] + pub lm_head: Option>, +} + +impl Model { + pub fn new(args: ModelArgs) -> Result { + let model = LlamaModel::new(&args)?; + let lm_head = if !args.tie_word_embeddings { + Some(MaybeQuantized::Original( + nn::LinearBuilder::new(args.hidden_size, args.vocab_size) + .bias(false) + .build()?, + )) + } else { + None + }; + + Ok(Self { + args, + model, + lm_head, + }) + } + + pub fn model_type(&self) -> &str { + &self.args.model_type + } +} + +impl Module> for Model +where + C: KeyValueCache + Default, +{ + type Output = Array; + + type Error = Exception; + + fn forward(&mut self, input: ModelInput<'_, C>) -> Result { + let out = self.model.forward(input)?; + + match self.lm_head.as_mut() { + Some(lm_head) => lm_head.forward(&out), + None => match &mut self.model.embed_tokens { + MaybeQuantized::Original(embed_tokens) => embed_tokens.as_linear(&out), + MaybeQuantized::Quantized(q_embed_tokens) => q_embed_tokens.as_linear(&out), + }, + } + } + + fn training_mode(&mut self, mode: bool) { + >>::training_mode(&mut self.model, mode); + if let Some(lm_head) = &mut self.lm_head { + lm_head.training_mode(mode); + } + } +} + +pub fn load_llama_tokenizer(model_dir: impl AsRef) -> Result { + let file = model_dir.as_ref().join("tokenizer.json"); + Tokenizer::from_file(file).map_err(Into::into) +} + +pub fn get_llama_model_args(model_dir: impl AsRef) -> Result { + let model_args_filename = model_dir.as_ref().join("config.json"); + let file = std::fs::File::open(model_args_filename)?; + let model_args: ModelArgs = serde_json::from_reader(file)?; + + Ok(model_args) +} + +#[derive(Debug, Clone, Deserialize)] +pub struct WeightMap { + pub metadata: HashMap, + pub weight_map: HashMap, +} + +pub fn load_llama_model(model_dir: impl AsRef) -> Result { + let model_dir = model_dir.as_ref(); + let model_args = get_llama_model_args(model_dir)?; + let mut model = Model::new(model_args)?; + + let weights_index = model_dir.join("model.safetensors.index.json"); + if weights_index.exists() { + // Sharded weights: read the index to find all weight files + let json = std::fs::read_to_string(weights_index)?; + let weight_map: WeightMap = serde_json::from_str(&json)?; + + let weight_files: HashSet<&String> = weight_map.weight_map.values().collect(); + for weight_file in weight_files { + let weights_filename = model_dir.join(weight_file); + model.load_safetensors(weights_filename)?; + } + } else { + // Single weight file + let weights_filename = model_dir.join("model.safetensors"); + model.load_safetensors(weights_filename)?; + } + + Ok(model) +} + +pub fn sample(logits: &Array, temp: f32) -> Result { + match temp { + 0.0 => argmax_axis!(logits, -1), + _ => { + let logits = logits.multiply(array!(1.0 / temp))?; + categorical!(logits) + } + } +} + +pub struct Generate<'a, C> { + model: &'a mut Model, + cache: &'a mut Vec>, + temp: f32, + state: GenerateState<'a>, +} + +impl<'a, C> Generate<'a, C> +where + C: KeyValueCache + Default, +{ + pub fn new( + model: &'a mut Model, + cache: &'a mut Vec>, + temp: f32, + prompt_token: &'a Array, + ) -> Self { + Self { + model, + cache, + temp, + state: GenerateState::Prefill { prompt_token }, + } + } +} + +pub enum GenerateState<'a> { + Prefill { prompt_token: &'a Array }, + Decode { y: Array }, +} + +macro_rules! tri { + ($expr:expr) => { + match $expr { + Ok(val) => val, + Err(e) => return Some(Err(e.into())), + } + }; +} + +impl<'a, C> Iterator for Generate<'a, C> +where + C: KeyValueCache + Default, +{ + type Item = Result; + + fn next(&mut self) -> Option { + match &self.state { + GenerateState::Prefill { prompt_token } => { + let input = ModelInput { + inputs: prompt_token, + mask: None, + cache: self.cache, + }; + let logits = tri!(self.model.forward(input)); + let y = tri!(sample(&logits.index((.., -1, ..)), self.temp)); + self.state = GenerateState::Decode { y: y.clone() }; + + Some(Ok(y)) + } + GenerateState::Decode { y } => { + let inputs = y.index((.., NewAxis)); + let input = ModelInput { + inputs: &inputs, + mask: None, + cache: self.cache, + }; + let logits = tri!(self.model.forward(input)); + let y = tri!(sample(&logits.index((.., -1, ..)), self.temp)); + + self.state = GenerateState::Decode { y: y.clone() }; + + Some(Ok(y)) + } + } + } +} + +#[cfg(test)] +mod tests { + use std::{env::home_dir, fs}; + + use lazy_static::lazy_static; + use mlx_rs::{ + ops::indexing::{IndexOp, NewAxis}, + transforms::eval, + Array, + }; + + use crate::{ + cache::ConcatKeyValueCache, + models::llama::{load_llama_model, load_llama_tokenizer}, + }; + + /// Resolve the HuggingFace cache directory to the actual snapshot path. + /// The structure is: + /// models----/ + /// refs/ + /// main (contains the commit hash) + /// snapshots/ + /// / (actual model files) + fn resolve_hf_cache_dir(model_cache_dir: &str) -> String { + let refs_main = std::path::Path::new(model_cache_dir) + .join("refs") + .join("main"); + let commit_hash = fs::read_to_string(&refs_main) + .unwrap_or_default() + .trim() + .to_string(); + std::path::Path::new(model_cache_dir) + .join("snapshots") + .join(commit_hash) + .to_string_lossy() + .into_owned() + } + + lazy_static! { + static ref CACHED_TEST_MODEL_DIR: String = { + let cache_dir = home_dir() + .map(|p| { + p.join(".cache") + .join("huggingface") + .join("hub") + .join("models--meta-llama--Llama-3.2-1B-Instruct") + .to_string_lossy() + .into_owned() + }) + .unwrap_or_default(); + + resolve_hf_cache_dir(&cache_dir) + }; + } + + #[test] + #[ignore = "requires local model files"] + fn test_load_llama_model() { + use mlx_rs::module::ModuleParameters; + + let model_dir = CACHED_TEST_MODEL_DIR.as_str(); + let model_args = super::get_llama_model_args(model_dir).unwrap(); + let model = super::Model::new(model_args).unwrap(); + + // Print some model parameter keys + let params = model.parameters().flatten(); + let mut param_keys: Vec<_> = params.keys().map(|k| k.to_string()).collect(); + param_keys.sort(); + println!("=== Model parameter keys (first 20) ==="); + for key in param_keys.iter().take(20) { + println!(" {key}"); + } + + // Print some safetensor keys + let weights_path = std::path::Path::new(model_dir).join("model.safetensors"); + let loaded = mlx_rs::Array::load_safetensors(&weights_path).unwrap(); + let mut weight_keys: Vec<_> = loaded.keys().map(|k| k.to_string()).collect(); + weight_keys.sort(); + println!("=== Safetensor weight keys (first 20) ==="); + for key in weight_keys.iter().take(20) { + println!(" {key}"); + } + + // Find unmatched keys + let param_set: std::collections::HashSet<_> = param_keys.iter().collect(); + let weight_set: std::collections::HashSet<_> = weight_keys.iter().collect(); + let unloaded: Vec<_> = weight_set.difference(¶m_set).collect(); + let missing: Vec<_> = param_set.difference(&weight_set).collect(); + println!( + "=== Weight keys NOT in model params ({}) ===", + unloaded.len() + ); + for key in unloaded.iter().take(10) { + println!(" {key}"); + } + println!( + "=== Model param keys NOT in weights ({}) ===", + missing.len() + ); + for key in missing.iter().take(10) { + println!(" {key}"); + } + println!( + "Total model params: {}, Total weight keys: {}", + param_keys.len(), + weight_keys.len() + ); + } + + #[test] + #[ignore = "requires local model files"] + fn test_load_tokenizer() { + let tokenizer = load_llama_tokenizer(CACHED_TEST_MODEL_DIR.as_str()).unwrap(); + + let _encoding = tokenizer.encode("Hello, world!", true).unwrap(); + } + + #[test] + #[ignore = "requires local model files"] + fn test_load_and_run_llama_with_concat_cache() { + let tokenizer = load_llama_tokenizer(CACHED_TEST_MODEL_DIR.as_str()).unwrap(); + let mut model = load_llama_model(CACHED_TEST_MODEL_DIR.as_str()).unwrap(); + + let prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nWhat is the capital of France?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"; + let encoding = tokenizer.encode(prompt, false).unwrap(); + let prompt_tokens = Array::from(encoding.get_ids()).index(NewAxis); + let mut cache = Vec::new(); + + let eos_token_id = 128001u32; + let eot_token_id = 128009u32; + + let mut token_ids = Vec::new(); + let generate = super::Generate::::new( + &mut model, + &mut cache, + 0.0, + &prompt_tokens, + ); + for (token, _ntoks) in generate.zip(0..50) { + let token = token.unwrap(); + eval([&token]).unwrap(); + let token_id = token.item::(); + print!("[{}]", token_id); + if token_id == eos_token_id || token_id == eot_token_id { + break; + } + token_ids.push(token_id); + } + println!(); + + let output = tokenizer.decode(&token_ids, true).unwrap(); + println!("Response: {output}"); + println!("------"); + } +} diff --git a/mlx-lm/src/models/mod.rs b/mlx-lm/src/models/mod.rs index 2e427b96b..37d46974f 100644 --- a/mlx-lm/src/models/mod.rs +++ b/mlx-lm/src/models/mod.rs @@ -1 +1,2 @@ +pub mod llama; pub mod qwen3; diff --git a/mlx-lm/src/models/qwen3.rs b/mlx-lm/src/models/qwen3.rs index 132a1fb90..f422f55bc 100644 --- a/mlx-lm/src/models/qwen3.rs +++ b/mlx-lm/src/models/qwen3.rs @@ -24,7 +24,7 @@ use crate::{ error::Error, utils::{ create_attention_mask, - rope::{initialize_rope, FloatOrString}, + rope::{initialize_rope, FloatOrString, RopeVariant}, AttentionMask, }, }; @@ -69,7 +69,7 @@ pub struct Attention { #[param] pub k_norm: nn::RmsNorm, #[param] - pub rope: nn::Rope, + pub rope: RopeVariant, } impl Attention { @@ -197,7 +197,7 @@ where self.o_proj.training_mode(mode); self.q_norm.training_mode(mode); self.k_norm.training_mode(mode); - >::training_mode(&mut self.rope, mode); + >::training_mode(&mut self.rope, mode); } } diff --git a/mlx-lm/src/utils/rope.rs b/mlx-lm/src/utils/rope.rs index 85585fbdb..668dc2a4d 100644 --- a/mlx-lm/src/utils/rope.rs +++ b/mlx-lm/src/utils/rope.rs @@ -1,6 +1,14 @@ use std::collections::HashMap; -use mlx_rs::{builder::Builder, error::Exception, nn}; +use mlx_macros::ModuleParameters; +use mlx_rs::{ + builder::Builder, + error::Exception, + module::Module, + nn, + ops::{arange, which}, + Array, +}; use serde::Deserialize; #[derive(Debug, Clone, PartialEq)] @@ -9,8 +17,9 @@ pub enum FloatOrStr<'a> { Str(&'a str), } -// TODO: check if additionl serde attributes are needed +// TODO: check if additional serde attributes are needed #[derive(Debug, Clone, Deserialize)] +#[serde(untagged)] pub enum FloatOrString { Float(f32), String(String), @@ -25,13 +34,206 @@ impl FloatOrString { } } +/// Get a numeric float value from a scaling config by key. +/// +/// Note: str variants in the config are not always floats — values like "default" or "linear" +/// are also valid for non-numeric fields. This function should only be called for keys that +/// are expected to hold numeric values. +fn get_numeric_from_config( + config: &HashMap, + key: &str, +) -> Result { + match config + .get(key) + .map(FloatOrString::borrowed) + .ok_or_else(|| { + Exception::custom(format!(r#"key "{key}" is not found in scaling config"#)) + })? { + FloatOrStr::Float(f) => Ok(f), + FloatOrStr::Str(s) => s + .parse::() + .map_err(|_| Exception::custom(format!(r#"key "{key}" is not a valid number"#))), + } +} + +/// Llama3-style RoPE with frequency scaling. +/// +/// Applies piecewise frequency scaling based on wavelength cutoffs derived from +/// `low_freq_factor`, `high_freq_factor`, `factor`, and `original_max_position_embeddings`. +// TODO: support derive ModuleParameters for structs with non-param Array fields +#[derive(Debug, Clone, ModuleParameters)] +pub struct Llama3Rope { + pub dimensions: i32, + pub traditional: bool, + pub scale: f32, + /// Pre-computed scaled frequencies. Not a module parameter. + pub freqs: Array, +} + +impl Llama3Rope { + pub fn new( + dims: i32, + traditional: bool, + original_max_position_embeddings: i32, + base: f32, + factor: f32, + low_freq_factor: f32, + high_freq_factor: f32, + ) -> Result { + let half_dims = dims / 2; + + // Compute freqs using MLX ops, matching Python: + // freqs = base ** (mx.arange(0, dims, 2) / dims) + // which equals base^(2i/dims) for i in 0..half_dims + let indices = arange::<_, f32>(None, half_dims, None)?; + let exponents = indices.multiply(Array::from_f32(2.0 / dims as f32))?; + let freqs = Array::from_f32(base).power(&exponents)?; + + let old_context_len = original_max_position_embeddings as f32; + let low_freq_wavelen = old_context_len / low_freq_factor; + let high_freq_wavelen = old_context_len / high_freq_factor; + + // wavelens = 2 * pi * freqs + // Apply piecewise scaling matching Python exactly: + // freqs = where(wavelens > low_freq_wavelen, freqs * factor, freqs) + // is_medium = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen) + // smooth_factors = (old_context_len / wavelens - low_freq_factor) / (high - low) + // smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors) + // freqs = where(is_medium, smooth_freqs, freqs) + let two_pi = Array::from_f32(2.0 * std::f32::consts::PI); + let wavelens = freqs.multiply(&two_pi)?; + + // First pass: scale low frequencies (long wavelengths) by factor + let is_low = wavelens.gt(Array::from_f32(low_freq_wavelen))?; + let freqs = which(&is_low, &freqs.multiply(Array::from_f32(factor))?, &freqs)?; + + // Second pass: smooth interpolation for medium frequencies + let is_medium = wavelens + .gt(Array::from_f32(high_freq_wavelen))? + .logical_and(&wavelens.lt(Array::from_f32(low_freq_wavelen))?)?; + + let smooth_factors = wavelens + .reciprocal()? + .multiply(Array::from_f32(old_context_len))? + .subtract(Array::from_f32(low_freq_factor))? + .divide(Array::from_f32(high_freq_factor - low_freq_factor))?; + + // smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors) + let one_minus_smooth = Array::from_f32(1.0).subtract(&smooth_factors)?; + let denom = one_minus_smooth + .divide(Array::from_f32(factor))? + .add(&smooth_factors)?; + let smooth_freqs = freqs.divide(&denom)?; + + let freqs = which(&is_medium, &smooth_freqs, &freqs)?; + + Ok(Self { + dimensions: dims, + traditional, + scale: 1.0, + freqs, + }) + } +} + +impl<'a, Input> Module for Llama3Rope +where + Input: Into>, +{ + type Error = Exception; + type Output = Array; + + fn forward(&mut self, input: Input) -> Result { + let nn::RopeInput { x, offset } = input.into(); + let shape = x.shape(); + let x = x.reshape(&[-1, x.dim(-2), x.dim(-1)])?; + let x = mlx_rs::fast::rope( + x, + self.dimensions, + self.traditional, + None::, + self.scale, + offset, + &self.freqs, + )?; + x.reshape(shape) + } + + fn training_mode(&mut self, _mode: bool) {} +} + +/// Enum wrapping different RoPE variants so that `initialize_rope` can return +/// either a standard RoPE or a Llama3 RoPE. +#[derive(Debug, Clone)] +pub enum RopeVariant { + Default(nn::Rope), + Llama3(Llama3Rope), +} + +// TODO: support derive ModuleParameters for enum +impl mlx_rs::module::ModuleParameters for RopeVariant { + fn num_parameters(&self) -> usize { + 0 + } + + fn freeze_parameters(&mut self, _recursive: bool) {} + + fn unfreeze_parameters(&mut self, _recursive: bool) {} + + fn parameters(&self) -> mlx_rs::module::ModuleParamRef<'_> { + mlx_rs::nested::NestedHashMap::new() + } + + fn parameters_mut(&mut self) -> mlx_rs::module::ModuleParamMut<'_> { + mlx_rs::nested::NestedHashMap::new() + } + + fn trainable_parameters(&self) -> mlx_rs::module::ModuleParamRef<'_> { + mlx_rs::nested::NestedHashMap::new() + } + + fn all_frozen(&self) -> Option { + None + } + + fn any_frozen(&self) -> Option { + None + } +} + +impl<'a, Input> Module for RopeVariant +where + Input: Into>, +{ + type Error = Exception; + type Output = Array; + + fn forward(&mut self, input: Input) -> Result { + match self { + RopeVariant::Default(rope) => rope.forward(input), + RopeVariant::Llama3(rope) => rope.forward(input), + } + } + + fn training_mode(&mut self, mode: bool) { + match self { + RopeVariant::Default(rope) => { + >::training_mode(rope, mode) + } + RopeVariant::Llama3(rope) => { + >::training_mode(rope, mode) + } + } + } +} + pub fn initialize_rope( dims: i32, base: f32, // rope_theta traditional: bool, scaling_config: &Option>, - _max_position_embeddings: i32, // TODO: implement other RoPE -) -> Result { + _max_position_embeddings: i32, +) -> Result { let rope_type = scaling_config .as_ref() .and_then(|config| { @@ -44,19 +246,7 @@ pub fn initialize_rope( if rope_type == FloatOrStr::Str("default") || rope_type == FloatOrStr::Str("linear") { let scale = if rope_type == FloatOrStr::Str("linear") { - let den = match scaling_config - .as_ref() - .and_then(|config| config.get("factor")) - .map(FloatOrString::borrowed) - .ok_or_else(|| { - Exception::custom(r#"key "factor" is not found in scaling config"#) - })? { - FloatOrStr::Float(f) => f, - FloatOrStr::Str(s) => s - .parse::() - .map_err(|_| Exception::custom(r#"key "factor" is not a valid float"#))?, - }; - + let den = get_numeric_from_config(scaling_config.as_ref().unwrap(), "factor")?; 1.0 / den } else { 1.0 @@ -68,9 +258,28 @@ pub fn initialize_rope( .scale(scale) .build() .expect("Infallible"); - return Ok(rope); + return Ok(RopeVariant::Default(rope)); } else if rope_type == FloatOrStr::Str("llama3") { - todo!() + let config = scaling_config + .as_ref() + .ok_or_else(|| Exception::custom("scaling_config is required for llama3 RoPE"))?; + + let factor = get_numeric_from_config(config, "factor")?; + let low_freq_factor = get_numeric_from_config(config, "low_freq_factor")?; + let high_freq_factor = get_numeric_from_config(config, "high_freq_factor")?; + let original_max_position_embeddings = + get_numeric_from_config(config, "original_max_position_embeddings")? as i32; + + let rope = Llama3Rope::new( + dims, + traditional, + original_max_position_embeddings, + base, + factor, + low_freq_factor, + high_freq_factor, + )?; + return Ok(RopeVariant::Llama3(rope)); } else if rope_type == FloatOrStr::Str("yarn") { todo!() } else if rope_type == FloatOrStr::Str("longrope") { @@ -81,11 +290,3 @@ pub fn initialize_rope( "Unsupported RoPE type {rope_type:?}" ))) } - -// // TODO -// #[derive(Debug, Clone, ModuleParameters)] -// pub struct Llama3Rope { -// dims: i32, -// max_position_embeddings: i32, -// traditional: bool, -// }