Skip to content
Merged
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ serde = { version = "1.0.228", features = [ "serde_derive",] }
tracing = "0.1.41"
thiserror = "2.0.17"

[workspace.dependencies.tokenizers]
version = "0.22.1"

[workspace.dependencies.ort]
version = "=2.0.0-rc.10"
features = [ "std", "tracing", "ndarray",]
5 changes: 2 additions & 3 deletions encoderfile-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ description = "Distribute and run transformer encoders with a single file."
default = [ "transport",]
transport = [ "runtime", "opentelemetry", "opentelemetry-semantic-conventions", "opentelemetry-stdout", "opentelemetry_sdk", "opentelemetry-otlp", "rmcp", "rmcp-macros", "axum", "axum-server", "clap", "clap_derive", "prost", "tonic-prost", "tonic-types", "tonic-web", "tracing-opentelemetry", "tonic", "tower-http",]
dev-utils = [ "runtime",]
runtime = [ "transforms", "tokenizers", "ort",]
runtime = [ "transforms", "ort",]
transforms = [ "mlua",]

[dependencies]
Expand Down Expand Up @@ -121,8 +121,7 @@ version = "0.14.1"
optional = true

[dependencies.tokenizers]
version = "0.22.1"
optional = true
workspace = true

[dependencies.tonic-prost]
version = "0.14.2"
Expand Down
10 changes: 8 additions & 2 deletions encoderfile-core/src/common/config.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
use super::model_type::ModelType;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use tokenizers::PaddingParams;

#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, Serialize, Deserialize)]
pub struct Config {
pub name: String,
pub version: String,
pub model_type: ModelType,
pub transform: Option<String>,
pub tokenizer: TokenizerConfig,
}

#[derive(Debug, Default, Serialize, Deserialize)]
pub struct TokenizerConfig {
pub padding: PaddingParams,
}
13 changes: 7 additions & 6 deletions encoderfile-core/src/dev_utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@ const SEQUENCE_CLASSIFICATION_DIR: &str = "../models/sequence_classification";
const TOKEN_CLASSIFICATION_DIR: &str = "../models/token_classification";

pub fn get_state<T: ModelTypeSpec>(dir: &str) -> AppState<T> {
let model_config = Arc::new(get_model_config(dir));
let tokenizer = Arc::new(get_tokenizer(dir, &model_config));
let session = Arc::new(get_model(dir));

let config = Arc::new(Config {
name: "my-model".to_string(),
version: "0.0.1".to_string(),
model_type: T::enum_val(),
transform: None,
tokenizer: Default::default(),
});

let model_config = Arc::new(get_model_config(dir));
let tokenizer = Arc::new(get_tokenizer(dir, &config));
let session = Arc::new(get_model(dir));

AppState::new(config, session, tokenizer, model_config)
}

Expand Down Expand Up @@ -52,11 +53,11 @@ fn get_model_config(dir: &str) -> ModelConfig {
serde_json::from_reader(reader).expect("Invalid model config")
}

fn get_tokenizer(dir: &str, config: &Arc<ModelConfig>) -> tokenizers::Tokenizer {
fn get_tokenizer(dir: &str, ec_config: &Arc<Config>) -> tokenizers::Tokenizer {
let tokenizer_str = std::fs::read_to_string(format!("{}/{}", dir, "tokenizer.json"))
.expect("Tokenizer json not found");

crate::runtime::get_tokenizer_from_string(tokenizer_str.as_str(), config)
crate::runtime::get_tokenizer_from_string(tokenizer_str.as_str(), ec_config)
}

fn get_model(dir: &str) -> Mutex<Session> {
Expand Down
2 changes: 1 addition & 1 deletion encoderfile-core/src/factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ where
let config = get_config(config_str);
let session = get_model(model_bytes);
let model_config = get_model_config(model_config_str);
let tokenizer = get_tokenizer(tokenizer_json, &model_config);
let tokenizer = get_tokenizer(tokenizer_json, &config);

let state = AppState::new(config, session, tokenizer, model_config);

Expand Down
36 changes: 6 additions & 30 deletions encoderfile-core/src/runtime/tokenizer.rs
Original file line number Diff line number Diff line change
@@ -1,48 +1,24 @@
use crate::{common::ModelConfig, error::ApiError};
use crate::{common::Config, error::ApiError};
use anyhow::Result;
use std::str::FromStr;
use std::sync::{Arc, OnceLock};
use tokenizers::{
Encoding, PaddingDirection, PaddingParams, PaddingStrategy, tokenizer::Tokenizer,
};
use tokenizers::{Encoding, tokenizer::Tokenizer};

static TOKENIZER: OnceLock<Arc<Tokenizer>> = OnceLock::new();

pub fn get_tokenizer(tokenizer_json: &str, model_config: &Arc<ModelConfig>) -> Arc<Tokenizer> {
pub fn get_tokenizer(tokenizer_json: &str, ec_config: &Arc<Config>) -> Arc<Tokenizer> {
TOKENIZER
.get_or_init(|| Arc::new(get_tokenizer_from_string(tokenizer_json, model_config)))
.get_or_init(|| Arc::new(get_tokenizer_from_string(tokenizer_json, ec_config)))
.clone()
}

pub fn get_tokenizer_from_string(s: &str, config: &Arc<ModelConfig>) -> Tokenizer {
let pad_token_id = config.pad_token_id;

pub fn get_tokenizer_from_string(s: &str, ec_config: &Arc<Config>) -> Tokenizer {
let mut tokenizer = match Tokenizer::from_str(s) {
Ok(t) => t,
Err(e) => panic!("FATAL: Error loading tokenizer: {e:?}"),
};

let pad_token = match tokenizer.id_to_token(pad_token_id) {
Some(tok) => tok,
None => panic!("Model requires a padding token."),
};

if tokenizer.get_padding().is_none() {
let params = PaddingParams {
strategy: PaddingStrategy::BatchLongest,
direction: PaddingDirection::Right,
pad_to_multiple_of: None,
pad_id: pad_token_id,
pad_type_id: 0,
pad_token,
};

tracing::warn!(
"No padding strategy specified in tokenizer config. Setting default: {:?}",
&params
);
tokenizer.with_padding(Some(params));
}
tokenizer.with_padding(Some(ec_config.tokenizer.padding.clone()));

tokenizer
}
Expand Down
3 changes: 3 additions & 0 deletions encoderfile/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ dev-utils = []
[dev-dependencies]
tempfile = "3.23.0"

[dependencies.tokenizers]
workspace = true

[dependencies.anyhow]
workspace = true

Expand Down
107 changes: 86 additions & 21 deletions encoderfile/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pub struct EncoderfileConfig {
pub output_path: Option<PathBuf>,
pub cache_dir: Option<PathBuf>,
pub transform: Option<Transform>,
pub tokenizer: Option<TokenizerBuildConfig>,
#[serde(default = "default_validate_transform")]
pub validate_transform: bool,
#[serde(default = "default_build")]
Expand All @@ -46,11 +47,13 @@ pub struct EncoderfileConfig {

impl EncoderfileConfig {
pub fn embedded_config(&self) -> Result<EmbeddedConfig> {
let tokenizer = self.validate_tokenizer()?;
let config = EmbeddedConfig {
name: self.name.clone(),
version: self.version.clone(),
model_type: self.model_type.clone(),
transform: self.transform()?,
tokenizer,
};

Ok(config)
Expand Down Expand Up @@ -117,6 +120,18 @@ impl EncoderfileConfig {
}
}

#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct TokenizerBuildConfig {
pub pad_strategy: Option<TokenizerPadStrategy>,
}

#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[serde(untagged, rename_all = "snake_case")]
pub enum TokenizerPadStrategy {
BatchLongest,
Fixed { fixed: usize },
}

#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[serde(untagged)]
pub enum Transform {
Expand Down Expand Up @@ -152,27 +167,60 @@ pub enum ModelPath {
model_config_path: PathBuf,
model_weights_path: PathBuf,
tokenizer_path: PathBuf,
tokenizer_config_path: Option<PathBuf>,
},
}

macro_rules! asset_path {
($var:ident, $default:expr, $err:expr) => {
pub fn $var(&self) -> Result<PathBuf> {
let path = match self {
Self::Paths { $var, .. } => $var.clone(),
Self::Directory(dir) => {
if !dir.is_dir() {
bail!("No such directory: {:?}", dir);
}
dir.join($default)
impl ModelPath {
fn resolve(
&self,
explicit: Option<PathBuf>,
default: impl FnOnce(&PathBuf) -> PathBuf,
err: &str,
) -> Result<Option<PathBuf>> {
let path = match self {
Self::Paths { .. } => explicit,
Self::Directory(dir) => {
if !dir.is_dir() {
bail!("No such directory: {:?}", dir);
}
};
Some(default(dir))
}
};

if !path.try_exists()? {
bail!("Could not locate {} at path: {:?}", $err, path);
match path {
Some(p) => {
if !p.try_exists()? {
bail!("Could not locate {} at path: {:?}", err, p);
}
Ok(Some(p.canonicalize()?))
}
None => Ok(None),
}
}
}

macro_rules! asset_path {
(@Optional $name:ident, $default:expr, $err:expr) => {
pub fn $name(&self) -> Result<Option<PathBuf>> {
let explicit = match self {
Self::Paths { $name, .. } => $name.clone(),
_ => None,
};

self.resolve(explicit, |dir| dir.join($default), $err)
}
};

($name:ident, $default:expr, $err:expr) => {
pub fn $name(&self) -> Result<PathBuf> {
let explicit = match self {
Self::Paths { $name, .. } => Some($name.clone()),
_ => None,
};

Ok(path.canonicalize()?)
self.resolve(explicit, |dir| dir.join($default), $err)?
.ok_or_else(|| anyhow::anyhow!("Missing required path: {}", $err))
}
};
}
Expand All @@ -181,6 +229,7 @@ impl ModelPath {
asset_path!(model_config_path, "config.json", "model config");
asset_path!(tokenizer_path, "tokenizer.json", "tokenizer");
asset_path!(model_weights_path, "model.onnx", "model weights");
asset_path!(@Optional tokenizer_config_path, "tokenizer_config.json", "tokenizer config");
}

fn default_cache_dir() -> PathBuf {
Expand Down Expand Up @@ -222,12 +271,19 @@ mod tests {
base
}

// Create temp output dir
fn create_temp_output_dir() -> PathBuf {
create_test_dir("model")
}

// Create a model dir populated with the required files
fn create_model_dir() -> PathBuf {
fn create_temp_model_dir() -> PathBuf {
let base = create_test_dir("model");
fs::write(base.join("config.json"), "{}").expect("Failed to create config.json");
fs::write(base.join("tokenizer.json"), "{}").expect("Failed to create tokenizer.json");
fs::write(base.join("model.onnx"), "onnx").expect("Failed to create model.onnx");
fs::write(base.join("tokenizer_config.json"), "{}")
.expect("Failed to create tokenizer_config.json");
base
}

Expand All @@ -243,12 +299,18 @@ mod tests {

#[test]
fn test_modelpath_directory_valid() {
let base = create_model_dir();
let base = create_temp_model_dir();
let mp = ModelPath::Directory(base.clone());

assert!(mp.model_config_path().unwrap().ends_with("config.json"));
assert!(mp.tokenizer_path().unwrap().ends_with("tokenizer.json"));
assert!(mp.model_weights_path().unwrap().ends_with("model.onnx"));
assert!(
mp.tokenizer_config_path()
.unwrap()
.unwrap()
.ends_with("tokenizer_config.json")
);

cleanup(&base);
}
Expand All @@ -266,11 +328,12 @@ mod tests {

#[test]
fn test_modelpath_explicit_paths() {
let base = create_model_dir();
let base = create_temp_model_dir();
let mp = ModelPath::Paths {
model_config_path: base.join("config.json"),
tokenizer_path: base.join("tokenizer.json"),
model_weights_path: base.join("model.onnx"),
tokenizer_config_path: Some(base.join("tokenizer_config.json")),
};

assert!(mp.model_config_path().is_ok());
Expand Down Expand Up @@ -310,17 +373,18 @@ mod tests {

#[test]
fn test_encoderfile_generated_dir() {
let base = create_model_dir();
let base = create_temp_output_dir();

let cfg = EncoderfileConfig {
name: "my-cool-model".into(),
version: "1.0".into(),
path: ModelPath::Directory(base.clone()),
path: ModelPath::Directory("../models/embedding".into()),
model_type: ModelType::Embedding,
output_path: Some(base.clone()),
cache_dir: Some(base.clone()),
validate_transform: false,
transform: None,
tokenizer: None,
build: true,
};

Expand All @@ -332,16 +396,17 @@ mod tests {

#[test]
fn test_encoderfile_to_tera_ctx() {
let base = create_model_dir();
let base = create_temp_output_dir();
let cfg = EncoderfileConfig {
name: "sadness".into(),
version: "0.1.0".into(),
path: ModelPath::Directory(base.clone()),
path: ModelPath::Directory("../models/embedding".into()),
model_type: ModelType::SequenceClassification,
output_path: Some(base.clone()),
cache_dir: Some(base.clone()),
validate_transform: false,
transform: Some(Transform::Inline("1+1".into())),
tokenizer: None,
build: true,
};

Expand Down
1 change: 1 addition & 0 deletions encoderfile/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ pub mod cli;
pub mod config;
pub mod model;
pub mod templates;
pub mod tokenizer;
pub mod transforms;
Loading
Loading