Skip to content

Commit 2f1e3ee

Browse files
Add bidirectional attention and projection layer support for voyage-4-nano
This change adds support for the voyageai/voyage-4-nano model which is based on Qwen3 architecture but with two key modifications: 1. Bidirectional attention (is_causal=False) instead of causal attention - Added `use_bidirectional_attention` config field (default: false) - When true, skips causal masking in attention 2. Projection layer (1024 -> 2048 dimensions) - Added `num_labels` config field for output projection dimension - When set, loads "linear.weight" and applies projection after final norm voyage-4-nano config.json includes: "use_bidirectional_attention": true "num_labels": 2048 Both flash (CUDA) and non-flash implementations are updated.
1 parent cb9de7a commit 2f1e3ee

File tree

2 files changed

+89
-3
lines changed

2 files changed

+89
-3
lines changed

backends/candle/src/models/flash_qwen3.rs

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ impl Qwen3Attention {
109109
cos: &Tensor,
110110
sin: &Tensor,
111111
max_s: usize,
112+
causal: bool, // voyage-4-nano: false for bidirectional attention
112113
) -> Result<Tensor> {
113114
let _enter = self.span.enter();
114115

@@ -158,7 +159,7 @@ impl Qwen3Attention {
158159
max_s,
159160
max_s,
160161
self.softmax_scale,
161-
true,
162+
causal, // voyage-4-nano: configurable causal flag
162163
None,
163164
None,
164165
)?;
@@ -262,14 +263,15 @@ impl Qwen3Layer {
262263
cos: &Tensor,
263264
sin: &Tensor,
264265
max_s: usize,
266+
causal: bool, // voyage-4-nano: false for bidirectional attention
265267
) -> Result<(Tensor, Tensor)> {
266268
let _enter = self.span.enter();
267269

268270
let (normed_hidden_states, res) = self.input_layer_norm.forward(hidden_states, residual)?;
269271

270272
let attn_output =
271273
self.attention
272-
.forward(&normed_hidden_states, cu_seqlens, cos, sin, max_s)?;
274+
.forward(&normed_hidden_states, cu_seqlens, cos, sin, max_s, causal)?;
273275

274276
let (normed_attn_res_output, attn_res) = self
275277
.post_attention_layer_norm
@@ -285,9 +287,11 @@ pub struct FlashQwen3Model {
285287
embeddings: Embedding,
286288
layers: Vec<Qwen3Layer>,
287289
norm: RMSNorm,
290+
projection: Option<Linear>, // voyage-4-nano: 1024 -> 2048 projection
288291
cos_cache: Tensor,
289292
sin_cache: Tensor,
290293
pool: Pool,
294+
use_bidirectional_attention: bool, // voyage-4-nano: skip causal masking
291295
pub device: Device,
292296

293297
span: tracing::Span,
@@ -331,6 +335,30 @@ impl FlashQwen3Model {
331335

332336
let norm = RMSNorm::load(vb.pp("norm"), config.hidden_size, config.rms_norm_eps)?;
333337

338+
// voyage-4-nano: load projection layer if num_labels is set
339+
// The projection transforms hidden_size (1024) to num_labels (2048)
340+
let projection = if let Some(num_labels) = config.num_labels {
341+
// Try to load from the model root (voyage-4-nano uses "linear.weight")
342+
let vb_root = if vb.contains_tensor("linear.weight") {
343+
vb.clone()
344+
} else {
345+
// Also check under "model" prefix for reranker-style models
346+
vb.pp("..") // go up one level if we're already in "model"
347+
};
348+
349+
if vb_root.contains_tensor("linear.weight") {
350+
let projection_weight = vb_root.get((num_labels, config.hidden_size), "linear.weight")?;
351+
Some(Linear::new(projection_weight, None, None))
352+
} else {
353+
tracing::warn!("num_labels is set but linear.weight not found, skipping projection layer");
354+
None
355+
}
356+
} else {
357+
None
358+
};
359+
360+
let use_bidirectional_attention = config.use_bidirectional_attention.unwrap_or(false);
361+
334362
let inv_freqs = get_inv_freqs(
335363
layers[0].attention.attention_head_size,
336364
config.rope_theta,
@@ -348,9 +376,11 @@ impl FlashQwen3Model {
348376
embeddings,
349377
layers,
350378
norm,
379+
projection,
351380
cos_cache,
352381
sin_cache,
353382
pool,
383+
use_bidirectional_attention,
354384
device: vb.device().clone(),
355385
span: tracing::span!(tracing::Level::TRACE, "model"),
356386
})
@@ -376,6 +406,9 @@ impl FlashQwen3Model {
376406
let cos = index_select(&self.cos_cache, &position_ids, 0)?;
377407
let sin = index_select(&self.sin_cache, &position_ids, 0)?;
378408

409+
// voyage-4-nano: use bidirectional attention (causal=false) if configured
410+
let causal = !self.use_bidirectional_attention;
411+
379412
let mut residual = None;
380413
for layer in &self.layers {
381414
let (h, r) = layer.forward(
@@ -385,13 +418,21 @@ impl FlashQwen3Model {
385418
&cos,
386419
&sin,
387420
batch.max_length as usize,
421+
causal,
388422
)?;
389423
hidden_states = h;
390424
residual = Some(r);
391425
}
392426

393427
let (outputs, _) = self.norm.forward(&hidden_states, residual.as_ref())?;
394428

429+
// voyage-4-nano: apply projection layer if present (1024 -> 2048)
430+
let outputs = if let Some(ref projection) = self.projection {
431+
projection.forward(&outputs)?
432+
} else {
433+
outputs
434+
};
435+
395436
let has_pooling_requests = !batch.pooled_indices.is_empty();
396437
let has_raw_requests = !batch.raw_indices.is_empty();
397438

backends/candle/src/models/qwen3.rs

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ pub struct Qwen3Config {
2424
pub sliding_window: Option<usize>,
2525
pub use_sliding_window: bool,
2626
pub eos_token_id: usize,
27+
// voyage-4-nano support: bidirectional attention (is_causal=False)
28+
#[serde(default)]
29+
pub use_bidirectional_attention: Option<bool>,
30+
// voyage-4-nano support: projection layer output dimension (1024 -> 2048)
31+
#[serde(default)]
32+
pub num_labels: Option<usize>,
2733
}
2834

2935
struct Qwen3Attention {
@@ -379,11 +385,13 @@ pub struct Qwen3Model {
379385
embeddings: Embedding,
380386
layers: Vec<Qwen3Layer>,
381387
norm: RMSNorm,
388+
projection: Option<Linear>, // voyage-4-nano: 1024 -> 2048 projection
382389
rotary_cache: (Tensor, Tensor),
383390
rotary_dim: usize,
384391
pool: Pool,
385392
num_attention_heads: usize,
386393
pad_token_id: u32,
394+
use_bidirectional_attention: bool, // voyage-4-nano: skip causal masking
387395

388396
dtype: DType,
389397
device: Device,
@@ -420,6 +428,30 @@ impl Qwen3Model {
420428

421429
let norm = RMSNorm::load(vb.pp("norm"), config.hidden_size, config.rms_norm_eps)?;
422430

431+
// voyage-4-nano: load projection layer if num_labels is set
432+
// The projection transforms hidden_size (1024) to num_labels (2048)
433+
let projection = if let Some(num_labels) = config.num_labels {
434+
// Try to load from the model root (voyage-4-nano uses "linear.weight")
435+
let vb_root = if vb.contains_tensor("linear.weight") {
436+
vb.clone()
437+
} else {
438+
// Also check under "model" prefix for reranker-style models
439+
vb.pp("..") // go up one level if we're already in "model"
440+
};
441+
442+
if vb_root.contains_tensor("linear.weight") {
443+
let projection_weight = vb_root.get((num_labels, config.hidden_size), "linear.weight")?;
444+
Some(Linear::new(projection_weight, None, None))
445+
} else {
446+
tracing::warn!("num_labels is set but linear.weight not found, skipping projection layer");
447+
None
448+
}
449+
} else {
450+
None
451+
};
452+
453+
let use_bidirectional_attention = config.use_bidirectional_attention.unwrap_or(false);
454+
423455
let rotary_dim = config
424456
.head_dim
425457
.unwrap_or(config.hidden_size / config.num_attention_heads);
@@ -433,11 +465,13 @@ impl Qwen3Model {
433465
embeddings,
434466
layers,
435467
norm,
468+
projection,
436469
rotary_cache,
437470
rotary_dim,
438471
pool,
439472
pad_token_id: config.eos_token_id as u32,
440473
num_attention_heads: config.num_attention_heads,
474+
use_bidirectional_attention,
441475
dtype: vb.dtype(),
442476
device: vb.device().clone(),
443477
span: tracing::span!(tracing::Level::TRACE, "model"),
@@ -555,7 +589,11 @@ impl Qwen3Model {
555589
(input_ids, position_ids, input_lengths, Some(attention_bias))
556590
};
557591

558-
let attention_bias = if let Some(attn_bias) = attention_bias {
592+
// voyage-4-nano: skip causal masking when using bidirectional attention
593+
let attention_bias = if self.use_bidirectional_attention {
594+
// Bidirectional attention: only use padding mask (no causal mask)
595+
attention_bias
596+
} else if let Some(attn_bias) = attention_bias {
559597
Some(self.get_causal_attention_bias(attn_bias)?)
560598
} else {
561599
None
@@ -581,6 +619,13 @@ impl Qwen3Model {
581619

582620
let (outputs, _) = self.norm.forward(&hidden_states, None)?;
583621

622+
// voyage-4-nano: apply projection layer if present (1024 -> 2048)
623+
let outputs = if let Some(ref projection) = self.projection {
624+
projection.forward(&outputs)?
625+
} else {
626+
outputs
627+
};
628+
584629
let has_pooling_requests = !batch.pooled_indices.is_empty();
585630
let has_raw_requests = !batch.raw_indices.is_empty();
586631

0 commit comments

Comments
 (0)