diff --git a/examples/embedding_gemma.bak/Cargo.toml b/examples/embedding_gemma.bak/Cargo.toml new file mode 100644 index 000000000..ecfa655cc --- /dev/null +++ b/examples/embedding_gemma.bak/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "embedding_gemma" +version.workspace = true +edition.workspace = true +authors.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true +documentation.workspace = true +rust-version.workspace = true + +[dependencies] +mlx-rs.workspace = true +mlx-lm.workspace = true +mlx-lm-utils.workspace = true + +anyhow = "1" \ No newline at end of file diff --git a/examples/embedding_gemma.bak/download.sh b/examples/embedding_gemma.bak/download.sh new file mode 100644 index 000000000..9fa628721 --- /dev/null +++ b/examples/embedding_gemma.bak/download.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +model_id="mlx-community/embeddinggemma-300m-bf16" + +huggingface-cli download $model_id --local-dir ./cache/embeddinggemma-300m-bf16 \ No newline at end of file diff --git a/examples/embedding_gemma.bak/src/main.rs b/examples/embedding_gemma.bak/src/main.rs new file mode 100644 index 000000000..35262858c --- /dev/null +++ b/examples/embedding_gemma.bak/src/main.rs @@ -0,0 +1,84 @@ +use std::path::Path; + +use mlx_lm::{cache::ConcatKeyValueCache, models::gemma::load_embedding_gemma_model}; +use mlx_lm_utils::tokenizer::{ + load_model_chat_template_from_file, ApplyChatTemplateArgs, Conversation, Role, Tokenizer, +}; +use mlx_rs::{ + ops::indexing::{IndexOp, NewAxis}, + transforms::eval, + Array, +}; + +const CACHED_TEST_MODEL_DIR: &str = "./cache/embeddinggemma-300m-bf16"; + +fn qwen3() -> anyhow::Result<()> { + let model_dir = Path::new(CACHED_TEST_MODEL_DIR); + + let model_id = "mlx-community/embeddinggemma-300m-bf16".to_string(); + let tokenizer_file = model_dir.join("tokenizer.json"); + let tokenizer_config_file = model_dir.join("tokenizer_config.json"); + let mut tokenizer = + Tokenizer::from_file(tokenizer_file).map_err(|e| anyhow::anyhow!("{:?}", e))?; + let model_chat_template = load_model_chat_template_from_file(tokenizer_config_file)? + .expect("Model chat template not found"); + + let conversations = vec![Conversation { + role: Role::User, + content: "what's your name?", + }]; + let args = ApplyChatTemplateArgs { + conversations: vec![conversations.into()], + documents: None, + model_id: &model_id, + chat_template_id: None, + add_generation_prompt: None, + continue_final_message: None, + }; + let encodings = tokenizer.apply_chat_template_and_encode(model_chat_template, args)?; + let prompt: Vec = encodings + .iter() + .flat_map(|encoding| encoding.get_ids()) + .copied() + .collect(); + let prompt_tokens = Array::from(&prompt[..]).index(NewAxis); + + let mut cache = Vec::new(); + let mut model = load_qwen3_model(model_dir)?; + let generate = mlx_lm::models::qwen3::Generate::::new( + &mut model, + &mut cache, + 0.2, + &prompt_tokens, + ); + + let mut tokens = Vec::new(); + for (token, ntoks) in generate.zip(0..256) { + let token = token.unwrap(); + tokens.push(token.clone()); + + if ntoks == 0 { + eval(&tokens).unwrap(); + } + + if tokens.len() % 20 == 0 { + eval(&tokens).unwrap(); + let slice: Vec = tokens.drain(..).map(|t| t.item::()).collect(); + let s = tokenizer.decode(&slice, true).unwrap(); + print!("{s}"); + } + } + + eval(&tokens).unwrap(); + let slice: Vec = tokens.drain(..).map(|t| t.item::()).collect(); + let s = tokenizer.decode(&slice, true).unwrap(); + println!("{s}"); + + println!("------"); + + Ok(()) +} + +fn main() -> anyhow::Result<()> { + qwen3() +} diff --git a/examples/gemma3/Cargo.toml b/examples/gemma3/Cargo.toml new file mode 100644 index 000000000..3457f19b9 --- /dev/null +++ b/examples/gemma3/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "gemma3" +version.workspace = true +edition.workspace = true +authors.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true +documentation.workspace = true +rust-version.workspace = true + +[dependencies] +mlx-rs.workspace = true +mlx-lm.workspace = true +mlx-lm-utils.workspace = true + +anyhow = "1" \ No newline at end of file diff --git a/examples/gemma3/download.sh b/examples/gemma3/download.sh new file mode 100644 index 000000000..0e0e32dd1 --- /dev/null +++ b/examples/gemma3/download.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +model_id="mlx-community/gemma-3-270m-bf16" + +huggingface-cli download $model_id --local-dir ./cache/gemma-3-270m-bf16 \ No newline at end of file diff --git a/examples/gemma3/src/main.rs b/examples/gemma3/src/main.rs new file mode 100644 index 000000000..6b40207aa --- /dev/null +++ b/examples/gemma3/src/main.rs @@ -0,0 +1,88 @@ +use std::path::Path; + +use mlx_lm::{cache::ConcatKeyValueCache, models::gemma::gemma3::load_gemma3_model}; +use mlx_lm_utils::tokenizer::{ + load_gemma_chat_template_from_file, ApplyChatTemplateArgs, Conversation, Role, Tokenizer, +}; +use mlx_rs::{ + ops::indexing::{IndexOp, NewAxis}, + transforms::eval, + Array, +}; + +const CACHED_TEST_MODEL_DIR: &str = "./cache/gemma-3-270m-bf16"; + +fn gemma3() -> anyhow::Result<()> { + let model_dir = Path::new(CACHED_TEST_MODEL_DIR); + + let model_id = "mlx-community/gemma-3-270m-bf16".to_string(); + let tokenizer_file = model_dir.join("tokenizer.json"); + let chat_template_jinja_file = model_dir.join("chat_template.jinja"); + let mut tokenizer = + Tokenizer::from_file(tokenizer_file).map_err(|e| anyhow::anyhow!("{:?}", e))?; + let model_chat_template = load_gemma_chat_template_from_file(chat_template_jinja_file)?; + + let conversations = vec![Conversation { + role: Role::User, + content: "what's your name?", + }]; + println!("Conversations: {:?}", conversations); + + let args = ApplyChatTemplateArgs { + conversations: vec![conversations.into()], + documents: None, + model_id: &model_id, + chat_template_id: None, + add_generation_prompt: Some(true), + continue_final_message: None, + add_special_tokens: Some(true), + }; + let encodings = tokenizer.apply_chat_template_and_encode(model_chat_template, args)?; + let prompt: Vec = encodings + .iter() + .flat_map(|encoding| encoding.get_ids()) + .copied() + .collect(); + println!("Prompt tokens (raw): {:?}", prompt); + let prompt_tokens = Array::from(&prompt[..]).index(NewAxis); + println!("Prompt tokens (array): {:?}", prompt_tokens); + + let mut cache = Vec::new(); + let mut model = load_gemma3_model(model_dir)?; + let generate = mlx_lm::models::gemma::gemma3::Generate::::new( + &mut model, + &mut cache, + 0.0, + &prompt_tokens, + ); + + let mut tokens = Vec::new(); + for (token, ntoks) in generate.zip(0..256) { + let token = token.unwrap(); + tokens.push(token.clone()); + + if ntoks == 0 { + eval(&tokens).unwrap(); + } + + if tokens.len() % 20 == 0 { + eval(&tokens).unwrap(); + let slice: Vec = tokens.drain(..).map(|t| t.item::()).collect(); + let s = tokenizer.decode(&slice, true).unwrap(); + print!("{s}"); + } + } + + eval(&tokens).unwrap(); + let slice: Vec = tokens.drain(..).map(|t| t.item::()).collect(); + let s = tokenizer.decode(&slice, true).unwrap(); + println!("{s}"); + + println!("------"); + + Ok(()) +} + +fn main() -> anyhow::Result<()> { + gemma3() +} diff --git a/examples/lm/download.sh b/examples/lm/download.sh new file mode 100644 index 000000000..c9ff9cd91 --- /dev/null +++ b/examples/lm/download.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +model_id="mlx-community/Qwen3-4B-bf16" + +huggingface-cli download $model_id --local-dir ./cache/Qwen3-4B-bf16 \ No newline at end of file diff --git a/mlx-lm-utils/src/tokenizer.rs b/mlx-lm-utils/src/tokenizer.rs index a2675e00c..7fe1c9f28 100644 --- a/mlx-lm-utils/src/tokenizer.rs +++ b/mlx-lm-utils/src/tokenizer.rs @@ -137,9 +137,10 @@ impl Tokenizer { { let Self { inner, env } = self; + let add_special_tokens = args.add_special_tokens.unwrap_or(false); let rendered_chats = apply_chat_template(env, model_template, args)?; inner - .encode_batch(rendered_chats, false) + .encode_batch(rendered_chats, add_special_tokens) .map_err(Into::into) } } @@ -237,6 +238,7 @@ where pub chat_template_id: Option<&'a str>, pub add_generation_prompt: Option, pub continue_final_message: Option, + pub add_special_tokens: Option, } pub fn load_model_chat_template_from_str(content: &str) -> std::io::Result> { @@ -257,6 +259,11 @@ pub fn load_model_chat_template_from_file( load_model_chat_template_from_str(&content) } +pub fn load_gemma_chat_template_from_file(file: impl AsRef) -> std::io::Result { + let content = read_to_string(file)?; + Ok(content) +} + // chat_template = self.get_chat_template(chat_template, tools) // if isinstance(conversation, (list, tuple)) and ( @@ -445,6 +452,7 @@ where chat_template_id, add_generation_prompt, continue_final_message, + add_special_tokens: _, } = args; let add_generation_prompt = add_generation_prompt.unwrap_or(false); @@ -635,6 +643,7 @@ mod tests { chat_template_id: None, add_generation_prompt: None, continue_final_message: None, + add_special_tokens: Some(false), }; let encodings = tokenizer diff --git a/mlx-lm/src/models/gemma/embedding_gemma.rs.bak b/mlx-lm/src/models/gemma/embedding_gemma.rs.bak new file mode 100644 index 000000000..12ff993d3 --- /dev/null +++ b/mlx-lm/src/models/gemma/embedding_gemma.rs.bak @@ -0,0 +1,695 @@ +use std::{ + collections::{HashMap, HashSet}, + path::Path, +}; + +use mlx_rs::{ + argmax_axis, array, + builder::Builder, + categorical, + error::Exception, + macros::{ModuleParameters, Quantizable}, + module::{Module, ModuleParametersExt}, + nn, + ops::indexing::{IndexOp, NewAxis}, + quantization::MaybeQuantized, + Array, +}; +use serde::Deserialize; +use serde_json::Value; +use tokenizers::Tokenizer; + +use crate::{ + cache::KeyValueCache, + error::Error, + utils::{ + create_attention_mask, + rope::{initialize_rope, FloatOrString}, + AttentionMask, + }, +}; + +#[derive(Debug, Clone, Deserialize)] +pub struct ModelArgs { + pub model_type: String, + pub hidden_size: i32, + pub num_hidden_layers: i32, + pub intermediate_size: i32, + pub num_attention_heads: i32, + pub rms_norm_eps: f32, + pub vocab_size: i32, + pub num_key_value_heads: i32, + pub max_position_embeddings: i32, + pub rope_theta: f32, + pub head_dim: i32, + pub tie_word_embeddings: bool, + pub rope_scaling: Option>, +} + +#[derive(Debug, Clone, ModuleParameters, Quantizable)] +pub struct Attention { + pub n_heads: i32, + pub n_kv_heads: i32, + pub scale: f32, + + #[quantizable] + #[param] + pub q_proj: MaybeQuantized, + #[quantizable] + #[param] + pub k_proj: MaybeQuantized, + #[quantizable] + #[param] + pub v_proj: MaybeQuantized, + #[quantizable] + #[param] + pub o_proj: MaybeQuantized, + #[param] + pub q_norm: nn::RmsNorm, + #[param] + pub k_norm: nn::RmsNorm, + #[param] + pub rope: nn::Rope, +} + +impl Attention { + pub fn new(args: &ModelArgs) -> Result { + let dim = args.hidden_size; + let n_heads = args.num_attention_heads; + let n_kv_heads = args.num_key_value_heads; + + let head_dim = args.head_dim; + let scale = (head_dim as f32).sqrt().recip(); + + let q_proj = nn::LinearBuilder::new(dim, n_heads * head_dim) + .bias(false) + .build()?; + let k_proj = nn::LinearBuilder::new(dim, n_kv_heads * head_dim) + .bias(false) + .build()?; + let v_proj = nn::LinearBuilder::new(dim, n_kv_heads * head_dim) + .bias(false) + .build()?; + let o_proj = nn::LinearBuilder::new(n_heads * head_dim, dim) + .bias(false) + .build()?; + + let q_norm = nn::RmsNormBuilder::new(head_dim) + .eps(args.rms_norm_eps) + .build()?; + let k_norm = nn::RmsNormBuilder::new(head_dim) + .eps(args.rms_norm_eps) + .build()?; + + let rope = initialize_rope( + head_dim, + args.rope_theta, + false, + &args.rope_scaling, + args.max_position_embeddings, + )?; + + Ok(Self { + n_heads, + n_kv_heads, + scale, + q_proj: MaybeQuantized::Original(q_proj), + k_proj: MaybeQuantized::Original(k_proj), + v_proj: MaybeQuantized::Original(v_proj), + o_proj: MaybeQuantized::Original(o_proj), + q_norm, + k_norm, + rope, + }) + } +} + +// TODO: check if this input can be generic for other attention modules +pub struct AttentionInput<'a, C> { + pub x: &'a Array, + pub mask: Option<&'a Array>, + pub cache: Option<&'a mut C>, +} + +impl Module> for Attention +where + C: KeyValueCache, +{ + type Output = Array; + + type Error = Exception; + + #[allow(non_snake_case)] + fn forward(&mut self, input: AttentionInput<'_, C>) -> Result { + let AttentionInput { x, mask, mut cache } = input; + + let shape = x.shape(); + let B = shape[0]; + let L = shape[1]; + + let queries = self.q_proj.forward(x)?; + let keys = self.k_proj.forward(x)?; + let values = self.v_proj.forward(x)?; + + let mut queries = self.q_norm.forward( + &queries + .reshape(&[B, L, self.n_heads, -1])? + .transpose_axes(&[0, 2, 1, 3])?, + )?; + let mut keys = self.k_norm.forward( + &keys + .reshape(&[B, L, self.n_kv_heads, -1])? + .transpose_axes(&[0, 2, 1, 3])?, + )?; + let mut values = values + .reshape(&[B, L, self.n_kv_heads, -1])? + .transpose_axes(&[0, 2, 1, 3])?; + + if let Some(cache) = cache.as_mut() { + let q_input = nn::RopeInputBuilder::new(&queries) + .offset(cache.offset()) + .build()?; + queries = self.rope.forward(q_input)?; + let k_input = nn::RopeInputBuilder::new(&keys) + .offset(cache.offset()) + .build()?; + keys = self.rope.forward(k_input)?; + + (keys, values) = cache.update_and_fetch(keys, values)?; + } else { + queries = self.rope.forward(nn::RopeInput::new(&queries))?; + keys = self.rope.forward(nn::RopeInput::new(&keys))?; + } + + let output = crate::utils::scaled_dot_product_attention( + queries, keys, values, cache, self.scale, mask, + )? + .transpose_axes(&[0, 2, 1, 3])? + .reshape(&[B, L, -1])?; + + self.o_proj.forward(&output) + } + + fn training_mode(&mut self, mode: bool) { + self.q_proj.training_mode(mode); + self.k_proj.training_mode(mode); + self.v_proj.training_mode(mode); + self.o_proj.training_mode(mode); + self.q_norm.training_mode(mode); + self.k_norm.training_mode(mode); + >::training_mode(&mut self.rope, mode); + } +} + +#[derive(Debug, Clone, ModuleParameters, Quantizable)] +pub struct Mlp { + #[quantizable] + #[param] + pub gate_proj: MaybeQuantized, + + #[quantizable] + #[param] + pub down_proj: MaybeQuantized, + + #[quantizable] + #[param] + pub up_proj: MaybeQuantized, +} + +impl Mlp { + pub fn new(dim: i32, hidden_dim: i32) -> Result { + let gate_proj = nn::LinearBuilder::new(dim, hidden_dim) + .bias(false) + .build()?; + let down_proj = nn::LinearBuilder::new(hidden_dim, dim) + .bias(false) + .build()?; + let up_proj = nn::LinearBuilder::new(dim, hidden_dim) + .bias(false) + .build()?; + + Ok(Self { + gate_proj: MaybeQuantized::Original(gate_proj), + down_proj: MaybeQuantized::Original(down_proj), + up_proj: MaybeQuantized::Original(up_proj), + }) + } +} + +impl Module<&Array> for Mlp { + type Output = Array; + + type Error = Exception; + + fn forward(&mut self, input: &Array) -> Result { + let down_proj_input = + nn::silu(self.gate_proj.forward(input)?)?.multiply(self.up_proj.forward(input)?)?; + self.down_proj.forward(&down_proj_input) + } + + fn training_mode(&mut self, mode: bool) { + self.gate_proj.training_mode(mode); + self.down_proj.training_mode(mode); + self.up_proj.training_mode(mode); + } +} + +#[derive(Debug, Clone, ModuleParameters, Quantizable)] +pub struct TransformerBlock { + pub num_attention_heads: i32, + pub hidden_size: i32, + + #[quantizable] + #[param] + pub self_attn: Attention, + + #[quantizable] + #[param] + pub mlp: Mlp, + + #[param] + pub input_layernorm: nn::RmsNorm, + + #[param] + pub post_attention_layernorm: nn::RmsNorm, +} + +impl TransformerBlock { + pub fn new(args: &ModelArgs) -> Result { + let num_attention_heads = args.num_attention_heads; + let hidden_size = args.hidden_size; + + let self_attn = Attention::new(args)?; + let mlp = Mlp::new(args.hidden_size, args.intermediate_size)?; + let input_layernorm = nn::RmsNormBuilder::new(args.hidden_size) + .eps(args.rms_norm_eps) + .build()?; + let post_attention_layernorm = nn::RmsNormBuilder::new(args.hidden_size) + .eps(args.rms_norm_eps) + .build()?; + + Ok(Self { + num_attention_heads, + hidden_size, + self_attn, + mlp, + input_layernorm, + post_attention_layernorm, + }) + } +} + +impl Module> for TransformerBlock +where + C: KeyValueCache, +{ + type Output = Array; + + type Error = Exception; + + fn forward(&mut self, input: AttentionInput<'_, C>) -> Result { + let AttentionInput { x, mask, cache } = input; + + let self_attn_input = AttentionInput { + x: &self.input_layernorm.forward(x)?, + mask, + cache, + }; + let r = self.self_attn.forward(self_attn_input)?; + let h = x.add(r)?; + + let r = self + .mlp + .forward(&self.post_attention_layernorm.forward(&h)?)?; + h.add(r) + } + + fn training_mode(&mut self, mode: bool) { + >>::training_mode(&mut self.self_attn, mode); + self.mlp.training_mode(mode); + self.input_layernorm.training_mode(mode); + self.post_attention_layernorm.training_mode(mode); + } +} + +#[derive(Debug, Clone, ModuleParameters, Quantizable)] +pub struct Qwen3Model { + pub vocab_size: i32, + pub num_hidden_layers: i32, + + #[quantizable] + #[param] + pub embed_tokens: MaybeQuantized, + + #[quantizable] + #[param] + pub layers: Vec, + + #[param] + pub norm: nn::RmsNorm, +} + +impl Qwen3Model { + pub fn new(args: &ModelArgs) -> Result { + assert!(args.vocab_size.is_positive()); + + let vocab_size = args.vocab_size; + let num_hidden_layers = args.num_hidden_layers; + + let embed_tokens = nn::Embedding::new(args.vocab_size, args.hidden_size)?; + let layers = (0..num_hidden_layers) + .map(|_| TransformerBlock::new(args)) + .collect::, _>>()?; + let norm = nn::RmsNormBuilder::new(args.hidden_size) + .eps(args.rms_norm_eps) + .build()?; + + Ok(Self { + vocab_size, + num_hidden_layers, + embed_tokens: MaybeQuantized::Original(embed_tokens), + layers, + norm, + }) + } +} + +pub struct ModelInput<'a, C> { + pub inputs: &'a Array, + pub mask: Option<&'a Array>, + pub cache: &'a mut Vec>, +} + +impl Module> for Qwen3Model +where + C: KeyValueCache, +{ + type Output = Array; + + type Error = Exception; + + fn forward(&mut self, input: ModelInput<'_, C>) -> Result { + let ModelInput { + inputs, + mask, + cache, + } = input; + + let mut h = self.embed_tokens.forward(inputs)?; + + let mask = match mask { + Some(mask) => Some(mask.clone()), + None => match create_attention_mask(&h, cache, Some(true))? { + Some(AttentionMask::Array(a)) => Some(a), + Some(AttentionMask::Causal) => { + return Err(Exception::custom("Only `Array` mask is supported")) + } + None => None, + }, + }; + + if cache.is_empty() { + *cache = (0..self.layers.len()).map(|_| None).collect(); + } + + for (layer, c) in self.layers.iter_mut().zip(cache.iter_mut()) { + let layer_input = AttentionInput { + x: &h, + mask: mask.as_ref(), + cache: c.as_mut(), + }; + h = layer.forward(layer_input)?; + } + + self.norm.forward(&h) + } + + fn training_mode(&mut self, mode: bool) { + self.embed_tokens.training_mode(mode); + for layer in &mut self.layers { + >>::training_mode(layer, mode); + } + self.norm.training_mode(mode); + } +} + +#[derive(Debug, Clone, ModuleParameters, Quantizable)] +pub struct Model { + pub args: ModelArgs, + + #[quantizable] + #[param] + pub model: Qwen3Model, + + #[quantizable] + #[param] + pub lm_head: Option>, +} + +impl Model { + pub fn new(args: ModelArgs) -> Result { + let model = Qwen3Model::new(&args)?; + let lm_head = if !args.tie_word_embeddings { + Some(MaybeQuantized::Original( + nn::LinearBuilder::new(args.hidden_size, args.vocab_size) + .bias(false) + .build()?, + )) + } else { + None + }; + + Ok(Self { + args, + model, + lm_head, + }) + } + + pub fn model_type(&self) -> &str { + &self.args.model_type + } +} + +impl Module> for Model +where + C: KeyValueCache, +{ + type Output = Array; + + type Error = Exception; + + fn forward(&mut self, input: ModelInput<'_, C>) -> Result { + let out = self.model.forward(input)?; + + match self.lm_head.as_mut() { + Some(lm_head) => lm_head.forward(&out), + None => match &mut self.model.embed_tokens { + MaybeQuantized::Original(embed_tokens) => embed_tokens.as_linear(&out), + MaybeQuantized::Quantized(q_embed_tokens) => q_embed_tokens.as_linear(&out), + }, + } + } + + fn training_mode(&mut self, mode: bool) { + >>::training_mode(&mut self.model, mode); + if let Some(lm_head) = &mut self.lm_head { + lm_head.training_mode(mode); + } + } +} + +pub fn load_qwen3_tokenizer(model_dir: impl AsRef) -> Result { + let file = model_dir.as_ref().join("tokenizer.json"); + Tokenizer::from_file(file).map_err(Into::into) +} + +pub fn get_qwen3_model_args(model_dir: impl AsRef) -> Result { + let model_args_filename = model_dir.as_ref().join("config.json"); + let file = std::fs::File::open(model_args_filename)?; + let model_args: ModelArgs = serde_json::from_reader(file)?; + + Ok(model_args) +} + +#[derive(Debug, Clone, Deserialize)] +pub struct WeightMap { + pub metadata: HashMap, + pub weight_map: HashMap, +} + +pub fn load_qwen3_model(model_dir: impl AsRef) -> Result { + let model_dir = model_dir.as_ref(); + let model_args = get_qwen3_model_args(model_dir)?; + let mut model = Model::new(model_args)?; + + let weights_index = model_dir.join("model.safetensors.index.json"); + let json = std::fs::read_to_string(weights_index)?; + let weight_map: WeightMap = serde_json::from_str(&json)?; + + let weight_files: HashSet<&String> = weight_map.weight_map.values().collect(); + + for weight_file in weight_files { + let weights_filename = model_dir.join(weight_file); + model.load_safetensors(weights_filename)?; + } + + Ok(model) +} + +pub fn sample(logits: &Array, temp: f32) -> Result { + match temp { + 0.0 => argmax_axis!(logits, -1).map_err(Into::into), + _ => { + let logits = logits.multiply(array!(1.0 / temp))?; + categorical!(logits).map_err(Into::into) + } + } +} + +pub struct Generate<'a, C> { + model: &'a mut Model, + cache: &'a mut Vec>, + temp: f32, + state: GenerateState<'a>, +} + +impl<'a, C> Generate<'a, C> +where + C: KeyValueCache, +{ + pub fn new( + model: &'a mut Model, + cache: &'a mut Vec>, + temp: f32, + prompt_token: &'a Array, + ) -> Self { + Self { + model, + cache, + temp, + state: GenerateState::Prefill { prompt_token }, + } + } +} + +pub enum GenerateState<'a> { + Prefill { prompt_token: &'a Array }, + Decode { y: Array }, +} + +macro_rules! tri { + ($expr:expr) => { + match $expr { + Ok(val) => val, + Err(e) => return Some(Err(e.into())), + } + }; +} + +impl<'a, C> Iterator for Generate<'a, C> +where + C: KeyValueCache, +{ + type Item = Result; + + fn next(&mut self) -> Option { + match &self.state { + GenerateState::Prefill { prompt_token } => { + let input = ModelInput { + inputs: prompt_token, + mask: None, + cache: self.cache, + }; + let logits = tri!(self.model.forward(input)); + let y = tri!(sample(&logits.index((.., -1, ..)), self.temp)); + self.state = GenerateState::Decode { y: y.clone() }; + + Some(Ok(y)) + } + GenerateState::Decode { y } => { + let inputs = y.index((.., NewAxis)); + let input = ModelInput { + inputs: &inputs, + mask: None, + cache: self.cache, + }; + let logits = tri!(self.model.forward(input)); + let y = tri!(sample(&logits, self.temp)); + + self.state = GenerateState::Decode { y: y.clone() }; + + Some(Ok(y)) + } + } + } +} + +#[cfg(test)] +mod tests { + use mlx_rs::{ + ops::indexing::{IndexOp, NewAxis}, + transforms::eval, + Array, + }; + + use crate::{ + cache::ConcatKeyValueCache, + models::qwen3::{load_qwen3_model, load_qwen3_tokenizer}, + }; + + const CACHED_TEST_MODEL_DIR: &str = "../cache/Qwen3-4B-bf16"; + + #[test] + fn test_load_qwen3_model() { + let _model = super::load_qwen3_model(CACHED_TEST_MODEL_DIR).unwrap(); + } + + #[test] + fn test_load_tokenizer() { + let tokenizer = load_qwen3_tokenizer(CACHED_TEST_MODEL_DIR).unwrap(); + + let _encoding = tokenizer.encode("Hello, world!", true).unwrap(); + } + + #[test] + fn test_load_and_run_qwen3_with_concat_cache() { + let tokenizer = load_qwen3_tokenizer(CACHED_TEST_MODEL_DIR).unwrap(); + + let mut model = load_qwen3_model(CACHED_TEST_MODEL_DIR).unwrap(); + + let encoding = tokenizer.encode("hello", true).unwrap(); + let prompt_tokens = Array::from(encoding.get_ids()).index(NewAxis); + let mut cache = Vec::new(); + + let mut tokens = Vec::new(); + let generate = super::Generate::::new( + &mut model, + &mut cache, + 0.0, + &prompt_tokens, + ); + for (token, ntoks) in generate.zip(0..10) { + let token = token.unwrap(); + tokens.push(token.clone()); + + if ntoks == 0 { + eval(&tokens).unwrap(); + } + + if tokens.len() % 20 == 0 { + eval(&tokens).unwrap(); + let slice: Vec = tokens.drain(..).map(|t| t.item::()).collect(); + let s = tokenizer.decode(&slice, true).unwrap(); + print!("{s}"); + } + } + + eval(&tokens).unwrap(); + let slice: Vec = tokens.drain(..).map(|t| t.item::()).collect(); + let s = tokenizer.decode(&slice, true).unwrap(); + println!("{s}"); + + println!("------"); + } +} diff --git a/mlx-lm/src/models/gemma/gemma3.rs b/mlx-lm/src/models/gemma/gemma3.rs new file mode 100644 index 000000000..72bbae963 --- /dev/null +++ b/mlx-lm/src/models/gemma/gemma3.rs @@ -0,0 +1,805 @@ +use std::{ + collections::{HashMap, HashSet}, + path::Path, +}; + +use mlx_rs::{ + argmax_axis, array, + builder::Builder, + categorical, + error::Exception, + macros::{ModuleParameters, Quantizable}, + module::{Module, ModuleParametersExt}, + nn, + ops::indexing::{IndexOp, NewAxis}, + quantization::MaybeQuantized, + Array, +}; +use serde::Deserialize; +use serde_json::Value; +use tokenizers::Tokenizer; + +use crate::{ + cache::KeyValueCache, + error::Error, + utils::{ + create_attention_mask, + rope::{initialize_rope, FloatOrString}, + AttentionMask, + }, +}; + +#[derive(Debug, Clone, Deserialize)] +pub struct ModelArgs { + pub model_type: String, + pub hidden_size: i32, + pub num_hidden_layers: i32, + pub intermediate_size: i32, + pub num_attention_heads: i32, + pub rms_norm_eps: f32, + pub vocab_size: i32, + pub num_key_value_heads: i32, + pub max_position_embeddings: i32, + pub rope_theta: f32, + pub rope_local_base_freq: f32, + pub head_dim: i32, + #[serde(default = "default_tie_word_embeddings")] + pub tie_word_embeddings: bool, + pub rope_scaling: Option>, + pub _sliding_window_pattern: i32, + pub sliding_window: i32, + pub use_bidirectional_attention: bool, + pub query_pre_attn_scalar: i32, + pub attn_logit_softcapping: Option, + pub final_logit_softcapping: Option, + pub layer_types: Vec, +} + +fn default_tie_word_embeddings() -> bool { + false +} + +#[derive(Debug, Clone, ModuleParameters, Quantizable)] +pub struct Attention { + pub n_heads: i32, + pub n_kv_heads: i32, + pub scale: f32, + pub is_sliding: bool, + + #[quantizable] + #[param] + pub q_proj: MaybeQuantized, + #[quantizable] + #[param] + pub k_proj: MaybeQuantized, + #[quantizable] + #[param] + pub v_proj: MaybeQuantized, + #[quantizable] + #[param] + pub o_proj: MaybeQuantized, + #[param] + pub q_norm: nn::RmsNorm, + #[param] + pub k_norm: nn::RmsNorm, + #[param] + pub rope: nn::Rope, + #[param] + pub rope_local: nn::Rope, +} + +impl Attention { + pub fn new(args: &ModelArgs, layer_idx: i32) -> Result { + let is_sliding = args.layer_types[layer_idx as usize] == "sliding_window"; + let dim = args.hidden_size; + let n_heads = args.num_attention_heads; + let n_kv_heads = args.num_key_value_heads; + + let head_dim = args.head_dim; + let query_pre_attn_scalar = args.query_pre_attn_scalar; + let scale = (query_pre_attn_scalar as f32).sqrt().recip(); + + let q_proj = nn::LinearBuilder::new(dim, n_heads * head_dim) + .bias(false) + .build()?; + let k_proj = nn::LinearBuilder::new(dim, n_kv_heads * head_dim) + .bias(false) + .build()?; + let v_proj = nn::LinearBuilder::new(dim, n_kv_heads * head_dim) + .bias(false) + .build()?; + let o_proj = nn::LinearBuilder::new(n_heads * head_dim, dim) + .bias(false) + .build()?; + + let q_norm = nn::RmsNormBuilder::new(head_dim) + .eps(args.rms_norm_eps) + .build()?; + let k_norm = nn::RmsNormBuilder::new(head_dim) + .eps(args.rms_norm_eps) + .build()?; + + let rope = initialize_rope( + head_dim, + args.rope_theta, + false, + &args.rope_scaling, + args.max_position_embeddings, + )?; + + let rope_local = initialize_rope( + head_dim, + args.rope_local_base_freq, + false, + &args.rope_scaling, + args.max_position_embeddings, + )?; + + Ok(Self { + n_heads, + n_kv_heads, + scale, + q_proj: MaybeQuantized::Original(q_proj), + k_proj: MaybeQuantized::Original(k_proj), + v_proj: MaybeQuantized::Original(v_proj), + o_proj: MaybeQuantized::Original(o_proj), + q_norm, + k_norm, + is_sliding, + rope, + rope_local, + }) + } +} + +// TODO: check if this input can be generic for other attention modules +pub struct AttentionInput<'a, C> { + pub x: &'a Array, + pub mask: Option<&'a Array>, + pub cache: Option<&'a mut C>, + pub position_embeddings_global: Option<&'a Array>, + pub position_embeddings_local: Option<&'a Array>, +} + +impl Module> for Attention +where + C: KeyValueCache, +{ + type Output = Array; + + type Error = Exception; + + #[allow(non_snake_case)] + fn forward(&mut self, input: AttentionInput<'_, C>) -> Result { + let AttentionInput { + x, + mask, + mut cache, + position_embeddings_global, + position_embeddings_local, + } = input; + + let shape = x.shape(); + let B = shape[0]; + let L = shape[1]; + + let queries = self.q_proj.forward(x)?; + let keys = self.k_proj.forward(x)?; + let values = self.v_proj.forward(x)?; + + let mut queries = self.q_norm.forward( + &queries + .reshape(&[B, L, self.n_heads, -1])? + .transpose_axes(&[0, 2, 1, 3])?, + )?; + let mut keys = self.k_norm.forward( + &keys + .reshape(&[B, L, self.n_kv_heads, -1])? + .transpose_axes(&[0, 2, 1, 3])?, + )?; + let mut values = values + .reshape(&[B, L, self.n_kv_heads, -1])? + .transpose_axes(&[0, 2, 1, 3])?; + + if let Some(cache) = cache.as_mut() { + let q_input = nn::RopeInputBuilder::new(&queries) + .offset(cache.offset()) + .build()?; + queries = self.rope.forward(q_input)?; + let k_input = nn::RopeInputBuilder::new(&keys) + .offset(cache.offset()) + .build()?; + keys = self.rope.forward(k_input)?; + + (keys, values) = cache.update_and_fetch(keys, values)?; + } else { + queries = self.rope.forward(nn::RopeInput::new(&queries))?; + keys = self.rope.forward(nn::RopeInput::new(&keys))?; + } + + let output = crate::utils::scaled_dot_product_attention( + queries, keys, values, cache, self.scale, mask, + )? + .transpose_axes(&[0, 2, 1, 3])? + .reshape(&[B, L, -1])?; + + self.o_proj.forward(&output) + } + + fn training_mode(&mut self, mode: bool) { + self.q_proj.training_mode(mode); + self.k_proj.training_mode(mode); + self.v_proj.training_mode(mode); + self.o_proj.training_mode(mode); + self.q_norm.training_mode(mode); + self.k_norm.training_mode(mode); + >::training_mode(&mut self.rope, mode); + >::training_mode(&mut self.rope_local, mode); + } +} + +#[derive(Debug, Clone, ModuleParameters, Quantizable)] +pub struct Mlp { + #[quantizable] + #[param] + pub gate_proj: MaybeQuantized, + + #[quantizable] + #[param] + pub down_proj: MaybeQuantized, + + #[quantizable] + #[param] + pub up_proj: MaybeQuantized, +} + +impl Mlp { + pub fn new(dim: i32, hidden_dim: i32) -> Result { + let gate_proj = nn::LinearBuilder::new(dim, hidden_dim) + .bias(false) + .build()?; + let down_proj = nn::LinearBuilder::new(hidden_dim, dim) + .bias(false) + .build()?; + let up_proj = nn::LinearBuilder::new(dim, hidden_dim) + .bias(false) + .build()?; + + Ok(Self { + gate_proj: MaybeQuantized::Original(gate_proj), + down_proj: MaybeQuantized::Original(down_proj), + up_proj: MaybeQuantized::Original(up_proj), + }) + } +} + +impl Module<&Array> for Mlp { + type Output = Array; + + type Error = Exception; + + fn forward(&mut self, input: &Array) -> Result { + let down_proj_input = nn::gelu_approximate(self.gate_proj.forward(input)?)? + .multiply(self.up_proj.forward(input)?)?; + self.down_proj.forward(&down_proj_input) + } + + fn training_mode(&mut self, mode: bool) { + self.gate_proj.training_mode(mode); + self.down_proj.training_mode(mode); + self.up_proj.training_mode(mode); + } +} + +#[derive(Debug, Clone, ModuleParameters, Quantizable)] +pub struct TransformerBlock { + pub num_attention_heads: i32, + pub hidden_size: i32, + + #[quantizable] + #[param] + pub self_attn: Attention, + + #[quantizable] + #[param] + pub mlp: Mlp, + + #[param] + pub input_layernorm: nn::RmsNorm, + + #[param] + pub post_attention_layernorm: nn::RmsNorm, + + #[param] + pub pre_feedforward_layernorm: nn::RmsNorm, + + #[param] + pub post_feedforward_layernorm: nn::RmsNorm, +} + +impl TransformerBlock { + pub fn new(args: &ModelArgs, layer_idx: i32) -> Result { + let num_attention_heads = args.num_attention_heads; + let hidden_size = args.hidden_size; + + let self_attn = Attention::new(args, layer_idx)?; + let mlp = Mlp::new(args.hidden_size, args.intermediate_size)?; + let input_layernorm = nn::RmsNormBuilder::new(args.hidden_size) + .eps(args.rms_norm_eps) + .build()?; + let post_attention_layernorm = nn::RmsNormBuilder::new(args.hidden_size) + .eps(args.rms_norm_eps) + .build()?; + let pre_feedforward_layernorm = nn::RmsNormBuilder::new(args.hidden_size) + .eps(args.rms_norm_eps) + .build()?; + let post_feedforward_layernorm = nn::RmsNormBuilder::new(args.hidden_size) + .eps(args.rms_norm_eps) + .build()?; + + Ok(Self { + num_attention_heads, + hidden_size, + self_attn, + mlp, + input_layernorm, + post_attention_layernorm, + pre_feedforward_layernorm, + post_feedforward_layernorm, + }) + } +} + +impl Module> for TransformerBlock +where + C: KeyValueCache, +{ + type Output = Array; + + type Error = Exception; + + fn forward(&mut self, input: AttentionInput<'_, C>) -> Result { + let AttentionInput { + x, + mask, + cache, + position_embeddings_global, + position_embeddings_local, + } = input; + + let self_attn_input = AttentionInput { + x: &self.input_layernorm.forward(x)?, + mask, + cache, + position_embeddings_global, + position_embeddings_local, + }; + let r = self + .post_attention_layernorm + .forward(&self.self_attn.forward(self_attn_input)?)?; + let h = x.add(r)?; + + let r = self.post_feedforward_layernorm.forward( + &self + .mlp + .forward(&self.pre_feedforward_layernorm.forward(&h)?)?, + )?; + h.add(r) + } + + fn training_mode(&mut self, mode: bool) { + >>::training_mode(&mut self.self_attn, mode); + self.mlp.training_mode(mode); + self.input_layernorm.training_mode(mode); + self.post_attention_layernorm.training_mode(mode); + self.pre_feedforward_layernorm.training_mode(mode); + self.post_feedforward_layernorm.training_mode(mode); + } +} + +#[derive(Debug, Clone, ModuleParameters, Quantizable)] +pub struct Gemma3Model { + pub vocab_size: i32, + pub num_hidden_layers: i32, + + #[quantizable] + #[param] + pub embed_tokens: MaybeQuantized, + pub embed_scale: f32, + + #[quantizable] + #[param] + pub layers: Vec, + + #[param] + pub norm: nn::RmsNorm, + + #[param] + pub rope: nn::Rope, + #[param] + pub rope_local: nn::Rope, +} + +impl Gemma3Model { + pub fn new(args: &ModelArgs) -> Result { + assert!(args.vocab_size.is_positive()); + + let vocab_size = args.vocab_size; + let num_hidden_layers = args.num_hidden_layers; + + let embed_tokens = nn::Embedding::new(args.vocab_size, args.hidden_size)?; + let embed_scale = (args.hidden_size as f32).sqrt(); + let layers = (0..num_hidden_layers) + .map(|layer_idx| TransformerBlock::new(args, layer_idx)) + .collect::, _>>()?; + let norm = nn::RmsNormBuilder::new(args.hidden_size) + .eps(args.rms_norm_eps) + .build()?; + let rope = initialize_rope( + args.hidden_size, + args.rope_theta, + false, + &args.rope_scaling, + args.max_position_embeddings, + )?; + let rope_local = initialize_rope( + args.hidden_size, + args.rope_local_base_freq, + false, + &args.rope_scaling, + args.max_position_embeddings, + )?; + + Ok(Self { + vocab_size, + num_hidden_layers, + embed_tokens: MaybeQuantized::Original(embed_tokens), + embed_scale, + layers, + norm, + rope, + rope_local, + }) + } +} + +pub struct ModelInput<'a, C> { + pub inputs: &'a Array, + pub mask: Option<&'a Array>, + pub cache: &'a mut Vec>, + pub position_ids: Option<&'a Array>, +} + +impl Module> for Gemma3Model +where + C: KeyValueCache, +{ + type Output = Array; + + type Error = Exception; + + fn forward(&mut self, input: ModelInput<'_, C>) -> Result { + let ModelInput { + inputs, + mask, + cache, + position_ids, + } = input; + + let mut h = self + .embed_tokens + .forward(inputs)? + .multiply(array!(self.embed_scale))?; + + let mask = match mask { + Some(mask) => Some(mask.clone()), + None => match create_attention_mask(&h, cache, Some(true))? { + Some(AttentionMask::Array(a)) => Some(a), + Some(AttentionMask::Causal) => { + return Err(Exception::custom("Only `Array` mask is supported")) + } + None => None, + }, + }; + + if cache.is_empty() { + *cache = (0..self.layers.len()).map(|_| None).collect(); + } + + let position_embeddings_global = if let Some(position_ids) = position_ids { + Some(self.rope.forward(nn::RopeInput::new(position_ids))?) + } else { + None + }; + let position_embeddings_local = if let Some(position_ids) = position_ids { + Some(self.rope_local.forward(nn::RopeInput::new(position_ids))?) + } else { + None + }; + + for (layer, c) in self.layers.iter_mut().zip(cache.iter_mut()) { + let layer_input = AttentionInput { + x: &h, + mask: mask.as_ref(), + cache: c.as_mut(), + position_embeddings_global: position_embeddings_global.as_ref(), + position_embeddings_local: position_embeddings_local.as_ref(), + }; + h = layer.forward(layer_input)?; + } + + self.norm.forward(&h) + } + + fn training_mode(&mut self, mode: bool) { + self.embed_tokens.training_mode(mode); + for layer in &mut self.layers { + >>::training_mode(layer, mode); + } + self.norm.training_mode(mode); + } +} + +#[derive(Debug, Clone, ModuleParameters, Quantizable)] +pub struct Model { + pub args: ModelArgs, + + #[quantizable] + #[param] + pub model: Gemma3Model, + + #[quantizable] + #[param] + pub lm_head: Option>, +} + +impl Model { + pub fn new(args: ModelArgs) -> Result { + let model = Gemma3Model::new(&args)?; + let lm_head = if !args.tie_word_embeddings { + Some(MaybeQuantized::Original( + nn::LinearBuilder::new(args.hidden_size, args.vocab_size) + .bias(false) + .build()?, + )) + } else { + None + }; + + Ok(Self { + args, + model, + lm_head, + }) + } + + pub fn model_type(&self) -> &str { + &self.args.model_type + } +} + +impl Module> for Model +where + C: KeyValueCache, +{ + type Output = Array; + + type Error = Exception; + + fn forward(&mut self, input: ModelInput<'_, C>) -> Result { + let out = self.model.forward(input)?; + + match self.lm_head.as_mut() { + Some(lm_head) => lm_head.forward(&out), + None => match &mut self.model.embed_tokens { + MaybeQuantized::Original(embed_tokens) => embed_tokens.as_linear(&out), + MaybeQuantized::Quantized(q_embed_tokens) => q_embed_tokens.as_linear(&out), + }, + } + } + + fn training_mode(&mut self, mode: bool) { + >>::training_mode(&mut self.model, mode); + if let Some(lm_head) = &mut self.lm_head { + lm_head.training_mode(mode); + } + } +} + +pub fn load_gemma3_tokenizer(model_dir: impl AsRef) -> Result { + let file = model_dir.as_ref().join("tokenizer.json"); + Tokenizer::from_file(file).map_err(Into::into) +} + +pub fn get_gemma3_model_args(model_dir: impl AsRef) -> Result { + let model_args_filename = model_dir.as_ref().join("config.json"); + let file = std::fs::File::open(model_args_filename)?; + let model_args: ModelArgs = serde_json::from_reader(file)?; + + Ok(model_args) +} + +#[derive(Debug, Clone, Deserialize)] +pub struct WeightMap { + pub metadata: HashMap, + pub weight_map: HashMap, +} + +pub fn load_gemma3_model(model_dir: impl AsRef) -> Result { + let model_dir = model_dir.as_ref(); + let model_args = get_gemma3_model_args(model_dir)?; + println!("Model args: {:?}", model_args); + let mut model = Model::new(model_args)?; + + let weights_index = model_dir.join("model.safetensors.index.json"); + let json = std::fs::read_to_string(weights_index)?; + let weight_map: WeightMap = serde_json::from_str(&json)?; + + let weight_files: HashSet<&String> = weight_map.weight_map.values().collect(); + + for weight_file in weight_files { + let weights_filename = model_dir.join(weight_file); + model.load_safetensors(weights_filename)?; + } + + Ok(model) +} + +pub fn sample(logits: &Array, temp: f32) -> Result { + match temp { + 0.0 => argmax_axis!(logits, -1).map_err(Into::into), + _ => { + let logits = logits.multiply(array!(1.0 / temp))?; + categorical!(logits).map_err(Into::into) + } + } +} + +pub struct Generate<'a, C> { + model: &'a mut Model, + cache: &'a mut Vec>, + temp: f32, + state: GenerateState<'a>, +} + +impl<'a, C> Generate<'a, C> +where + C: KeyValueCache, +{ + pub fn new( + model: &'a mut Model, + cache: &'a mut Vec>, + temp: f32, + prompt_token: &'a Array, + ) -> Self { + Self { + model, + cache, + temp, + state: GenerateState::Prefill { prompt_token }, + } + } +} + +pub enum GenerateState<'a> { + Prefill { prompt_token: &'a Array }, + Decode { y: Array }, +} + +macro_rules! tri { + ($expr:expr) => { + match $expr { + Ok(val) => val, + Err(e) => return Some(Err(e.into())), + } + }; +} + +impl<'a, C> Iterator for Generate<'a, C> +where + C: KeyValueCache, +{ + type Item = Result; + + fn next(&mut self) -> Option { + match &self.state { + GenerateState::Prefill { prompt_token } => { + let input = ModelInput { + inputs: prompt_token, + mask: None, + cache: self.cache, + position_ids: None, + }; + let logits = tri!(self.model.forward(input)); + let y = tri!(sample(&logits.index((.., -1, ..)), self.temp)); + self.state = GenerateState::Decode { y: y.clone() }; + + Some(Ok(y)) + } + GenerateState::Decode { y } => { + let inputs = y.index((.., NewAxis)); + let input = ModelInput { + inputs: &inputs, + mask: None, + cache: self.cache, + position_ids: None, + }; + let logits = tri!(self.model.forward(input)); + let y = tri!(sample(&logits, self.temp)); + + self.state = GenerateState::Decode { y: y.clone() }; + + Some(Ok(y)) + } + } + } +} + +#[cfg(test)] +mod tests { + use mlx_rs::{ + ops::indexing::{IndexOp, NewAxis}, + transforms::eval, + Array, + }; + + use crate::{ + cache::ConcatKeyValueCache, + models::gemma::gemma3::{load_gemma3_model, load_gemma3_tokenizer}, + }; + + const CACHED_TEST_MODEL_DIR: &str = "../cache/gemma-3-270m-bf16"; + + #[test] + fn test_load_gemma3_model() { + let _model = super::load_gemma3_model(CACHED_TEST_MODEL_DIR).unwrap(); + } + + #[test] + fn test_load_tokenizer() { + let tokenizer = load_gemma3_tokenizer(CACHED_TEST_MODEL_DIR).unwrap(); + + let _encoding = tokenizer.encode("Hello, world!", true).unwrap(); + } + + #[test] + fn test_load_and_run_gemma3_with_concat_cache() { + let tokenizer = load_gemma3_tokenizer(CACHED_TEST_MODEL_DIR).unwrap(); + + let mut model = load_gemma3_model(CACHED_TEST_MODEL_DIR).unwrap(); + + let encoding = tokenizer.encode("hello", true).unwrap(); + let prompt_tokens = Array::from(encoding.get_ids()).index(NewAxis); + let mut cache = Vec::new(); + + let mut tokens = Vec::new(); + let generate = super::Generate::::new( + &mut model, + &mut cache, + 0.0, + &prompt_tokens, + ); + for (token, ntoks) in generate.zip(0..10) { + let token = token.unwrap(); + tokens.push(token.clone()); + + if ntoks == 0 { + eval(&tokens).unwrap(); + } + + if tokens.len() % 20 == 0 { + eval(&tokens).unwrap(); + let slice: Vec = tokens.drain(..).map(|t| t.item::()).collect(); + let s = tokenizer.decode(&slice, true).unwrap(); + print!("{s}"); + } + } + + eval(&tokens).unwrap(); + let slice: Vec = tokens.drain(..).map(|t| t.item::()).collect(); + let s = tokenizer.decode(&slice, true).unwrap(); + println!("{s}"); + + println!("------"); + } +} diff --git a/mlx-lm/src/models/gemma/mod.rs b/mlx-lm/src/models/gemma/mod.rs new file mode 100644 index 000000000..7dd4f1916 --- /dev/null +++ b/mlx-lm/src/models/gemma/mod.rs @@ -0,0 +1 @@ +pub mod gemma3; diff --git a/mlx-lm/src/models/mod.rs b/mlx-lm/src/models/mod.rs index 2e427b96b..127c94047 100644 --- a/mlx-lm/src/models/mod.rs +++ b/mlx-lm/src/models/mod.rs @@ -1 +1,2 @@ +pub mod gemma; pub mod qwen3; diff --git a/mlx-rs/src/nn/normalization.rs b/mlx-rs/src/nn/normalization.rs index 41360cc45..1091268a8 100644 --- a/mlx-rs/src/nn/normalization.rs +++ b/mlx-rs/src/nn/normalization.rs @@ -264,7 +264,8 @@ impl Module<&Array> for RmsNorm { type Output = Array; fn forward(&mut self, x: &Array) -> Result { - let weight = self.weight.as_ref(); + let weight = self.weight.as_ref().add(array!(1.0))?; + // let weight = self.weight.as_ref(); let eps = self.eps; crate::fast::rms_norm(x, weight, eps) }