Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion bindings/python/src/trainers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use pyo3::prelude::*;
use pyo3::types::*;
use serde::{Deserialize, Serialize};
use tk::models::TrainerWrapper;
use tk::utils::ProgressFormat;
use tk::Trainer;
use tokenizers as tk;

Expand Down Expand Up @@ -209,6 +210,39 @@ impl PyBpeTrainer {
setter!(self_, BpeTrainer, show_progress, show_progress);
}

/// Get the progress output format ("indicatif", "json", or "silent")
#[getter]
fn get_progress_format(self_: PyRef<Self>) -> String {
let format = getter!(self_, BpeTrainer, progress_format);
match format {
ProgressFormat::Indicatif => "indicatif".to_string(),
ProgressFormat::JsonLines => "json".to_string(),
ProgressFormat::Silent => "silent".to_string(),
}
}

/// Set the progress output format ("indicatif", "json", or "silent")
#[setter]
fn set_progress_format(self_: PyRef<Self>, format: &str) {
let fmt = match format {
"json" => ProgressFormat::JsonLines,
"silent" => ProgressFormat::Silent,
_ => ProgressFormat::Indicatif,
};
setter!(self_, BpeTrainer, progress_format, fmt);
}

/// Get the number of unique words after feeding the corpus
#[pyo3(name = "get_word_count")]
fn get_word_count(self_: PyRef<Self>) -> usize {
let super_ = self_.as_ref();
if let TrainerWrapper::BpeTrainer(ref trainer) = *super_.trainer.read().unwrap() {
trainer.get_word_count()
} else {
0
}
}

