Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions examples/embedding_gemma.bak/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[package]
name = "embedding_gemma"
version.workspace = true
edition.workspace = true
authors.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true
documentation.workspace = true
rust-version.workspace = true

[dependencies]
mlx-rs.workspace = true
mlx-lm.workspace = true
mlx-lm-utils.workspace = true

anyhow = "1"
5 changes: 5 additions & 0 deletions examples/embedding_gemma.bak/download.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash

model_id="mlx-community/embeddinggemma-300m-bf16"

huggingface-cli download $model_id --local-dir ./cache/embeddinggemma-300m-bf16
84 changes: 84 additions & 0 deletions examples/embedding_gemma.bak/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use std::path::Path;

use mlx_lm::{cache::ConcatKeyValueCache, models::gemma::load_embedding_gemma_model};
use mlx_lm_utils::tokenizer::{
load_model_chat_template_from_file, ApplyChatTemplateArgs, Conversation, Role, Tokenizer,
};
use mlx_rs::{
ops::indexing::{IndexOp, NewAxis},
transforms::eval,
Array,
};

const CACHED_TEST_MODEL_DIR: &str = "./cache/embeddinggemma-300m-bf16";

fn qwen3() -> anyhow::Result<()> {
let model_dir = Path::new(CACHED_TEST_MODEL_DIR);

let model_id = "mlx-community/embeddinggemma-300m-bf16".to_string();
let tokenizer_file = model_dir.join("tokenizer.json");
let tokenizer_config_file = model_dir.join("tokenizer_config.json");
let mut tokenizer =
Tokenizer::from_file(tokenizer_file).map_err(|e| anyhow::anyhow!("{:?}", e))?;
let model_chat_template = load_model_chat_template_from_file(tokenizer_config_file)?
.expect("Model chat template not found");

let conversations = vec![Conversation {
role: Role::User,
content: "what's your name?",
}];
let args = ApplyChatTemplateArgs {
conversations: vec![conversations.into()],
documents: None,
model_id: &model_id,
chat_template_id: None,
add_generation_prompt: None,
continue_final_message: None,
};
let encodings = tokenizer.apply_chat_template_and_encode(model_chat_template, args)?;
let prompt: Vec<u32> = encodings
.iter()
.flat_map(|encoding| encoding.get_ids())
.copied()
.collect();
let prompt_tokens = Array::from(&prompt[..]).index(NewAxis);

let mut cache = Vec::new();
let mut model = load_qwen3_model(model_dir)?;
let generate = mlx_lm::models::qwen3::Generate::<ConcatKeyValueCache>::new(
&mut model,
&mut cache,
0.2,
&prompt_tokens,
);

let mut tokens = Vec::new();
for (token, ntoks) in generate.zip(0..256) {
let token = token.unwrap();
tokens.push(token.clone());

if ntoks == 0 {
eval(&tokens).unwrap();
}

if tokens.len() % 20 == 0 {
eval(&tokens).unwrap();
let slice: Vec<u32> = tokens.drain(..).map(|t| t.item::<u32>()).collect();
let s = tokenizer.decode(&slice, true).unwrap();
print!("{s}");
}
}

eval(&tokens).unwrap();
let slice: Vec<u32> = tokens.drain(..).map(|t| t.item::<u32>()).collect();
let s = tokenizer.decode(&slice, true).unwrap();
println!("{s}");

println!("------");

Ok(())
}

fn main() -> anyhow::Result<()> {
qwen3()
}
18 changes: 18 additions & 0 deletions examples/gemma3/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[package]
name = "gemma3"
version.workspace = true
edition.workspace = true
authors.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true
documentation.workspace = true
rust-version.workspace = true

[dependencies]
mlx-rs.workspace = true
mlx-lm.workspace = true
mlx-lm-utils.workspace = true

anyhow = "1"
5 changes: 5 additions & 0 deletions examples/gemma3/download.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash

model_id="mlx-community/gemma-3-270m-bf16"

huggingface-cli download $model_id --local-dir ./cache/gemma-3-270m-bf16
88 changes: 88 additions & 0 deletions examples/gemma3/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
use std::path::Path;

use mlx_lm::{cache::ConcatKeyValueCache, models::gemma::gemma3::load_gemma3_model};
use mlx_lm_utils::tokenizer::{
load_gemma_chat_template_from_file, ApplyChatTemplateArgs, Conversation, Role, Tokenizer,
};
use mlx_rs::{
ops::indexing::{IndexOp, NewAxis},
transforms::eval,
Array,
};

