Skip to content

Commit 3fdc224

Browse files
Fix projection layer loading for voyage-4-nano
The projection layer weight (linear.weight) is at the root level of the safetensors file, not under the "model" prefix. We need to capture the root VarBuilder before adding the "model" prefix. Tested with voyage-4-nano: - Output dimension: 2048 (correct) - Cosine similarity vs transformers: 0.999965 - Inference time: 7.8ms on L4 GPU
1 parent 2f1e3ee commit 3fdc224

File tree

2 files changed

+6
-16
lines changed

2 files changed

+6
-16
lines changed

backends/candle/src/models/flash_qwen3.rs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,8 @@ impl FlashQwen3Model {
317317

318318
// The Qwen3-Reranker models contain the `model` key
319319
// https://huggingface.co/collections/Qwen/qwen3-reranker-6841b22d0192d7ade9cdefea
320+
// Keep reference to root vb for loading projection layer
321+
let vb_root = vb.clone();
320322
let vb = if vb.contains_tensor("model.embed_tokens.weight") {
321323
vb.pp("model")
322324
} else {
@@ -337,15 +339,8 @@ impl FlashQwen3Model {
337339

338340
// voyage-4-nano: load projection layer if num_labels is set
339341
// The projection transforms hidden_size (1024) to num_labels (2048)
342+
// Use vb_root (root level) since linear.weight is at root, not under "model"
340343
let projection = if let Some(num_labels) = config.num_labels {
341-
// Try to load from the model root (voyage-4-nano uses "linear.weight")
342-
let vb_root = if vb.contains_tensor("linear.weight") {
343-
vb.clone()
344-
} else {
345-
// Also check under "model" prefix for reranker-style models
346-
vb.pp("..") // go up one level if we're already in "model"
347-
};
348-
349344
if vb_root.contains_tensor("linear.weight") {
350345
let projection_weight = vb_root.get((num_labels, config.hidden_size), "linear.weight")?;
351346
Some(Linear::new(projection_weight, None, None))

backends/candle/src/models/qwen3.rs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,8 @@ impl Qwen3Model {
410410

411411
// The Qwen3-Reranker models contain the `model` key
412412
// https://huggingface.co/collections/Qwen/qwen3-reranker-6841b22d0192d7ade9cdefea
413+
// Keep reference to root vb for loading projection layer
414+
let vb_root = vb.clone();
413415
let vb = if vb.contains_tensor("model.embed_tokens.weight") {
414416
vb.pp("model")
415417
} else {
@@ -430,15 +432,8 @@ impl Qwen3Model {
430432

431433
// voyage-4-nano: load projection layer if num_labels is set
432434
// The projection transforms hidden_size (1024) to num_labels (2048)
435+
// Use vb_root (root level) since linear.weight is at root, not under "model"
433436
let projection = if let Some(num_labels) = config.num_labels {
434-
// Try to load from the model root (voyage-4-nano uses "linear.weight")
435-
let vb_root = if vb.contains_tensor("linear.weight") {
436-
vb.clone()
437-
} else {
438-
// Also check under "model" prefix for reranker-style models
439-
vb.pp("..") // go up one level if we're already in "model"
440-
};
441-
442437
if vb_root.contains_tensor("linear.weight") {
443438
let projection_weight = vb_root.get((num_labels, config.hidden_size), "linear.weight")?;
444439
Some(Linear::new(projection_weight, None, None))

0 commit comments

Comments
 (0)