Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 51 additions & 22 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,32 +64,12 @@ impl StaticModel {
// Load the tokenizer
let tokenizer = Tokenizer::from_file(&tok_path).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?;

// Median-token-length hack for pre-truncation
let mut lens: Vec<usize> = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect();
lens.sort_unstable();
let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1);

// Read normalize default from config.json
let cfg_file = std::fs::File::open(&cfg_path).context("failed to read config.json")?;
let cfg: Value = serde_json::from_reader(&cfg_file).context("failed to parse config.json")?;
let cfg_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true);
let normalize = normalize.unwrap_or(cfg_norm);

// Serialize the tokenizer to JSON, then parse it and get the unk_token
let spec_json = tokenizer
.to_string(false)
.map_err(|e| anyhow!("tokenizer -> JSON failed: {e}"))?;
let spec: Value = serde_json::from_str(&spec_json)?;
let unk_token = spec
.get("model")
.and_then(|m| m.get("unk_token"))
.and_then(Value::as_str)
.unwrap_or("[UNK]");
let unk_token_id = tokenizer
.token_to_id(unk_token)
.ok_or_else(|| anyhow!("tokenizer claims unk_token='{unk_token}' but it isn't in the vocab"))?
as usize;

// Load the safetensors
let model_bytes = fs::read(&mdl_path).context("failed to read model.safetensors")?;
let safet = SafeTensors::deserialize(&model_bytes).context("failed to parse safetensors")?;
Expand All @@ -115,7 +95,6 @@ impl StaticModel {
Dtype::I8 => raw.iter().map(|&b| f32::from(b as i8)).collect(),
other => return Err(anyhow!("unsupported tensor dtype: {other:?}")),
};
let embeddings = Array2::from_shape_vec((rows, cols), floats).context("failed to build embeddings array")?;

// Load optional weights for vocabulary quantization
let weights = match safet.tensor("weights") {
Expand Down Expand Up @@ -154,14 +133,64 @@ impl StaticModel {
Err(_) => None,
};

Self::from_raw_parts(tokenizer, &floats, rows, cols, normalize, weights, token_mapping)
}

/// Construct from pre-parsed parts.
///
/// # Arguments
/// * `tokenizer` - Pre-deserialized tokenizer
/// * `embeddings` - Raw f32 embedding data
/// * `rows` - Number of vocabulary entries
/// * `cols` - Embedding dimension
/// * `normalize` - Whether to L2-normalize output embeddings
/// * `weights` - Optional per-token weights for quantized models
/// * `token_mapping` - Optional token ID mapping for quantized models
pub fn from_raw_parts(
tokenizer: Tokenizer,
embeddings: &[f32],
rows: usize,
cols: usize,
normalize: bool,
weights: Option<Vec<f32>>,
token_mapping: Option<Vec<usize>>,
) -> Result<Self> {
if embeddings.len() != rows * cols {
return Err(anyhow!(
"embeddings length {} != rows {} * cols {}",
embeddings.len(),
rows,
cols
));
}

// Median-token-length hack for pre-truncation
let mut lens: Vec<usize> = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect();
lens.sort_unstable();
let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1);

// Get unk_token from tokenizer (optional - BPE tokenizers may not have one)
let spec_json = tokenizer
.to_string(false)
.map_err(|e| anyhow!("tokenizer -> JSON failed: {e}"))?;
let spec: Value = serde_json::from_str(&spec_json)?;
let unk_token = spec
.get("model")
.and_then(|m| m.get("unk_token"))
.and_then(Value::as_str);
let unk_token_id = unk_token.and_then(|tok| tokenizer.token_to_id(tok)).map(|id| id as usize);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is too permissive; I agree that the previous logic was too strict, but I prefer a middleground:

  • if unk_token is absent: unk_token_id = None
  • if unk_token is present but not found in vocab: error
  • if unk_token is present and in vocab: use the unk_token

So something like this should work I think:

let unk_token_id = match unk_token {
    None => None, // Allow None if tokenizer does not declare one
    Some(tok) => match tokenizer.token_to_id(tok) {
        Some(id) => Some(id as usize),
        None => {
            return Err(anyhow!(
                "tokenizer declares unk_token='{tok}' but it isn't in the vocab"
            ))
        }
    },
};


let embeddings = Array2::from_shape_vec((rows, cols), embeddings.to_vec())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does an extra full copy of the embeddings which adds up for larger models. I think the easiest solution is to change embeddings in from_raw_parts from embeddings: &[f32, to embeddings: Vec[f32] and then this becomes:

let embeddings = Array2::from_shape_vec((rows, cols), embeddings)
    .context("failed to build embeddings array")?;

And then Self::from_raw_parts(tokenizer, &floats, rows, cols, normalize, weights, token_mapping) becomes Self::from_raw_parts(tokenizer, floats, rows, cols, normalize, weights, token_mapping). This way from_pretained can copy the embeddings over directly.

.context("failed to build embeddings array")?;

Ok(Self {
tokenizer,
embeddings,
weights,
token_mapping,
normalize,
median_token_length,
unk_token_id: Some(unk_token_id),
unk_token_id,
})
}

Expand Down
23 changes: 23 additions & 0 deletions tests/test_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,26 @@ fn test_normalization_flag_override() {
"Without normalization override, norm should be larger"
);
}

/// Test from_raw_parts constructor
#[test]
fn test_from_raw_parts() {
use std::fs;
use tokenizers::Tokenizer;
use safetensors::SafeTensors;

let path = "tests/fixtures/test-model-float32";
let tokenizer = Tokenizer::from_file(format!("{path}/tokenizer.json")).unwrap();
let bytes = fs::read(format!("{path}/model.safetensors")).unwrap();
let tensors = SafeTensors::deserialize(&bytes).unwrap();
let tensor = tensors.tensor("embeddings").unwrap();
let [rows, cols]: [usize; 2] = tensor.shape().try_into().unwrap();
let floats: Vec<f32> = tensor.data()
.chunks_exact(4)
.map(|b| f32::from_le_bytes(b.try_into().unwrap()))
.collect();

let model = StaticModel::from_raw_parts(tokenizer, &floats, rows, cols, true, None, None).unwrap();
let emb = model.encode_single("hello");
assert!(!emb.is_empty());
}
Loading