const CACHED_TEST_MODEL_DIR: &str = "./cache/gemma-3-270m-bf16";

fn gemma3() -> anyhow::Result<()> {
let model_dir = Path::new(CACHED_TEST_MODEL_DIR);

let model_id = "mlx-community/gemma-3-270m-bf16".to_string();
let tokenizer_file = model_dir.join("tokenizer.json");
let chat_template_jinja_file = model_dir.join("chat_template.jinja");
let mut tokenizer =
Tokenizer::from_file(tokenizer_file).map_err(|e| anyhow::anyhow!("{:?}", e))?;
let model_chat_template = load_gemma_chat_template_from_file(chat_template_jinja_file)?;

let conversations = vec![Conversation {
role: Role::User,
content: "what's your name?",
}];
println!("Conversations: {:?}", conversations);

let args = ApplyChatTemplateArgs {
conversations: vec![conversations.into()],
documents: None,
model_id: &model_id,
chat_template_id: None,
add_generation_prompt: Some(true),
continue_final_message: None,
add_special_tokens: Some(true),
};
let encodings = tokenizer.apply_chat_template_and_encode(model_chat_template, args)?;
let prompt: Vec<u32> = encodings
.iter()
.flat_map(|encoding| encoding.get_ids())
.copied()
.collect();
println!("Prompt tokens (raw): {:?}", prompt);
let prompt_tokens = Array::from(&prompt[..]).index(NewAxis);
println!("Prompt tokens (array): {:?}", prompt_tokens);

let mut cache = Vec::new();
let mut model = load_gemma3_model(model_dir)?;
let generate = mlx_lm::models::gemma::gemma3::Generate::<ConcatKeyValueCache>::new(
&mut model,
&mut cache,
0.0,
&prompt_tokens,
);

let mut tokens = Vec::new();
for (token, ntoks) in generate.zip(0..256) {
let token = token.unwrap();
tokens.push(token.clone());

if ntoks == 0 {
eval(&tokens).unwrap();
}

if tokens.len() % 20 == 0 {
eval(&tokens).unwrap();
let slice: Vec<u32> = tokens.drain(..).map(|t| t.item::<u32>()).collect();
let s = tokenizer.decode(&slice, true).unwrap();
print!("{s}");
}
}

eval(&tokens).unwrap();
let slice: Vec<u32> = tokens.drain(..).map(|t| t.item::<u32>()).collect();
let s = tokenizer.decode(&slice, true).unwrap();
println!("{s}");

println!("------");

Ok(())
}

fn main() -> anyhow::Result<()> {
gemma3()
}
5 changes: 5 additions & 0 deletions examples/lm/download.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash

model_id="mlx-community/Qwen3-4B-bf16"

huggingface-cli download $model_id --local-dir ./cache/Qwen3-4B-bf16
11 changes: 10 additions & 1 deletion mlx-lm-utils/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,10 @@ impl Tokenizer {
{
let Self { inner, env } = self;

let add_special_tokens = args.add_special_tokens.unwrap_or(false);
let rendered_chats = apply_chat_template(env, model_template, args)?;
inner
.encode_batch(rendered_chats, false)
.encode_batch(rendered_chats, add_special_tokens)
.map_err(Into::into)
}
}
Expand Down Expand Up @@ -237,6 +238,7 @@ where
pub chat_template_id: Option<&'a str>,
pub add_generation_prompt: Option<bool>,
pub continue_final_message: Option<bool>,
pub add_special_tokens: Option<bool>,
}

pub fn load_model_chat_template_from_str(content: &str) -> std::io::Result<Option<String>> {
Expand All @@ -257,6 +259,11 @@ pub fn load_model_chat_template_from_file(
load_model_chat_template_from_str(&content)
}

pub fn load_gemma_chat_template_from_file(file: impl AsRef<Path>) -> std::io::Result<String> {
let content = read_to_string(file)?;
Ok(content)
}

// chat_template = self.get_chat_template(chat_template, tools)

// if isinstance(conversation, (list, tuple)) and (
Expand Down Expand Up @@ -445,6 +452,7 @@ where
chat_template_id,
add_generation_prompt,
continue_final_message,
add_special_tokens: _,
} = args;

let add_generation_prompt = add_generation_prompt.unwrap_or(false);
Expand Down Expand Up @@ -635,6 +643,7 @@ mod tests {
chat_template_id: None,
add_generation_prompt: None,
continue_final_message: None,
add_special_tokens: Some(false),
};

let encodings = tokenizer
Expand Down
Loading
Loading