diff --git a/Cargo.lock b/Cargo.lock index 437d2ccb..d85cd87e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -841,6 +841,7 @@ dependencies = [ "sha2", "tempfile", "tera", + "tokenizers", "uuid", ] diff --git a/Cargo.toml b/Cargo.toml index 8f9640e0..0912184a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,9 @@ serde = { version = "1.0.228", features = [ "serde_derive",] } tracing = "0.1.41" thiserror = "2.0.17" +[workspace.dependencies.tokenizers] +version = "0.22.1" + [workspace.dependencies.ort] version = "=2.0.0-rc.10" features = [ "std", "tracing", "ndarray",] diff --git a/encoderfile-core/Cargo.toml b/encoderfile-core/Cargo.toml index a02d0a84..bcc6b9b2 100644 --- a/encoderfile-core/Cargo.toml +++ b/encoderfile-core/Cargo.toml @@ -48,7 +48,7 @@ description = "Distribute and run transformer encoders with a single file." default = [ "transport",] transport = [ "runtime", "opentelemetry", "opentelemetry-semantic-conventions", "opentelemetry-stdout", "opentelemetry_sdk", "opentelemetry-otlp", "rmcp", "rmcp-macros", "axum", "axum-server", "clap", "clap_derive", "prost", "tonic-prost", "tonic-types", "tonic-web", "tracing-opentelemetry", "tonic", "tower-http",] dev-utils = [ "runtime",] -runtime = [ "transforms", "tokenizers", "ort",] +runtime = [ "transforms", "ort",] transforms = [ "mlua",] [dependencies] @@ -121,8 +121,7 @@ version = "0.14.1" optional = true [dependencies.tokenizers] -version = "0.22.1" -optional = true +workspace = true [dependencies.tonic-prost] version = "0.14.2" diff --git a/encoderfile-core/src/common/config.rs b/encoderfile-core/src/common/config.rs index 47a695fb..ef433bb5 100644 --- a/encoderfile-core/src/common/config.rs +++ b/encoderfile-core/src/common/config.rs @@ -1,11 +1,17 @@ use super::model_type::ModelType; -use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use tokenizers::PaddingParams; -#[derive(Debug, Serialize, Deserialize, JsonSchema)] +#[derive(Debug, Serialize, Deserialize)] pub struct Config { pub name: String, pub version: String, pub model_type: ModelType, pub transform: Option, + pub tokenizer: TokenizerConfig, +} + +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct TokenizerConfig { + pub padding: PaddingParams, } diff --git a/encoderfile-core/src/dev_utils/mod.rs b/encoderfile-core/src/dev_utils/mod.rs index 47a413f8..46424cd5 100644 --- a/encoderfile-core/src/dev_utils/mod.rs +++ b/encoderfile-core/src/dev_utils/mod.rs @@ -14,17 +14,18 @@ const SEQUENCE_CLASSIFICATION_DIR: &str = "../models/sequence_classification"; const TOKEN_CLASSIFICATION_DIR: &str = "../models/token_classification"; pub fn get_state(dir: &str) -> AppState { - let model_config = Arc::new(get_model_config(dir)); - let tokenizer = Arc::new(get_tokenizer(dir, &model_config)); - let session = Arc::new(get_model(dir)); - let config = Arc::new(Config { name: "my-model".to_string(), version: "0.0.1".to_string(), model_type: T::enum_val(), transform: None, + tokenizer: Default::default(), }); + let model_config = Arc::new(get_model_config(dir)); + let tokenizer = Arc::new(get_tokenizer(dir, &config)); + let session = Arc::new(get_model(dir)); + AppState::new(config, session, tokenizer, model_config) } @@ -52,11 +53,11 @@ fn get_model_config(dir: &str) -> ModelConfig { serde_json::from_reader(reader).expect("Invalid model config") } -fn get_tokenizer(dir: &str, config: &Arc) -> tokenizers::Tokenizer { +fn get_tokenizer(dir: &str, ec_config: &Arc) -> tokenizers::Tokenizer { let tokenizer_str = std::fs::read_to_string(format!("{}/{}", dir, "tokenizer.json")) .expect("Tokenizer json not found"); - crate::runtime::get_tokenizer_from_string(tokenizer_str.as_str(), config) + crate::runtime::get_tokenizer_from_string(tokenizer_str.as_str(), ec_config) } fn get_model(dir: &str) -> Mutex { diff --git a/encoderfile-core/src/factory.rs b/encoderfile-core/src/factory.rs index 6879154b..828f3488 100644 --- a/encoderfile-core/src/factory.rs +++ b/encoderfile-core/src/factory.rs @@ -85,7 +85,7 @@ where let config = get_config(config_str); let session = get_model(model_bytes); let model_config = get_model_config(model_config_str); - let tokenizer = get_tokenizer(tokenizer_json, &model_config); + let tokenizer = get_tokenizer(tokenizer_json, &config); let state = AppState::new(config, session, tokenizer, model_config); diff --git a/encoderfile-core/src/runtime/tokenizer.rs b/encoderfile-core/src/runtime/tokenizer.rs index a6d52918..2eba8565 100644 --- a/encoderfile-core/src/runtime/tokenizer.rs +++ b/encoderfile-core/src/runtime/tokenizer.rs @@ -1,48 +1,24 @@ -use crate::{common::ModelConfig, error::ApiError}; +use crate::{common::Config, error::ApiError}; use anyhow::Result; use std::str::FromStr; use std::sync::{Arc, OnceLock}; -use tokenizers::{ - Encoding, PaddingDirection, PaddingParams, PaddingStrategy, tokenizer::Tokenizer, -}; +use tokenizers::{Encoding, tokenizer::Tokenizer}; static TOKENIZER: OnceLock> = OnceLock::new(); -pub fn get_tokenizer(tokenizer_json: &str, model_config: &Arc) -> Arc { +pub fn get_tokenizer(tokenizer_json: &str, ec_config: &Arc) -> Arc { TOKENIZER - .get_or_init(|| Arc::new(get_tokenizer_from_string(tokenizer_json, model_config))) + .get_or_init(|| Arc::new(get_tokenizer_from_string(tokenizer_json, ec_config))) .clone() } -pub fn get_tokenizer_from_string(s: &str, config: &Arc) -> Tokenizer { - let pad_token_id = config.pad_token_id; - +pub fn get_tokenizer_from_string(s: &str, ec_config: &Arc) -> Tokenizer { let mut tokenizer = match Tokenizer::from_str(s) { Ok(t) => t, Err(e) => panic!("FATAL: Error loading tokenizer: {e:?}"), }; - let pad_token = match tokenizer.id_to_token(pad_token_id) { - Some(tok) => tok, - None => panic!("Model requires a padding token."), - }; - - if tokenizer.get_padding().is_none() { - let params = PaddingParams { - strategy: PaddingStrategy::BatchLongest, - direction: PaddingDirection::Right, - pad_to_multiple_of: None, - pad_id: pad_token_id, - pad_type_id: 0, - pad_token, - }; - - tracing::warn!( - "No padding strategy specified in tokenizer config. Setting default: {:?}", - ¶ms - ); - tokenizer.with_padding(Some(params)); - } + tokenizer.with_padding(Some(ec_config.tokenizer.padding.clone())); tokenizer } diff --git a/encoderfile/Cargo.toml b/encoderfile/Cargo.toml index a6714c76..bd955825 100644 --- a/encoderfile/Cargo.toml +++ b/encoderfile/Cargo.toml @@ -23,6 +23,9 @@ dev-utils = [] [dev-dependencies] tempfile = "3.23.0" +[dependencies.tokenizers] +workspace = true + [dependencies.anyhow] workspace = true diff --git a/encoderfile/src/config.rs b/encoderfile/src/config.rs index 9fa2ecc4..c12baeb2 100644 --- a/encoderfile/src/config.rs +++ b/encoderfile/src/config.rs @@ -38,6 +38,7 @@ pub struct EncoderfileConfig { pub output_path: Option, pub cache_dir: Option, pub transform: Option, + pub tokenizer: Option, #[serde(default = "default_validate_transform")] pub validate_transform: bool, #[serde(default = "default_build")] @@ -46,11 +47,13 @@ pub struct EncoderfileConfig { impl EncoderfileConfig { pub fn embedded_config(&self) -> Result { + let tokenizer = self.validate_tokenizer()?; let config = EmbeddedConfig { name: self.name.clone(), version: self.version.clone(), model_type: self.model_type.clone(), transform: self.transform()?, + tokenizer, }; Ok(config) @@ -117,6 +120,18 @@ impl EncoderfileConfig { } } +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct TokenizerBuildConfig { + pub pad_strategy: Option, +} + +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +#[serde(untagged, rename_all = "snake_case")] +pub enum TokenizerPadStrategy { + BatchLongest, + Fixed { fixed: usize }, +} + #[derive(Debug, Serialize, Deserialize, JsonSchema)] #[serde(untagged)] pub enum Transform { @@ -152,27 +167,60 @@ pub enum ModelPath { model_config_path: PathBuf, model_weights_path: PathBuf, tokenizer_path: PathBuf, + tokenizer_config_path: Option, }, } -macro_rules! asset_path { - ($var:ident, $default:expr, $err:expr) => { - pub fn $var(&self) -> Result { - let path = match self { - Self::Paths { $var, .. } => $var.clone(), - Self::Directory(dir) => { - if !dir.is_dir() { - bail!("No such directory: {:?}", dir); - } - dir.join($default) +impl ModelPath { + fn resolve( + &self, + explicit: Option, + default: impl FnOnce(&PathBuf) -> PathBuf, + err: &str, + ) -> Result> { + let path = match self { + Self::Paths { .. } => explicit, + Self::Directory(dir) => { + if !dir.is_dir() { + bail!("No such directory: {:?}", dir); } - }; + Some(default(dir)) + } + }; - if !path.try_exists()? { - bail!("Could not locate {} at path: {:?}", $err, path); + match path { + Some(p) => { + if !p.try_exists()? { + bail!("Could not locate {} at path: {:?}", err, p); + } + Ok(Some(p.canonicalize()?)) } + None => Ok(None), + } + } +} + +macro_rules! asset_path { + (@Optional $name:ident, $default:expr, $err:expr) => { + pub fn $name(&self) -> Result> { + let explicit = match self { + Self::Paths { $name, .. } => $name.clone(), + _ => None, + }; + + self.resolve(explicit, |dir| dir.join($default), $err) + } + }; + + ($name:ident, $default:expr, $err:expr) => { + pub fn $name(&self) -> Result { + let explicit = match self { + Self::Paths { $name, .. } => Some($name.clone()), + _ => None, + }; - Ok(path.canonicalize()?) + self.resolve(explicit, |dir| dir.join($default), $err)? + .ok_or_else(|| anyhow::anyhow!("Missing required path: {}", $err)) } }; } @@ -181,6 +229,7 @@ impl ModelPath { asset_path!(model_config_path, "config.json", "model config"); asset_path!(tokenizer_path, "tokenizer.json", "tokenizer"); asset_path!(model_weights_path, "model.onnx", "model weights"); + asset_path!(@Optional tokenizer_config_path, "tokenizer_config.json", "tokenizer config"); } fn default_cache_dir() -> PathBuf { @@ -222,12 +271,19 @@ mod tests { base } + // Create temp output dir + fn create_temp_output_dir() -> PathBuf { + create_test_dir("model") + } + // Create a model dir populated with the required files - fn create_model_dir() -> PathBuf { + fn create_temp_model_dir() -> PathBuf { let base = create_test_dir("model"); fs::write(base.join("config.json"), "{}").expect("Failed to create config.json"); fs::write(base.join("tokenizer.json"), "{}").expect("Failed to create tokenizer.json"); fs::write(base.join("model.onnx"), "onnx").expect("Failed to create model.onnx"); + fs::write(base.join("tokenizer_config.json"), "{}") + .expect("Failed to create tokenizer_config.json"); base } @@ -243,12 +299,18 @@ mod tests { #[test] fn test_modelpath_directory_valid() { - let base = create_model_dir(); + let base = create_temp_model_dir(); let mp = ModelPath::Directory(base.clone()); assert!(mp.model_config_path().unwrap().ends_with("config.json")); assert!(mp.tokenizer_path().unwrap().ends_with("tokenizer.json")); assert!(mp.model_weights_path().unwrap().ends_with("model.onnx")); + assert!( + mp.tokenizer_config_path() + .unwrap() + .unwrap() + .ends_with("tokenizer_config.json") + ); cleanup(&base); } @@ -266,11 +328,12 @@ mod tests { #[test] fn test_modelpath_explicit_paths() { - let base = create_model_dir(); + let base = create_temp_model_dir(); let mp = ModelPath::Paths { model_config_path: base.join("config.json"), tokenizer_path: base.join("tokenizer.json"), model_weights_path: base.join("model.onnx"), + tokenizer_config_path: Some(base.join("tokenizer_config.json")), }; assert!(mp.model_config_path().is_ok()); @@ -310,17 +373,18 @@ mod tests { #[test] fn test_encoderfile_generated_dir() { - let base = create_model_dir(); + let base = create_temp_output_dir(); let cfg = EncoderfileConfig { name: "my-cool-model".into(), version: "1.0".into(), - path: ModelPath::Directory(base.clone()), + path: ModelPath::Directory("../models/embedding".into()), model_type: ModelType::Embedding, output_path: Some(base.clone()), cache_dir: Some(base.clone()), validate_transform: false, transform: None, + tokenizer: None, build: true, }; @@ -332,16 +396,17 @@ mod tests { #[test] fn test_encoderfile_to_tera_ctx() { - let base = create_model_dir(); + let base = create_temp_output_dir(); let cfg = EncoderfileConfig { name: "sadness".into(), version: "0.1.0".into(), - path: ModelPath::Directory(base.clone()), + path: ModelPath::Directory("../models/embedding".into()), model_type: ModelType::SequenceClassification, output_path: Some(base.clone()), cache_dir: Some(base.clone()), validate_transform: false, transform: Some(Transform::Inline("1+1".into())), + tokenizer: None, build: true, }; diff --git a/encoderfile/src/lib.rs b/encoderfile/src/lib.rs index 16e4891b..b7f5b3b7 100644 --- a/encoderfile/src/lib.rs +++ b/encoderfile/src/lib.rs @@ -2,4 +2,5 @@ pub mod cli; pub mod config; pub mod model; pub mod templates; +pub mod tokenizer; pub mod transforms; diff --git a/encoderfile/src/tokenizer.rs b/encoderfile/src/tokenizer.rs new file mode 100644 index 00000000..d71803c0 --- /dev/null +++ b/encoderfile/src/tokenizer.rs @@ -0,0 +1,334 @@ +// IMPORTANT NOTE: +// +// Tokenizer configuration is NOT a stable, self-contained artifact. +// +// In practice, tokenizer behavior is split across: +// - tokenizer.json (partially serialized runtime state) SOMETIMES in older models +// - tokenizer_config.json (optional, inconsistently populated) +// - implicit defaults inside the `tokenizers` library +// - and values that affect inference but are *never serialized* +// +// This means: +// - Missing fields fail silently and fall back to defaults +// - Backwards compatibility is heuristic, not contractual +// - Some critical values (e.g. pad_token_id) must be re-derived at runtime +// - A "valid" tokenizer config can still produce subtly wrong results +// +// This code exists to aggressively reconstruct a deterministic TokenizerConfig +// for inference, emitting warnings where possible — but be aware: +// +// ⚠️ Incorrect or incomplete tokenizer configs may not crash. +// ⚠️ They may instead produce silently incorrect model outputs. +// ⚠️ If, khas v'shalem, something silently breaks in Encoderfile, I bet $5 it is going to be this feature. +// +// This is not ideal and will be revisited in v1.0.0 once we have an opportunity to make breaking changes +// in the way encoderfile.yml works, etc. + +use anyhow::Result; +use encoderfile_core::common::TokenizerConfig; +use std::str::FromStr; +use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer}; + +use crate::config::{EncoderfileConfig, TokenizerPadStrategy}; + +impl EncoderfileConfig { + pub fn validate_tokenizer(&self) -> Result { + let tokenizer = match Tokenizer::from_str( + std::fs::read_to_string(self.path.tokenizer_path()?)?.as_str(), + ) { + Ok(t) => t, + Err(e) => anyhow::bail!("FATAL: Failed to load tokenizer: {:?}", e), + }; + + let mut config = match self.path.tokenizer_config_path()? { + // if tokenizer_config.json is provided, use that + Some(tokenizer_config_path) => { + // open tokenizer_config + let contents = std::fs::read_to_string(tokenizer_config_path)?; + let tokenizer_config: serde_json::Value = serde_json::from_str(contents.as_str())?; + + tokenizer_config_from_json_value(tokenizer_config, tokenizer)? + } + // otherwise check for any values given in tokenizer.json (backwards compatibility) + None => { + // will fail here if neither are given + from_tokenizer(tokenizer)? + } + }; + + // TODO: insert any overrides from encoderfile.yml here + let tokenizer_build_config = match &self.tokenizer { + Some(t) => t, + None => return Ok(config), + }; + + if let Some(s) = &tokenizer_build_config.pad_strategy { + config.padding.strategy = match s { + TokenizerPadStrategy::BatchLongest => PaddingStrategy::BatchLongest, + TokenizerPadStrategy::Fixed { fixed } => PaddingStrategy::Fixed(*fixed), + } + }; + + Ok(config) + } +} + +fn from_tokenizer(tokenizer: Tokenizer) -> Result { + let padding = match tokenizer.get_padding() { + Some(p) => p.clone(), + None => { + let padding_params = PaddingParams::default(); + + eprintln!( + "WARNING: No padding params found in `tokenizer.json`. Using defaults: {:?}", + &padding_params + ); + + padding_params + } + }; + + Ok(TokenizerConfig { padding }) +} + +fn tokenizer_config_from_json_value( + val: serde_json::Value, + tokenizer: tokenizers::Tokenizer, +) -> Result { + let mut builder = TokenizerConfigBuilder::new( + val.as_object() + .ok_or(anyhow::anyhow!("tokenizer_config.json must be an object"))?, + ); + + builder.field( + "padding_side", + |config, v| { + let side = v + .as_str() + .ok_or(anyhow::anyhow!("padding_side must be a str"))?; + + config.padding.direction = match side { + "left" => tokenizers::PaddingDirection::Left, + "right" => tokenizers::PaddingDirection::Right, + _ => anyhow::bail!("padding_side must be \"left\" or \"right\""), + }; + + Ok(()) + }, + |config| config.padding.direction, + )?; + + builder.field( + "pad_to_multiple_of", + |config, v| { + if v.is_null() { + config.padding.pad_to_multiple_of = None; + return Ok(()); + } + + config.padding.pad_to_multiple_of = v.as_u64().map(|i| Some(i as usize)).ok_or( + anyhow::anyhow!("pad_to_multiple_of must be an unsigned int or null"), + )?; + + Ok(()) + }, + |config| config.padding.pad_to_multiple_of, + )?; + + builder.field( + "pad_token", + |config, v| { + config.padding.pad_token = v + .as_str() + .ok_or(anyhow::anyhow!("pad_token must be a string"))? + .to_string(); + + Ok(()) + }, + |config| config.padding.pad_token.clone(), + )?; + + builder.field( + "pad_token_type_id", + |config, v| { + config.padding.pad_type_id = v + .as_u64() + .map(|i| i as u32) + .ok_or(anyhow::anyhow!("pad_token_type_id must be an unsigned int"))?; + + Ok(()) + }, + |config| config.padding.pad_type_id, + )?; + + // now we fetch pad_token_id manually because it doesn't get serialized into tokenizer_config.json! + builder.set_pad_token_id(&tokenizer)?; + + builder.build() +} + +#[derive(Debug)] +struct TokenizerConfigBuilder<'a> { + config: TokenizerConfig, + val: &'a serde_json::value::Map, +} + +impl<'a> TokenizerConfigBuilder<'a> { + fn new(val: &'a serde_json::value::Map) -> Self { + Self { + config: TokenizerConfig::default(), + val, + } + } + + fn build(self) -> Result { + Ok(self.config) + } + + fn set_pad_token_id(&mut self, tokenizer: &Tokenizer) -> Result<()> { + let pad_token = self.config.padding.pad_token.as_str(); + self.config.padding.pad_id = tokenizer.token_to_id(pad_token).ok_or(anyhow::anyhow!( + "pad_token set to {}, but token does not exist in tokenizer", + pad_token + ))?; + + Ok(()) + } + + fn field( + &mut self, + field: &str, + process_value_fn: P, + default_value_fn: D, + ) -> Result<()> + where + P: FnOnce(&mut TokenizerConfig, &serde_json::Value) -> Result<()>, + D: FnOnce(&TokenizerConfig) -> V, + V: std::fmt::Debug, + { + match self.val.get(field) { + Some(v) => process_value_fn(&mut self.config, v), + None => { + if !self.val.contains_key(field) { + eprintln!( + "WARNING: No {} found in tokenizer_config.json. Using default: {:?}", + field, + default_value_fn(&self.config), + ) + } + + Ok(()) + } + } + } +} + +#[cfg(test)] +mod tests { + use encoderfile_core::common::ModelType; + + use crate::config::{ModelPath, TokenizerBuildConfig}; + + use super::*; + + #[test] + fn test_validate_tokenizer() { + let config = EncoderfileConfig { + name: "my-model".into(), + version: "0.0.1".into(), + path: ModelPath::Directory("../models/embedding".into()), + model_type: ModelType::Embedding, + output_path: None, + cache_dir: None, + transform: None, + tokenizer: None, + validate_transform: false, + build: false, + }; + + let tokenizer_config = config + .validate_tokenizer() + .expect("Failed to validate tokenizer"); + + assert_eq!(format!("{:?}", tokenizer_config.padding.direction), "Right"); + assert_eq!( + format!("{:?}", tokenizer_config.padding.strategy), + "BatchLongest" + ); + assert_eq!(tokenizer_config.padding.pad_id, 0); + assert_eq!(tokenizer_config.padding.pad_token, "[PAD]"); + assert!(tokenizer_config.padding.pad_to_multiple_of.is_none()); + assert_eq!(tokenizer_config.padding.pad_type_id, 0); + } + + #[test] + fn test_validate_tokenizer_fixed() { + let config = EncoderfileConfig { + name: "my-model".into(), + version: "0.0.1".into(), + path: ModelPath::Directory("../models/embedding".into()), + model_type: ModelType::Embedding, + output_path: None, + cache_dir: None, + transform: None, + tokenizer: Some(TokenizerBuildConfig { + pad_strategy: Some(TokenizerPadStrategy::Fixed { fixed: 512 }), + }), + validate_transform: false, + build: false, + }; + + let tokenizer_config = config + .validate_tokenizer() + .expect("Failed to validate tokenizer"); + + assert_eq!(format!("{:?}", tokenizer_config.padding.direction), "Right"); + assert_eq!( + format!("{:?}", tokenizer_config.padding.strategy), + "Fixed(512)" + ); + assert_eq!(tokenizer_config.padding.pad_id, 0); + assert_eq!(tokenizer_config.padding.pad_token, "[PAD]"); + assert!(tokenizer_config.padding.pad_to_multiple_of.is_none()); + assert_eq!(tokenizer_config.padding.pad_type_id, 0); + } + + #[test] + fn test_validate_tokenizer_no_config() { + let path = ModelPath::Directory("../models/token_classification".into()); + + let explicit_path = ModelPath::Paths { + model_config_path: path.model_config_path().unwrap(), + model_weights_path: path.model_weights_path().unwrap(), + tokenizer_path: path.tokenizer_path().unwrap(), + tokenizer_config_path: None, + }; + + let config = EncoderfileConfig { + name: "my-model".into(), + version: "0.0.1".into(), + path: explicit_path, + model_type: ModelType::Embedding, + output_path: None, + cache_dir: None, + transform: None, + tokenizer: None, + validate_transform: false, + build: false, + }; + + let tokenizer_config = config + .validate_tokenizer() + .expect("Failed to validate tokenizer"); + + assert_eq!(format!("{:?}", tokenizer_config.padding.direction), "Right"); + assert_eq!( + format!("{:?}", tokenizer_config.padding.strategy), + "BatchLongest" + ); + assert_eq!(tokenizer_config.padding.pad_id, 0); + assert_eq!(tokenizer_config.padding.pad_token, "[PAD]"); + assert!(tokenizer_config.padding.pad_to_multiple_of.is_none()); + assert_eq!(tokenizer_config.padding.pad_type_id, 0); + } +} diff --git a/encoderfile/src/transforms/validation/embedding.rs b/encoderfile/src/transforms/validation/embedding.rs index 108e6ae4..a36ea1b1 100644 --- a/encoderfile/src/transforms/validation/embedding.rs +++ b/encoderfile/src/transforms/validation/embedding.rs @@ -71,6 +71,7 @@ mod tests { transform: None, validate_transform: true, build: true, + tokenizer: None, } } diff --git a/encoderfile/src/transforms/validation/mod.rs b/encoderfile/src/transforms/validation/mod.rs index baa1e05c..834488cb 100644 --- a/encoderfile/src/transforms/validation/mod.rs +++ b/encoderfile/src/transforms/validation/mod.rs @@ -105,6 +105,7 @@ mod tests { transform: None, validate_transform: true, build: true, + tokenizer: None, } } @@ -148,6 +149,7 @@ mod tests { transform: Some(Transform::Inline(transform_str.to_string())), validate_transform: true, build: true, + tokenizer: None, }; let model_config_str = @@ -171,6 +173,7 @@ mod tests { transform: None, validate_transform: true, build: true, + tokenizer: None, }; let model_config_str = diff --git a/encoderfile/src/transforms/validation/sentence_embedding.rs b/encoderfile/src/transforms/validation/sentence_embedding.rs index 28c7a65e..86cd28f4 100644 --- a/encoderfile/src/transforms/validation/sentence_embedding.rs +++ b/encoderfile/src/transforms/validation/sentence_embedding.rs @@ -74,6 +74,7 @@ mod tests { transform: None, validate_transform: true, build: true, + tokenizer: None, } } diff --git a/encoderfile/src/transforms/validation/sequence_classification.rs b/encoderfile/src/transforms/validation/sequence_classification.rs index 0452e448..a2315f0e 100644 --- a/encoderfile/src/transforms/validation/sequence_classification.rs +++ b/encoderfile/src/transforms/validation/sequence_classification.rs @@ -70,6 +70,7 @@ mod tests { transform: None, validate_transform: true, build: true, + tokenizer: None, } } diff --git a/encoderfile/src/transforms/validation/token_classification.rs b/encoderfile/src/transforms/validation/token_classification.rs index 5596fb5e..a9ddd601 100644 --- a/encoderfile/src/transforms/validation/token_classification.rs +++ b/encoderfile/src/transforms/validation/token_classification.rs @@ -71,6 +71,7 @@ mod tests { transform: None, validate_transform: true, build: true, + tokenizer: None, } } diff --git a/schemas/encoderfile-config-schema.json b/schemas/encoderfile-config-schema.json index 44a0e0d3..09f06ade 100644 --- a/schemas/encoderfile-config-schema.json +++ b/schemas/encoderfile-config-schema.json @@ -1,6 +1,6 @@ { "$schema": "https://json-schema.org/draft/2020-12/schema", - "title": "Config", + "title": "BuildConfig", "type": "object", "properties": { "encoderfile": { @@ -39,6 +39,16 @@ "path": { "$ref": "#/$defs/ModelPath" }, + "tokenizer": { + "anyOf": [ + { + "$ref": "#/$defs/TokenizerBuildConfig" + }, + { + "type": "null" + } + ] + }, "transform": { "anyOf": [ { @@ -78,6 +88,12 @@ "model_weights_path": { "type": "string" }, + "tokenizer_config_path": { + "type": [ + "string", + "null" + ] + }, "tokenizer_path": { "type": "string" } @@ -99,6 +115,41 @@ "sentence_embedding" ] }, + "TokenizerBuildConfig": { + "type": "object", + "properties": { + "pad_strategy": { + "anyOf": [ + { + "$ref": "#/$defs/TokenizerPadStrategy" + }, + { + "type": "null" + } + ] + } + } + }, + "TokenizerPadStrategy": { + "anyOf": [ + { + "type": "null" + }, + { + "type": "object", + "properties": { + "fixed": { + "type": "integer", + "format": "uint", + "minimum": 0 + } + }, + "required": [ + "fixed" + ] + } + ] + }, "Transform": { "anyOf": [ {