diff --git a/bindings/python/src/trainers.rs b/bindings/python/src/trainers.rs index 53415fff0..2566564f3 100644 --- a/bindings/python/src/trainers.rs +++ b/bindings/python/src/trainers.rs @@ -7,6 +7,7 @@ use pyo3::prelude::*; use pyo3::types::*; use serde::{Deserialize, Serialize}; use tk::models::TrainerWrapper; +use tk::utils::ProgressFormat; use tk::Trainer; use tokenizers as tk; @@ -209,6 +210,39 @@ impl PyBpeTrainer { setter!(self_, BpeTrainer, show_progress, show_progress); } + /// Get the progress output format ("indicatif", "json", or "silent") + #[getter] + fn get_progress_format(self_: PyRef) -> String { + let format = getter!(self_, BpeTrainer, progress_format); + match format { + ProgressFormat::Indicatif => "indicatif".to_string(), + ProgressFormat::JsonLines => "json".to_string(), + ProgressFormat::Silent => "silent".to_string(), + } + } + + /// Set the progress output format ("indicatif", "json", or "silent") + #[setter] + fn set_progress_format(self_: PyRef, format: &str) { + let fmt = match format { + "json" => ProgressFormat::JsonLines, + "silent" => ProgressFormat::Silent, + _ => ProgressFormat::Indicatif, + }; + setter!(self_, BpeTrainer, progress_format, fmt); + } + + /// Get the number of unique words after feeding the corpus + #[pyo3(name = "get_word_count")] + fn get_word_count(self_: PyRef) -> usize { + let super_ = self_.as_ref(); + if let TrainerWrapper::BpeTrainer(ref trainer) = *super_.trainer.read().unwrap() { + trainer.get_word_count() + } else { + 0 + } + } + #[getter] fn get_special_tokens(self_: PyRef) -> Vec { getter!( @@ -308,7 +342,7 @@ impl PyBpeTrainer { #[new] #[pyo3( signature = (**kwargs), - text_signature = "(self, vocab_size=30000, min_frequency=0, show_progress=True, special_tokens=[], limit_alphabet=None, initial_alphabet=[], continuing_subword_prefix=None, end_of_word_suffix=None, max_token_length=None, words={})" + text_signature = "(self, vocab_size=30000, min_frequency=0, show_progress=True, progress_format=\"indicatif\", special_tokens=[], limit_alphabet=None, initial_alphabet=[], continuing_subword_prefix=None, end_of_word_suffix=None, max_token_length=None, words={})" )] pub fn new(kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<(Self, PyTrainer)> { let mut builder = tk::models::bpe::BpeTrainer::builder(); @@ -319,6 +353,15 @@ impl PyBpeTrainer { "vocab_size" => builder = builder.vocab_size(val.extract()?), "min_frequency" => builder = builder.min_frequency(val.extract()?), "show_progress" => builder = builder.show_progress(val.extract()?), + "progress_format" => { + let fmt: String = val.extract()?; + let format = match fmt.as_str() { + "json" => ProgressFormat::JsonLines, + "silent" => ProgressFormat::Silent, + _ => ProgressFormat::Indicatif, + }; + builder = builder.progress_format(format); + } "special_tokens" => { builder = builder.special_tokens( val.downcast::()? diff --git a/tokenizers/src/lib.rs b/tokenizers/src/lib.rs index 7841314d0..08e145a25 100644 --- a/tokenizers/src/lib.rs +++ b/tokenizers/src/lib.rs @@ -151,6 +151,9 @@ pub use tokenizer::*; // Re-export also parallelism utils pub use utils::parallelism; +// Re-export ProgressFormat for trainer configuration +pub use utils::ProgressFormat; + // Re-export for from_pretrained #[cfg(feature = "http")] pub use utils::from_pretrained::FromPretrainedParameters; diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index cda6aea65..df68c655e 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -3,7 +3,7 @@ use super::{Pair, WithFirstLastIterator, Word, BPE}; use crate::parallelism::*; use crate::tokenizer::{AddedToken, Result, Trainer}; -use crate::utils::progress::{ProgressBar, ProgressStyle}; +use crate::utils::progress::{ProgressBar, ProgressFormat, ProgressStyle}; use ahash::{AHashMap, AHashSet}; use compact_str::CompactString; use dary_heap::OctonaryHeap; @@ -42,6 +42,7 @@ struct Config { min_frequency: u64, vocab_size: usize, show_progress: bool, + progress_format: ProgressFormat, special_tokens: Vec, limit_alphabet: Option, initial_alphabet: AHashSet, @@ -63,6 +64,7 @@ impl Default for BpeTrainerBuilder { min_frequency: 0, vocab_size: 30000, show_progress: true, + progress_format: ProgressFormat::default(), special_tokens: vec![], limit_alphabet: None, initial_alphabet: AHashSet::new(), @@ -101,6 +103,18 @@ impl BpeTrainerBuilder { self } + /// Set the progress output format + /// + /// Controls how progress information is reported during training. + /// - `Indicatif` (default): Interactive terminal progress bars + /// - `JsonLines`: Machine-readable JSON lines to stderr + /// - `Silent`: No progress output + #[must_use] + pub fn progress_format(mut self, format: ProgressFormat) -> Self { + self.config.progress_format = format; + self + } + /// Set the special tokens #[must_use] pub fn special_tokens(mut self, tokens: Vec) -> Self { @@ -150,6 +164,7 @@ impl BpeTrainerBuilder { min_frequency: self.config.min_frequency, vocab_size: self.config.vocab_size, show_progress: self.config.show_progress, + progress_format: self.config.progress_format, special_tokens: self.config.special_tokens, limit_alphabet: self.config.limit_alphabet, initial_alphabet: self.config.initial_alphabet, @@ -186,6 +201,8 @@ pub struct BpeTrainer { pub vocab_size: usize, /// Whether to show progress while training pub show_progress: bool, + /// Progress output format (Indicatif, JsonLines, or Silent) + pub progress_format: ProgressFormat, /// A list of special tokens that the model should know of pub special_tokens: Vec, /// Whether to limit the number of initial tokens that can be kept before computing merges @@ -222,9 +239,15 @@ impl BpeTrainer { BpeTrainerBuilder::new() } - /// Setup a progress bar if asked to show progress + /// Returns the number of unique words in the corpus after feeding. + /// This can be used to estimate training time before starting. + pub fn get_word_count(&self) -> usize { + self.words.len() + } + + /// Setup a progress bar if asked to show progress (only for Indicatif format) fn setup_progress(&self) -> Option { - if self.show_progress { + if self.show_progress && self.progress_format == ProgressFormat::Indicatif { let p = ProgressBar::new(0); p.set_style( ProgressStyle::default_bar() @@ -237,13 +260,24 @@ impl BpeTrainer { } } + /// Emit JSON progress line to stderr (for JsonLines format) + fn emit_json_progress(&self, stage: &str, current: usize, total: usize) { + if self.progress_format == ProgressFormat::JsonLines { + eprintln!( + r#"{{"stage":"{}","current":{},"total":{}}}"#, + stage, current, total + ); + } + } + /// Set the progress bar in the finish state - fn finalize_progress(&self, p: &Option, final_len: usize) { + fn finalize_progress(&self, p: &Option, final_len: usize, stage: &str) { if let Some(p) = p { p.set_length(final_len as u64); p.finish(); println!(); } + self.emit_json_progress(stage, final_len, final_len); } /// Update the progress bar with the new provided length and message @@ -253,6 +287,8 @@ impl BpeTrainer { p.set_length(len as u64); p.reset(); } + // Emit initial JSON progress for this stage + self.emit_json_progress(message, 0, len); } /// Add the provided special tokens to the initial vocabulary @@ -444,7 +480,7 @@ impl BpeTrainer { self.update_progress(&progress, word_counts.len(), "Tokenize words"); let (mut words, counts) = self.tokenize_words(word_counts, &mut word_to_id, &mut id_to_word, &progress); - self.finalize_progress(&progress, words.len()); + self.finalize_progress(&progress, words.len(), "Tokenize words"); // // 4. Count pairs in words @@ -463,7 +499,7 @@ impl BpeTrainer { }); } }); - self.finalize_progress(&progress, words.len()); + self.finalize_progress(&progress, words.len(), "Count pairs"); // // 5. Do merges @@ -565,8 +601,9 @@ impl BpeTrainer { if let Some(p) = &progress { p.inc(1); } + self.emit_json_progress("Compute merges", merges.len(), self.vocab_size); } - self.finalize_progress(&progress, merges.len()); + self.finalize_progress(&progress, merges.len(), "Compute merges"); // Transfer new vocab & options to model //model.vocab = word_to_id; diff --git a/tokenizers/src/utils/mod.rs b/tokenizers/src/utils/mod.rs index 636bee660..c9450b322 100644 --- a/tokenizers/src/utils/mod.rs +++ b/tokenizers/src/utils/mod.rs @@ -20,6 +20,9 @@ pub mod parallelism; pub(crate) mod progress; pub mod truncation; +// Re-export ProgressFormat for public API +pub use progress::ProgressFormat; + use ahash::AHashMap; use serde::{Serialize, Serializer}; use std::collections::BTreeMap; diff --git a/tokenizers/src/utils/progress.rs b/tokenizers/src/utils/progress.rs index 96e9f6082..393315163 100644 --- a/tokenizers/src/utils/progress.rs +++ b/tokenizers/src/utils/progress.rs @@ -1,3 +1,20 @@ +use serde::{Deserialize, Serialize}; + +/// Progress output format for training operations. +/// +/// Controls how progress information is reported during tokenizer training. +/// Default is `Indicatif` which shows interactive terminal progress bars. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)] +pub enum ProgressFormat { + /// Interactive terminal progress bars using indicatif (default behavior) + #[default] + Indicatif, + /// Machine-readable JSON lines to stderr for programmatic consumption + JsonLines, + /// No progress output + Silent, +} + #[cfg(feature = "progressbar")] pub(crate) use indicatif::{ProgressBar, ProgressStyle};