#[getter]
fn get_special_tokens(self_: PyRef<Self>) -> Vec<PyAddedToken> {
getter!(
Expand Down Expand Up @@ -308,7 +342,7 @@ impl PyBpeTrainer {
#[new]
#[pyo3(
signature = (**kwargs),
text_signature = "(self, vocab_size=30000, min_frequency=0, show_progress=True, special_tokens=[], limit_alphabet=None, initial_alphabet=[], continuing_subword_prefix=None, end_of_word_suffix=None, max_token_length=None, words={})"
text_signature = "(self, vocab_size=30000, min_frequency=0, show_progress=True, progress_format=\"indicatif\", special_tokens=[], limit_alphabet=None, initial_alphabet=[], continuing_subword_prefix=None, end_of_word_suffix=None, max_token_length=None, words={})"
)]
pub fn new(kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<(Self, PyTrainer)> {
let mut builder = tk::models::bpe::BpeTrainer::builder();
Expand All @@ -319,6 +353,15 @@ impl PyBpeTrainer {
"vocab_size" => builder = builder.vocab_size(val.extract()?),
"min_frequency" => builder = builder.min_frequency(val.extract()?),
"show_progress" => builder = builder.show_progress(val.extract()?),
"progress_format" => {
let fmt: String = val.extract()?;
let format = match fmt.as_str() {
"json" => ProgressFormat::JsonLines,
"silent" => ProgressFormat::Silent,
_ => ProgressFormat::Indicatif,
};
builder = builder.progress_format(format);
}
"special_tokens" => {
builder = builder.special_tokens(
val.downcast::<PyList>()?
Expand Down
3 changes: 3 additions & 0 deletions tokenizers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ pub use tokenizer::*;
// Re-export also parallelism utils
pub use utils::parallelism;

// Re-export ProgressFormat for trainer configuration
pub use utils::ProgressFormat;

// Re-export for from_pretrained
#[cfg(feature = "http")]
pub use utils::from_pretrained::FromPretrainedParameters;
51 changes: 44 additions & 7 deletions tokenizers/src/models/bpe/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use super::{Pair, WithFirstLastIterator, Word, BPE};
use crate::parallelism::*;
use crate::tokenizer::{AddedToken, Result, Trainer};
use crate::utils::progress::{ProgressBar, ProgressStyle};
use crate::utils::progress::{ProgressBar, ProgressFormat, ProgressStyle};
use ahash::{AHashMap, AHashSet};
use compact_str::CompactString;
use dary_heap::OctonaryHeap;
Expand Down Expand Up @@ -42,6 +42,7 @@ struct Config {
min_frequency: u64,
vocab_size: usize,
show_progress: bool,
progress_format: ProgressFormat,
special_tokens: Vec<AddedToken>,
limit_alphabet: Option<usize>,
initial_alphabet: AHashSet<char>,
Expand All @@ -63,6 +64,7 @@ impl Default for BpeTrainerBuilder {
min_frequency: 0,
vocab_size: 30000,
show_progress: true,
progress_format: ProgressFormat::default(),
special_tokens: vec![],
limit_alphabet: None,
initial_alphabet: AHashSet::new(),
Expand Down Expand Up @@ -101,6 +103,18 @@ impl BpeTrainerBuilder {
self
}

/// Set the progress output format
///
/// Controls how progress information is reported during training.
/// - `Indicatif` (default): Interactive terminal progress bars
/// - `JsonLines`: Machine-readable JSON lines to stderr
/// - `Silent`: No progress output
#[must_use]
pub fn progress_format(mut self, format: ProgressFormat) -> Self {
self.config.progress_format = format;
self
}

/// Set the special tokens
#[must_use]
pub fn special_tokens(mut self, tokens: Vec<AddedToken>) -> Self {
Expand Down Expand Up @@ -150,6 +164,7 @@ impl BpeTrainerBuilder {
min_frequency: self.config.min_frequency,
vocab_size: self.config.vocab_size,
show_progress: self.config.show_progress,
progress_format: self.config.progress_format,
special_tokens: self.config.special_tokens,
limit_alphabet: self.config.limit_alphabet,
initial_alphabet: self.config.initial_alphabet,
Expand Down Expand Up @@ -186,6 +201,8 @@ pub struct BpeTrainer {
pub vocab_size: usize,
/// Whether to show progress while training
pub show_progress: bool,
/// Progress output format (Indicatif, JsonLines, or Silent)
pub progress_format: ProgressFormat,
/// A list of special tokens that the model should know of
pub special_tokens: Vec<AddedToken>,
/// Whether to limit the number of initial tokens that can be kept before computing merges
Expand Down Expand Up @@ -222,9 +239,15 @@ impl BpeTrainer {
BpeTrainerBuilder::new()
}

/// Setup a progress bar if asked to show progress
/// Returns the number of unique words in the corpus after feeding.
/// This can be used to estimate training time before starting.
pub fn get_word_count(&self) -> usize {
self.words.len()
}

/// Setup a progress bar if asked to show progress (only for Indicatif format)
fn setup_progress(&self) -> Option<ProgressBar> {
if self.show_progress {
if self.show_progress && self.progress_format == ProgressFormat::Indicatif {
let p = ProgressBar::new(0);
p.set_style(
ProgressStyle::default_bar()
Expand All @@ -237,13 +260,24 @@ impl BpeTrainer {
}
}

/// Emit JSON progress line to stderr (for JsonLines format)
fn emit_json_progress(&self, stage: &str, current: usize, total: usize) {
if self.progress_format == ProgressFormat::JsonLines {
eprintln!(
r#"{{"stage":"{}","current":{},"total":{}}}"#,
stage, current, total
);
}
}

/// Set the progress bar in the finish state
fn finalize_progress(&self, p: &Option<ProgressBar>, final_len: usize) {
fn finalize_progress(&self, p: &Option<ProgressBar>, final_len: usize, stage: &str) {
if let Some(p) = p {
p.set_length(final_len as u64);
p.finish();
println!();
}
self.emit_json_progress(stage, final_len, final_len);
}

/// Update the progress bar with the new provided length and message
Expand All @@ -253,6 +287,8 @@ impl BpeTrainer {
p.set_length(len as u64);
p.reset();
}
// Emit initial JSON progress for this stage
self.emit_json_progress(message, 0, len);
}

/// Add the provided special tokens to the initial vocabulary
Expand Down Expand Up @@ -444,7 +480,7 @@ impl BpeTrainer {
self.update_progress(&progress, word_counts.len(), "Tokenize words");
let (mut words, counts) =
self.tokenize_words(word_counts, &mut word_to_id, &mut id_to_word, &progress);
self.finalize_progress(&progress, words.len());
self.finalize_progress(&progress, words.len(), "Tokenize words");

//
// 4. Count pairs in words
Expand All @@ -463,7 +499,7 @@ impl BpeTrainer {
});
}
});
self.finalize_progress(&progress, words.len());
self.finalize_progress(&progress, words.len(), "Count pairs");

//
// 5. Do merges
Expand Down Expand Up @@ -565,8 +601,9 @@ impl BpeTrainer {
if let Some(p) = &progress {
p.inc(1);
}
self.emit_json_progress("Compute merges", merges.len(), self.vocab_size);
}
self.finalize_progress(&progress, merges.len());
self.finalize_progress(&progress, merges.len(), "Compute merges");

// Transfer new vocab & options to model
//model.vocab = word_to_id;
Expand Down
3 changes: 3 additions & 0 deletions tokenizers/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ pub mod parallelism;
pub(crate) mod progress;
pub mod truncation;

// Re-export ProgressFormat for public API
pub use progress::ProgressFormat;

use ahash::AHashMap;
use serde::{Serialize, Serializer};
use std::collections::BTreeMap;
Expand Down
17 changes: 17 additions & 0 deletions tokenizers/src/utils/progress.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
use serde::{Deserialize, Serialize};

/// Progress output format for training operations.
///
/// Controls how progress information is reported during tokenizer training.
/// Default is `Indicatif` which shows interactive terminal progress bars.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum ProgressFormat {
/// Interactive terminal progress bars using indicatif (default behavior)
#[default]
Indicatif,
/// Machine-readable JSON lines to stderr for programmatic consumption
JsonLines,
/// No progress output
Silent,
}

#[cfg(feature = "progressbar")]
pub(crate) use indicatif::{ProgressBar, ProgressStyle};

Expand Down