-
Notifications
You must be signed in to change notification settings - Fork 14
feat: Add from_raw_parts() constructor
#33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -64,32 +64,12 @@ impl StaticModel { | |
| // Load the tokenizer | ||
| let tokenizer = Tokenizer::from_file(&tok_path).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?; | ||
|
|
||
| // Median-token-length hack for pre-truncation | ||
| let mut lens: Vec<usize> = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect(); | ||
| lens.sort_unstable(); | ||
| let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1); | ||
|
|
||
| // Read normalize default from config.json | ||
| let cfg_file = std::fs::File::open(&cfg_path).context("failed to read config.json")?; | ||
| let cfg: Value = serde_json::from_reader(&cfg_file).context("failed to parse config.json")?; | ||
| let cfg_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true); | ||
| let normalize = normalize.unwrap_or(cfg_norm); | ||
|
|
||
| // Serialize the tokenizer to JSON, then parse it and get the unk_token | ||
| let spec_json = tokenizer | ||
| .to_string(false) | ||
| .map_err(|e| anyhow!("tokenizer -> JSON failed: {e}"))?; | ||
| let spec: Value = serde_json::from_str(&spec_json)?; | ||
| let unk_token = spec | ||
| .get("model") | ||
| .and_then(|m| m.get("unk_token")) | ||
| .and_then(Value::as_str) | ||
| .unwrap_or("[UNK]"); | ||
| let unk_token_id = tokenizer | ||
| .token_to_id(unk_token) | ||
| .ok_or_else(|| anyhow!("tokenizer claims unk_token='{unk_token}' but it isn't in the vocab"))? | ||
| as usize; | ||
|
|
||
| // Load the safetensors | ||
| let model_bytes = fs::read(&mdl_path).context("failed to read model.safetensors")?; | ||
| let safet = SafeTensors::deserialize(&model_bytes).context("failed to parse safetensors")?; | ||
|
|
@@ -115,7 +95,6 @@ impl StaticModel { | |
| Dtype::I8 => raw.iter().map(|&b| f32::from(b as i8)).collect(), | ||
| other => return Err(anyhow!("unsupported tensor dtype: {other:?}")), | ||
| }; | ||
| let embeddings = Array2::from_shape_vec((rows, cols), floats).context("failed to build embeddings array")?; | ||
|
|
||
| // Load optional weights for vocabulary quantization | ||
| let weights = match safet.tensor("weights") { | ||
|
|
@@ -154,14 +133,64 @@ impl StaticModel { | |
| Err(_) => None, | ||
| }; | ||
|
|
||
| Self::from_raw_parts(tokenizer, &floats, rows, cols, normalize, weights, token_mapping) | ||
| } | ||
|
|
||
| /// Construct from pre-parsed parts. | ||
| /// | ||
| /// # Arguments | ||
| /// * `tokenizer` - Pre-deserialized tokenizer | ||
| /// * `embeddings` - Raw f32 embedding data | ||
| /// * `rows` - Number of vocabulary entries | ||
| /// * `cols` - Embedding dimension | ||
| /// * `normalize` - Whether to L2-normalize output embeddings | ||
| /// * `weights` - Optional per-token weights for quantized models | ||
| /// * `token_mapping` - Optional token ID mapping for quantized models | ||
| pub fn from_raw_parts( | ||
| tokenizer: Tokenizer, | ||
| embeddings: &[f32], | ||
| rows: usize, | ||
| cols: usize, | ||
| normalize: bool, | ||
| weights: Option<Vec<f32>>, | ||
| token_mapping: Option<Vec<usize>>, | ||
| ) -> Result<Self> { | ||
| if embeddings.len() != rows * cols { | ||
| return Err(anyhow!( | ||
| "embeddings length {} != rows {} * cols {}", | ||
| embeddings.len(), | ||
| rows, | ||
| cols | ||
| )); | ||
| } | ||
|
|
||
| // Median-token-length hack for pre-truncation | ||
| let mut lens: Vec<usize> = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect(); | ||
| lens.sort_unstable(); | ||
| let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1); | ||
|
|
||
| // Get unk_token from tokenizer (optional - BPE tokenizers may not have one) | ||
| let spec_json = tokenizer | ||
| .to_string(false) | ||
| .map_err(|e| anyhow!("tokenizer -> JSON failed: {e}"))?; | ||
| let spec: Value = serde_json::from_str(&spec_json)?; | ||
| let unk_token = spec | ||
| .get("model") | ||
| .and_then(|m| m.get("unk_token")) | ||
| .and_then(Value::as_str); | ||
| let unk_token_id = unk_token.and_then(|tok| tokenizer.token_to_id(tok)).map(|id| id as usize); | ||
|
|
||
| let embeddings = Array2::from_shape_vec((rows, cols), embeddings.to_vec()) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This does an extra full copy of the embeddings which adds up for larger models. I think the easiest solution is to change embeddings in let embeddings = Array2::from_shape_vec((rows, cols), embeddings)
.context("failed to build embeddings array")?;And then |
||
| .context("failed to build embeddings array")?; | ||
|
|
||
| Ok(Self { | ||
| tokenizer, | ||
| embeddings, | ||
| weights, | ||
| token_mapping, | ||
| normalize, | ||
| median_token_length, | ||
| unk_token_id: Some(unk_token_id), | ||
| unk_token_id, | ||
| }) | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is too permissive; I agree that the previous logic was too strict, but I prefer a middleground:
So something like this should work I think: