Skip to content

Commit 381af70

Browse files
authored
Update launcher to auto-convert fast tokenizer (#48)
This PR updates the launcher to save the fast tokenizer to `/tmp/<uuid>/tokenizer.json` for use by the router, in case the model does not have an already converted fast tokenizer or the one it has isn't in sync with its slow tokenizer. Signed-off-by: Daniel Clark <[email protected]> Co-authored-by: Daniel Clark <[email protected]>
1 parent 88f2a0b commit 381af70

File tree

3 files changed

+94
-40
lines changed

3 files changed

+94
-40
lines changed

Cargo.lock

Lines changed: 6 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

launcher/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ nix = "0.27.1"
1212
serde_json = "^1.0.114"
1313
tracing = "0.1.40"
1414
tracing-subscriber = { version = "0.3.18", features = ["json"] }
15-
15+
uuid = { version = "1.7.0", features = ["v4", "fast-rng"] }

launcher/src/main.rs

Lines changed: 87 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::{
77
io,
88
io::{BufRead, BufReader, ErrorKind, Write},
99
os::unix::process::{CommandExt, ExitStatusExt},
10-
path::Path,
10+
path::{Path, PathBuf},
1111
process::{Command, ExitCode, ExitStatus, Stdio},
1212
sync::{
1313
atomic::{AtomicBool, Ordering},
@@ -25,7 +25,8 @@ use nix::{
2525
sys::signal::{self, Signal},
2626
unistd::Pid,
2727
};
28-
use tracing::{info, warn};
28+
use tracing::{error, info, warn};
29+
use uuid::Uuid;
2930

3031
// In most cases this gives the best performance for inferencing
3132
const DEFAULT_PYTORCH_CUDA_ALLOC_CONF: &str = "expandable_segments:True";
@@ -135,13 +136,27 @@ fn main() -> ExitCode {
135136
// Determine number of shards based on command line arg and env vars
136137
let num_shard = find_num_shards(args.num_shard);
137138

138-
// Resolve fast tokenizer path
139-
let tokenizer_path = resolve_tokenizer_path(&args.model_name, args.revision.as_deref())
140-
.expect("Could not find tokenizer for model");
139+
let config_path: PathBuf = resolve_config_path(&args.model_name, args.revision.as_deref())
140+
.expect("Failed to resolve config path")
141+
.into();
142+
143+
// Save fast tokenizer to /tmp
144+
let save_path = format!("/tmp/{}", Uuid::new_v4());
145+
let tokenizer_path =
146+
if save_fast_tokenizer(&args.model_name, args.revision.as_deref(), &save_path).is_ok() {
147+
format!("/{save_path}/tokenizer.json")
148+
} else {
149+
warn!("Failed to (re-)convert tokenizer, falling back to use existing fast tokenizer");
150+
let tokenizer_path = config_path.parent().unwrap().join("tokenizer.json");
151+
if tokenizer_path.is_file() {
152+
tokenizer_path.to_string_lossy().to_string()
153+
} else {
154+
panic!("No existing fast tokenizer (tokenizer.json) found")
155+
}
156+
};
141157

142158
// Determine max sequence length based on command line arg and env vars
143-
let config_json_path = tokenizer_path.replace("tokenizer.json", "config.json");
144-
let max_sequence_length = get_max_sequence_length(args.max_sequence_length, &config_json_path);
159+
let max_sequence_length = get_max_sequence_length(args.max_sequence_length, &config_path);
145160

146161
match env::var("MAX_BATCH_WEIGHT") {
147162
Ok(max_batch_weight) if !max_batch_weight.trim().is_empty() => {
@@ -415,15 +430,15 @@ fn num_cuda_devices() -> Option<usize> {
415430
/// 1. MAX_SEQUENCE_LENGTH set by user
416431
/// 2. The sequence length specified in config.json
417432
/// 3. A default of 2048
418-
fn get_max_sequence_length(max_sequence_length: Option<usize>, config_json_path: &String) -> usize {
433+
fn get_max_sequence_length(max_sequence_length: Option<usize>, config_path: &Path) -> usize {
419434
if let Some(max_sequence_length) = max_sequence_length {
420435
info!(
421436
"Using configured max_sequence_length: {}",
422437
max_sequence_length
423438
);
424439
return max_sequence_length;
425440
}
426-
if let Ok(model_config) = get_config_json(config_json_path) {
441+
if let Ok(model_config) = get_config_json(config_path) {
427442
if let Some(length) = get_max_sequence_length_from_config(&model_config) {
428443
info!(
429444
"Loaded max_sequence_length from model config.json: {}",
@@ -440,7 +455,7 @@ fn get_max_sequence_length(max_sequence_length: Option<usize>, config_json_path:
440455
}
441456

442457
/// Opens the model's config.json file and reads into serde_json value
443-
fn get_config_json(config_path: &String) -> Result<serde_json::Value, std::io::Error> {
458+
fn get_config_json(config_path: &Path) -> Result<serde_json::Value, std::io::Error> {
444459
let reader = BufReader::new(File::open(config_path)?);
445460
Ok(serde_json::from_reader(reader)?)
446461
}
@@ -734,7 +749,19 @@ fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receive
734749
let _ = shutdown_receiver.recv();
735750
}
736751

737-
fn resolve_tokenizer_path(model_name: &str, revision: Option<&str>) -> Result<String, io::Error> {
752+
fn write_termination_log(msg: &str) -> Result<(), io::Error> {
753+
// Writes a message to the termination log.
754+
// Creates the logfile if it doesn't exist.
755+
let mut f = File::options()
756+
.write(true)
757+
.create(true)
758+
.truncate(true)
759+
.open("/dev/termination-log")?;
760+
writeln!(f, "{}", msg)?;
761+
Ok(())
762+
}
763+
764+
fn resolve_config_path(model_name: &str, revision: Option<&str>) -> Result<String, io::Error> {
738765
let cache = env::var("TRANSFORMERS_CACHE")
739766
.or_else(|_| env::var("HUGGINGFACE_HUB_CACHE"))
740767
.ok();
@@ -753,20 +780,19 @@ fn resolve_tokenizer_path(model_name: &str, revision: Option<&str>) -> Result<St
753780
true => fs::read_to_string(ref_path)?,
754781
false => revision.to_string(),
755782
};
756-
let tok_path = dir.join("snapshots").join(&revision).join("tokenizer.json");
757-
if tok_path.try_exists()? {
758-
Ok(tok_path.to_string_lossy().into())
783+
let config_path = dir.join("snapshots").join(&revision).join("config.json");
784+
if config_path.try_exists()? {
785+
Ok(config_path.to_string_lossy().into())
759786
} else {
760-
Err(io::Error::new(
761-
ErrorKind::NotFound,
762-
format!(
763-
"Tokenizer file not found in local cache for model {model_name}, revision {revision}"
764-
)
765-
))
787+
let message = format!(
788+
"Config file not found in local cache for model {model_name}, revision {revision}"
789+
);
790+
error!(message);
791+
Err(io::Error::new(ErrorKind::NotFound, message))
766792
}
767793
} else {
768794
// Try treating model_name as explicit model path
769-
let try_path = Path::new(&model_name).join("tokenizer.json");
795+
let try_path = Path::new(&model_name).join("config.json");
770796
if try_path.try_exists()? {
771797
Ok(try_path.to_string_lossy().into())
772798
} else {
@@ -775,20 +801,49 @@ fn resolve_tokenizer_path(model_name: &str, revision: Option<&str>) -> Result<St
775801
} else {
776802
format!("Model {model_name} not found in local cache")
777803
};
778-
804+
error!(message);
779805
Err(io::Error::new(ErrorKind::NotFound, message))
780806
}
781807
}
782808
}
783809

784-
fn write_termination_log(msg: &str) -> Result<(), io::Error> {
785-
// Writes a message to the termination log.
786-
// Creates the logfile if it doesn't exist.
787-
let mut f = File::options()
788-
.write(true)
789-
.create(true)
790-
.truncate(true)
791-
.open("/dev/termination-log")?;
792-
writeln!(f, "{}", msg)?;
793-
Ok(())
810+
/// Convert and save fast tokenizer via transformers `AutoTokenizer.from_pretrained`.
811+
fn save_fast_tokenizer(
812+
model_name: &str,
813+
revision: Option<&str>,
814+
save_path: &str,
815+
) -> Result<(), io::Error> {
816+
info!("Saving fast tokenizer for `{model_name}` to `{save_path}`");
817+
let model_name = model_name.escape_default();
818+
let revision = revision.map(|v| v.escape_default());
819+
let code = if let Some(revision) = revision {
820+
format!(
821+
"from transformers import AutoTokenizer; \
822+
AutoTokenizer.from_pretrained(\"{model_name}\", \
823+
revision=\"{revision}\").save_pretrained(\"{save_path}\")"
824+
)
825+
} else {
826+
format!(
827+
"from transformers import AutoTokenizer; \
828+
AutoTokenizer.from_pretrained(\"{model_name}\").save_pretrained(\"{save_path}\")"
829+
)
830+
};
831+
match Command::new("python").args(["-c", &code]).status() {
832+
Ok(status) => {
833+
if status.success() {
834+
Ok(())
835+
} else {
836+
let message = "Python process for tokenizer conversion failed".to_string();
837+
error!(
838+
exit_code = ?status.code(),
839+
message
840+
);
841+
Err(io::Error::new(ErrorKind::Other, message))
842+
}
843+
}
844+
Err(e) => {
845+
error!("Failed to launch python process for tokenizer conversion");
846+
Err(e)
847+
}
848+
}
794849
}

0 commit comments

Comments
 (0)