Skip to content

Commit 539f322

Browse files
Add support for voyage-4-nano embedding model
Add two new config fields to Qwen3 to support voyage-4-nano and similar models: - `use_bidirectional_attention`: When true, disables causal masking for embedding models that use full bidirectional attention - `num_labels`: When set, loads projection layer from linear.weight at safetensors root level (e.g., 1024 -> 2048 for voyage-4-nano) Both fields are backwards compatible, defaulting to disabled behavior. Changes: - backends/candle/src/models/qwen3.rs: Add config fields and CPU impl - backends/candle/src/models/flash_qwen3.rs: Add CUDA/flash-attn impl - backends/candle/tests/test_voyage_nano.rs: CPU tests with snapshots - backends/candle/tests/test_flash_voyage_nano.rs: CUDA tests - README.md, docs/source/en/supported_models.md: Add voyage-4-nano Tested with voyageai/voyage-4-nano: - Output dimension: 2048 (correct) - Cosine similarity vs transformers: 0.999965 - Inference time: ~9ms on L4 GPU (vs 35ms with transformers)
1 parent cb9de7a commit 539f322

File tree

8 files changed

+8388
-3
lines changed

8 files changed

+8388
-3
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ Below are some examples of the currently supported models:
9191
| N/A | 396M | ModernBERT | [answerdotai/ModernBERT-large](https://hf.co/answerdotai/ModernBERT-large) |
9292
| N/A | 137M | JinaBERT | [jinaai/jina-embeddings-v2-base-en](https://hf.co/jinaai/jina-embeddings-v2-base-en) |
9393
| N/A | 137M | JinaBERT | [jinaai/jina-embeddings-v2-base-code](https://hf.co/jinaai/jina-embeddings-v2-base-code) |
94+
| N/A | 340M | Qwen3 | [voyageai/voyage-4-nano](https://hf.co/voyageai/voyage-4-nano) |
9495

9596
To explore the list of best performing text embeddings models, visit the
9697
[Massive Text Embedding Benchmark (MTEB) Leaderboard](https://huggingface.co/spaces/mteb/leaderboard).

backends/candle/src/models/flash_qwen3.rs

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ impl Qwen3Attention {
109109
cos: &Tensor,
110110
sin: &Tensor,
111111
max_s: usize,
112+
causal: bool,
112113
) -> Result<Tensor> {
113114
let _enter = self.span.enter();
114115

@@ -158,7 +159,7 @@ impl Qwen3Attention {
158159
max_s,
159160
max_s,
160161
self.softmax_scale,
161-
true,
162+
causal,
162163
None,
163164
None,
164165
)?;
@@ -262,14 +263,15 @@ impl Qwen3Layer {
262263
cos: &Tensor,
263264
sin: &Tensor,
264265
max_s: usize,
266+
causal: bool,
265267
) -> Result<(Tensor, Tensor)> {
266268
let _enter = self.span.enter();
267269

268270
let (normed_hidden_states, res) = self.input_layer_norm.forward(hidden_states, residual)?;
269271

270272
let attn_output =
271273
self.attention
272-
.forward(&normed_hidden_states, cu_seqlens, cos, sin, max_s)?;
274+
.forward(&normed_hidden_states, cu_seqlens, cos, sin, max_s, causal)?;
273275

274276
let (normed_attn_res_output, attn_res) = self
275277
.post_attention_layer_norm
@@ -285,9 +287,11 @@ pub struct FlashQwen3Model {
285287
embeddings: Embedding,
286288
layers: Vec<Qwen3Layer>,
287289
norm: RMSNorm,
290+
projection: Option<Linear>,
288291
cos_cache: Tensor,
289292
sin_cache: Tensor,
290293
pool: Pool,
294+
use_bidirectional_attention: bool,
291295
pub device: Device,
292296

293297
span: tracing::Span,
@@ -313,6 +317,8 @@ impl FlashQwen3Model {
313317

314318
// The Qwen3-Reranker models contain the `model` key
315319
// https://huggingface.co/collections/Qwen/qwen3-reranker-6841b22d0192d7ade9cdefea
320+
// Keep reference to root vb for loading projection layer
321+
let vb_root = vb.clone();
316322
let vb = if vb.contains_tensor("model.embed_tokens.weight") {
317323
vb.pp("model")
318324
} else {
@@ -331,6 +337,23 @@ impl FlashQwen3Model {
331337

332338
let norm = RMSNorm::load(vb.pp("norm"), config.hidden_size, config.rms_norm_eps)?;
333339

340+
let projection = if let Some(num_labels) = config.num_labels {
341+
if vb_root.contains_tensor("linear.weight") {
342+
let projection_weight =
343+
vb_root.get((num_labels, config.hidden_size), "linear.weight")?;
344+
Some(Linear::new(projection_weight, None, None))
345+
} else {
346+
tracing::warn!(
347+
"num_labels is set but linear.weight not found, skipping projection layer"
348+
);
349+
None
350+
}
351+
} else {
352+
None
353+
};
354+
355+
let use_bidirectional_attention = config.use_bidirectional_attention.unwrap_or(false);
356+
334357
let inv_freqs = get_inv_freqs(
335358
layers[0].attention.attention_head_size,
336359
config.rope_theta,
@@ -348,9 +371,11 @@ impl FlashQwen3Model {
348371
embeddings,
349372
layers,
350373
norm,
374+
projection,
351375
cos_cache,
352376
sin_cache,
353377
pool,
378+
use_bidirectional_attention,
354379
device: vb.device().clone(),
355380
span: tracing::span!(tracing::Level::TRACE, "model"),
356381
})
@@ -376,6 +401,8 @@ impl FlashQwen3Model {
376401
let cos = index_select(&self.cos_cache, &position_ids, 0)?;
377402
let sin = index_select(&self.sin_cache, &position_ids, 0)?;
378403

404+
let causal = !self.use_bidirectional_attention;
405+
379406
let mut residual = None;
380407
for layer in &self.layers {
381408
let (h, r) = layer.forward(
@@ -385,13 +412,20 @@ impl FlashQwen3Model {
385412
&cos,
386413
&sin,
387414
batch.max_length as usize,
415+
causal,
388416
)?;
389417
hidden_states = h;
390418
residual = Some(r);
391419
}
392420

393421
let (outputs, _) = self.norm.forward(&hidden_states, residual.as_ref())?;
394422

423+
let outputs = if let Some(ref projection) = self.projection {
424+
projection.forward(&outputs)?
425+
} else {
426+
outputs
427+
};
428+
395429
let has_pooling_requests = !batch.pooled_indices.is_empty();
396430
let has_raw_requests = !batch.raw_indices.is_empty();
397431

backends/candle/src/models/qwen3.rs

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ pub struct Qwen3Config {
2424
pub sliding_window: Option<usize>,
2525
pub use_sliding_window: bool,
2626
pub eos_token_id: usize,
27+
#[serde(default)]
28+
pub use_bidirectional_attention: Option<bool>,
29+
#[serde(default)]
30+
pub num_labels: Option<usize>,
2731
}
2832

2933
struct Qwen3Attention {
@@ -379,11 +383,13 @@ pub struct Qwen3Model {
379383
embeddings: Embedding,
380384
layers: Vec<Qwen3Layer>,
381385
norm: RMSNorm,
386+
projection: Option<Linear>,
382387
rotary_cache: (Tensor, Tensor),
383388
rotary_dim: usize,
384389
pool: Pool,
385390
num_attention_heads: usize,
386391
pad_token_id: u32,
392+
use_bidirectional_attention: bool,
387393

388394
dtype: DType,
389395
device: Device,
@@ -402,6 +408,8 @@ impl Qwen3Model {
402408

403409
// The Qwen3-Reranker models contain the `model` key
404410
// https://huggingface.co/collections/Qwen/qwen3-reranker-6841b22d0192d7ade9cdefea
411+
// Keep reference to root vb for loading projection layer
412+
let vb_root = vb.clone();
405413
let vb = if vb.contains_tensor("model.embed_tokens.weight") {
406414
vb.pp("model")
407415
} else {
@@ -420,6 +428,23 @@ impl Qwen3Model {
420428

421429
let norm = RMSNorm::load(vb.pp("norm"), config.hidden_size, config.rms_norm_eps)?;
422430

431+
let projection = if let Some(num_labels) = config.num_labels {
432+
if vb_root.contains_tensor("linear.weight") {
433+
let projection_weight =
434+
vb_root.get((num_labels, config.hidden_size), "linear.weight")?;
435+
Some(Linear::new(projection_weight, None, None))
436+
} else {
437+
tracing::warn!(
438+
"num_labels is set but linear.weight not found, skipping projection layer"
439+
);
440+
None
441+
}
442+
} else {
443+
None
444+
};
445+
446+
let use_bidirectional_attention = config.use_bidirectional_attention.unwrap_or(false);
447+
423448
let rotary_dim = config
424449
.head_dim
425450
.unwrap_or(config.hidden_size / config.num_attention_heads);
@@ -433,11 +458,13 @@ impl Qwen3Model {
433458
embeddings,
434459
layers,
435460
norm,
461+
projection,
436462
rotary_cache,
437463
rotary_dim,
438464
pool,
439465
pad_token_id: config.eos_token_id as u32,
440466
num_attention_heads: config.num_attention_heads,
467+
use_bidirectional_attention,
441468
dtype: vb.dtype(),
442469
device: vb.device().clone(),
443470
span: tracing::span!(tracing::Level::TRACE, "model"),
@@ -555,7 +582,9 @@ impl Qwen3Model {
555582
(input_ids, position_ids, input_lengths, Some(attention_bias))
556583
};
557584

558-
let attention_bias = if let Some(attn_bias) = attention_bias {
585+
let attention_bias = if self.use_bidirectional_attention {
586+
attention_bias
587+
} else if let Some(attn_bias) = attention_bias {
559588
Some(self.get_causal_attention_bias(attn_bias)?)
560589
} else {
561590
None
@@ -581,6 +610,12 @@ impl Qwen3Model {
581610

582611
let (outputs, _) = self.norm.forward(&hidden_states, None)?;
583612

613+
let outputs = if let Some(ref projection) = self.projection {
614+
projection.forward(&outputs)?
615+
} else {
616+
outputs
617+
};
618+
584619
let has_pooling_requests = !batch.pooled_indices.is_empty();
585620
let has_raw_requests = !batch.raw_indices.is_empty();
586621

0 commit comments

Comments
 (0)