diff --git a/src/model.rs b/src/model.rs index d695bb9..dff1a8a 100644 --- a/src/model.rs +++ b/src/model.rs @@ -64,32 +64,12 @@ impl StaticModel { // Load the tokenizer let tokenizer = Tokenizer::from_file(&tok_path).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?; - // Median-token-length hack for pre-truncation - let mut lens: Vec = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect(); - lens.sort_unstable(); - let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1); - // Read normalize default from config.json let cfg_file = std::fs::File::open(&cfg_path).context("failed to read config.json")?; let cfg: Value = serde_json::from_reader(&cfg_file).context("failed to parse config.json")?; let cfg_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true); let normalize = normalize.unwrap_or(cfg_norm); - // Serialize the tokenizer to JSON, then parse it and get the unk_token - let spec_json = tokenizer - .to_string(false) - .map_err(|e| anyhow!("tokenizer -> JSON failed: {e}"))?; - let spec: Value = serde_json::from_str(&spec_json)?; - let unk_token = spec - .get("model") - .and_then(|m| m.get("unk_token")) - .and_then(Value::as_str) - .unwrap_or("[UNK]"); - let unk_token_id = tokenizer - .token_to_id(unk_token) - .ok_or_else(|| anyhow!("tokenizer claims unk_token='{unk_token}' but it isn't in the vocab"))? - as usize; - // Load the safetensors let model_bytes = fs::read(&mdl_path).context("failed to read model.safetensors")?; let safet = SafeTensors::deserialize(&model_bytes).context("failed to parse safetensors")?; @@ -115,7 +95,6 @@ impl StaticModel { Dtype::I8 => raw.iter().map(|&b| f32::from(b as i8)).collect(), other => return Err(anyhow!("unsupported tensor dtype: {other:?}")), }; - let embeddings = Array2::from_shape_vec((rows, cols), floats).context("failed to build embeddings array")?; // Load optional weights for vocabulary quantization let weights = match safet.tensor("weights") { @@ -154,6 +133,56 @@ impl StaticModel { Err(_) => None, }; + Self::from_raw_parts(tokenizer, &floats, rows, cols, normalize, weights, token_mapping) + } + + /// Construct from pre-parsed parts. + /// + /// # Arguments + /// * `tokenizer` - Pre-deserialized tokenizer + /// * `embeddings` - Raw f32 embedding data + /// * `rows` - Number of vocabulary entries + /// * `cols` - Embedding dimension + /// * `normalize` - Whether to L2-normalize output embeddings + /// * `weights` - Optional per-token weights for quantized models + /// * `token_mapping` - Optional token ID mapping for quantized models + pub fn from_raw_parts( + tokenizer: Tokenizer, + embeddings: &[f32], + rows: usize, + cols: usize, + normalize: bool, + weights: Option>, + token_mapping: Option>, + ) -> Result { + if embeddings.len() != rows * cols { + return Err(anyhow!( + "embeddings length {} != rows {} * cols {}", + embeddings.len(), + rows, + cols + )); + } + + // Median-token-length hack for pre-truncation + let mut lens: Vec = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect(); + lens.sort_unstable(); + let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1); + + // Get unk_token from tokenizer (optional - BPE tokenizers may not have one) + let spec_json = tokenizer + .to_string(false) + .map_err(|e| anyhow!("tokenizer -> JSON failed: {e}"))?; + let spec: Value = serde_json::from_str(&spec_json)?; + let unk_token = spec + .get("model") + .and_then(|m| m.get("unk_token")) + .and_then(Value::as_str); + let unk_token_id = unk_token.and_then(|tok| tokenizer.token_to_id(tok)).map(|id| id as usize); + + let embeddings = Array2::from_shape_vec((rows, cols), embeddings.to_vec()) + .context("failed to build embeddings array")?; + Ok(Self { tokenizer, embeddings, @@ -161,7 +190,7 @@ impl StaticModel { token_mapping, normalize, median_token_length, - unk_token_id: Some(unk_token_id), + unk_token_id, }) } diff --git a/tests/test_model.rs b/tests/test_model.rs index f09b8c2..abb3761 100644 --- a/tests/test_model.rs +++ b/tests/test_model.rs @@ -70,3 +70,26 @@ fn test_normalization_flag_override() { "Without normalization override, norm should be larger" ); } + +/// Test from_raw_parts constructor +#[test] +fn test_from_raw_parts() { + use std::fs; + use tokenizers::Tokenizer; + use safetensors::SafeTensors; + + let path = "tests/fixtures/test-model-float32"; + let tokenizer = Tokenizer::from_file(format!("{path}/tokenizer.json")).unwrap(); + let bytes = fs::read(format!("{path}/model.safetensors")).unwrap(); + let tensors = SafeTensors::deserialize(&bytes).unwrap(); + let tensor = tensors.tensor("embeddings").unwrap(); + let [rows, cols]: [usize; 2] = tensor.shape().try_into().unwrap(); + let floats: Vec = tensor.data() + .chunks_exact(4) + .map(|b| f32::from_le_bytes(b.try_into().unwrap())) + .collect(); + + let model = StaticModel::from_raw_parts(tokenizer, &floats, rows, cols, true, None, None).unwrap(); + let emb = model.encode_single("hello"); + assert!(!emb.is_empty()); +}