Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 51 additions & 22 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,32 +64,12 @@ impl StaticModel {
// Load the tokenizer
let tokenizer = Tokenizer::from_file(&tok_path).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?;

// Median-token-length hack for pre-truncation
let mut lens: Vec<usize> = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect();
lens.sort_unstable();
let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1);

// Read normalize default from config.json
let cfg_file = std::fs::File::open(&cfg_path).context("failed to read config.json")?;
let cfg: Value = serde_json::from_reader(&cfg_file).context("failed to parse config.json")?;
let cfg_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true);
let normalize = normalize.unwrap_or(cfg_norm);

// Serialize the tokenizer to JSON, then parse it and get the unk_token
let spec_json = tokenizer
.to_string(false)
.map_err(|e| anyhow!("tokenizer -> JSON failed: {e}"))?;
let spec: Value = serde_json::from_str(&spec_json)?;
let unk_token = spec
.get("model")
.and_then(|m| m.get("unk_token"))
.and_then(Value::as_str)
.unwrap_or("[UNK]");
let unk_token_id = tokenizer
.token_to_id(unk_token)
.ok_or_else(|| anyhow!("tokenizer claims unk_token='{unk_token}' but it isn't in the vocab"))?
as usize;

// Load the safetensors
let model_bytes = fs::read(&mdl_path).context("failed to read model.safetensors")?;
let safet = SafeTensors::deserialize(&model_bytes).context("failed to parse safetensors")?;
Expand All @@ -115,7 +95,6 @@ impl StaticModel {
Dtype::I8 => raw.iter().map(|&b| f32::from(b as i8)).collect(),
other => return Err(anyhow!("unsupported tensor dtype: {other:?}")),
};
let embeddings = Array2::from_shape_vec((rows, cols), floats).context("failed to build embeddings array")?;

// Load optional weights for vocabulary quantization
let weights = match safet.tensor("weights") {
Expand Down Expand Up @@ -154,14 +133,64 @@ impl StaticModel {
Err(_) => None,
};

Self::from_raw_parts(tokenizer, &floats, rows, cols, normalize, weights, token_mapping)
}

/// Construct from pre-parsed parts.
///
/// # Arguments
/// * `tokenizer` - Pre-deserialized tokenizer
/// * `embeddings` - Raw f32 embedding data
/// * `rows` - Number of vocabulary entries
/// * `cols` - Embedding dimension
/// * `normalize` - Whether to L2-normalize output embeddings
/// * `weights` - Optional per-token weights for quantized models
/// * `token_mapping` - Optional token ID mapping for quantized models
pub fn from_raw_parts(
tokenizer: Tokenizer,
embeddings: &[f32],
rows: usize,
cols: usize,
normalize: bool,
weights: Option<Vec<f32>>,
token_mapping: Option<Vec<usize>>,
) -> Result<Self> {
if embeddings.len() != rows * cols {
return Err(anyhow!(
"embeddings length {} != rows {} * cols {}",
embeddings.len(),
rows,
cols
));
}

// Median-token-length hack for pre-truncation
let mut lens: Vec<usize> = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect();
lens.sort_unstable();
let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1);

// Get unk_token from tokenizer (optional - BPE tokenizers may not have one)
let spec_json = tokenizer
.to_string(false)
.map_err(|e| anyhow!("tokenizer -> JSON failed: {e}"))?;
let spec: Value = serde_json::from_str(&spec_json)?;
let unk_token = spec
.get("model")
.and_then(|m| m.get("unk_token"))
.and_then(Value::as_str);
let unk_token_id = unk_token.and_then(|tok| tokenizer.token_to_id(tok)).map(|id| id as usize);

let embeddings = Array2::from_shape_vec((rows, cols), embeddings.to_vec())
.context("failed to build embeddings array")?;

Ok(Self {
tokenizer,
embeddings,
weights,
token_mapping,
normalize,
median_token_length,
unk_token_id: Some(unk_token_id),
unk_token_id,
})
}

Expand Down
23 changes: 23 additions & 0 deletions tests/test_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,26 @@ fn test_normalization_flag_override() {
"Without normalization override, norm should be larger"
);
}

/// Test from_raw_parts constructor
#[test]
fn test_from_raw_parts() {
use std::fs;
use tokenizers::Tokenizer;
use safetensors::SafeTensors;

let path = "tests/fixtures/test-model-float32";
let tokenizer = Tokenizer::from_file(format!("{path}/tokenizer.json")).unwrap();
let bytes = fs::read(format!("{path}/model.safetensors")).unwrap();
let tensors = SafeTensors::deserialize(&bytes).unwrap();
let tensor = tensors.tensor("embeddings").unwrap();
let [rows, cols]: [usize; 2] = tensor.shape().try_into().unwrap();
let floats: Vec<f32> = tensor.data()
.chunks_exact(4)
.map(|b| f32::from_le_bytes(b.try_into().unwrap()))
.collect();

let model = StaticModel::from_raw_parts(tokenizer, &floats, rows, cols, true, None, None).unwrap();
let emb = model.encode_single("hello");
assert!(!emb.is_empty());
}
Loading