Skip to content

[Feature] Qwen3 Reranker #695

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -597,13 +597,30 @@ impl Backend for CandleBackend {
let batch_size = batch.len();

let results = self.model.predict(batch).e()?;

let results = results.to_dtype(DType::F32).e()?.to_vec2().e()?;
let results = results.to_dtype(DType::F32).e()?;

let mut predictions =
HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());
for (i, r) in results.into_iter().enumerate() {
predictions.insert(i, r);

match results.dims() {
[_n] => {
let scores = results.to_vec1::<f32>().e()?;
for (i, score) in scores.into_iter().enumerate() {
predictions.insert(i, vec![score]);
}
}
[_, _] => {
let results = results.to_vec2().e()?;
for (i, r) in results.into_iter().enumerate() {
predictions.insert(i, r);
}
}
dims => {
return Err(BackendError::Inference(format!(
"Unexpected tensor shape: {:?}",
dims
)));
}
}

Ok(predictions)
Expand Down
7 changes: 6 additions & 1 deletion backends/candle/src/models/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,9 @@ impl BertModel {
};
(pool, None, splade)
}
ModelType::ListwiseReranker => {
candle::bail!("`reranker` model type is not supported for Bert")
}
};

let (embeddings, encoder) = match (
Expand All @@ -615,7 +618,6 @@ impl BertModel {
}
}
};

Ok(Self {
embeddings,
encoder,
Expand Down Expand Up @@ -661,6 +663,9 @@ impl BertModel {
};
(pool, None, splade)
}
ModelType::ListwiseReranker => {
candle::bail!("`reranker` model type is not supported for RoBERTa")
}
};

let (embeddings, encoder) = match (
Expand Down
3 changes: 3 additions & 0 deletions backends/candle/src/models/distilbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,9 @@ impl DistilBertModel {

(pool, None)
}
ModelType::ListwiseReranker => {
candle::bail!("`reranker` model type is not supported for DistilBert")
}
};

let (embeddings, encoder) = match (
Expand Down
2 changes: 2 additions & 0 deletions backends/candle/src/models/flash_bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ impl FlashBertModel {
};
(pool, None, splade)
}
ModelType::ListwiseReranker => todo!(),
};

let (embeddings, encoder) = match (
Expand Down Expand Up @@ -326,6 +327,7 @@ impl FlashBertModel {
};
(pool, None, splade)
}
ModelType::ListwiseReranker => todo!(),
};

let (embeddings, encoder) = match (
Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/models/flash_distilbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ impl FlashDistilBertModel {
candle::bail!("`classifier` model type is not supported for DistilBert")
}
ModelType::Embedding(pool) => pool,
ModelType::ListwiseReranker => todo!(),
};

let (embeddings, encoder) = match (
Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/models/flash_gte.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ impl FlashGTEModel {
(pool, Some(classifier))
}
ModelType::Embedding(pool) => (pool, None),
ModelType::ListwiseReranker => todo!(),
};

let (word_embeddings, token_type_embeddings, layers, embeddings_norm) =
Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/models/flash_jina.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ impl FlashJinaBertModel {
}
(pool, None)
}
ModelType::ListwiseReranker => todo!(),
};

let (embeddings, encoder) = match (
Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/models/flash_jina_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ impl FlashJinaCodeBertModel {
}
pool
}
ModelType::ListwiseReranker => todo!(),
};

let (embeddings, encoder) = match (
Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/models/flash_mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ impl FlashMistralModel {
candle::bail!("`classifier` model type is not supported for Mistral")
}
ModelType::Embedding(pool) => pool,
ModelType::ListwiseReranker => todo!(),
};

let embeddings = Embedding::new(
Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/models/flash_modernbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ impl FlashModernBertModel {

(pool, None)
}
ModelType::ListwiseReranker => todo!(),
};

let embeddings = ModernBertEmbeddings::load(vb.pp("model.embeddings"), config)
Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/models/flash_nomic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ impl FlashNomicBertModel {
}
pool
}
ModelType::ListwiseReranker => todo!(),
};

let embeddings = NomicBertEmbeddings::load(vb.clone(), config)?;
Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/models/flash_qwen2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ impl FlashQwen2Model {
candle::bail!("`classifier` model type is not supported for Qwen2")
}
ModelType::Embedding(pool) => pool,
ModelType::ListwiseReranker => todo!(),
};

// Pushing the prefix for `model` is apparently only required if the model architecture is
Expand Down
131 changes: 114 additions & 17 deletions backends/candle/src/models/flash_qwen3.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, Linear, RMSNorm};
use crate::models::{Model, Qwen3Config};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Module, VarBuilder};
use candle_rotary::apply_rotary_inplace;
use text_embeddings_backend_core::{Batch, ModelType, Pool};
Expand Down Expand Up @@ -42,7 +42,10 @@ impl Qwen3Attention {
"weight",
)?;
let query_bias = if config.attention_bias {
Some(vb.pp("q_proj").get(hidden_size, "bias")?)
Some(
vb.pp("q_proj")
.get(num_attention_heads * attention_head_size, "bias")?,
)
} else {
None
};
Expand Down Expand Up @@ -85,7 +88,7 @@ impl Qwen3Attention {
let q_norm = RMSNorm::load(vb.pp("q_norm"), attention_head_size, config.rms_norm_eps)?;
let k_norm = RMSNorm::load(vb.pp("k_norm"), attention_head_size, config.rms_norm_eps)?;

let softmax_scale = (1. / (attention_head_size as f64).sqrt()) as f32;
let softmax_scale = 1.0 / (attention_head_size as f64).sqrt() as f32;

Ok(Self {
q_proj,
Expand Down Expand Up @@ -148,6 +151,28 @@ impl Qwen3Attention {

apply_rotary_inplace(&q, &k, &cos, &sin, true)?;

let (k, v) = if self.num_key_value_heads != self.num_attention_heads {
if self.num_attention_heads % self.num_key_value_heads != 0 {
candle::bail!("num_attention_heads must be a multiple of num_key_value_heads");
}
let repeat = self.num_attention_heads / self.num_key_value_heads;

let (total_tokens, n_kv_heads, head_dim) = k.dims3()?;

let k = k
.unsqueeze(2)?
.expand((total_tokens, n_kv_heads, repeat, head_dim))?
.reshape((total_tokens, n_kv_heads * repeat, head_dim))?;

let v = v
.unsqueeze(2)?
.expand((total_tokens, n_kv_heads, repeat, head_dim))?
.reshape((total_tokens, n_kv_heads * repeat, head_dim))?;
(k, v)
} else {
(k, v)
};

let attention = flash_attn_varlen(
&q,
&k,
Expand Down Expand Up @@ -277,16 +302,19 @@ impl Qwen3Layer {

let mlp_output = self.mlp.forward(&normed_attn_res_output)?;

Ok((mlp_output, attn_res))
let output = (&mlp_output + &attn_res)?;
Ok((output, attn_res))
}
}

pub struct FlashQwen3Model {
embeddings: Embedding,
lm_head_weight: Tensor,
layers: Vec<Qwen3Layer>,
norm: RMSNorm,
cos_cache: Tensor,
sin_cache: Tensor,
model_type: ModelType,
pool: Pool,
pub device: Device,

Expand All @@ -304,11 +332,12 @@ impl FlashQwen3Model {
candle::bail!("FlashQwen3 requires DType::F16")
}

let pool = match model_type {
let pool = match &model_type {
ModelType::Classifier => {
candle::bail!("`classifier` model type is not supported for Qwen3")
}
ModelType::Embedding(pool) => pool,
ModelType::Embedding(pool) => pool.clone(),
ModelType::ListwiseReranker => Pool::LastToken,
};

// The Qwen3-Reranker models contain the `model` key
Expand All @@ -319,11 +348,13 @@ impl FlashQwen3Model {
vb
};

let embeddings = Embedding::new(
vb.pp("embed_tokens")
.get((config.vocab_size, config.hidden_size), "weight")?,
config.hidden_size,
);
let embed_weight = vb
.pp("embed_tokens")
.get((config.vocab_size, config.hidden_size), "weight")?;

let embeddings = Embedding::new(embed_weight.clone(), config.hidden_size);

let lm_head_weight = embed_weight;

let layers = (0..config.num_hidden_layers)
.map(|index| Qwen3Layer::load(vb.pp(format!("layers.{index}")), config))
Expand All @@ -346,10 +377,12 @@ impl FlashQwen3Model {

Ok(Self {
embeddings,
lm_head_weight,
layers,
norm,
cos_cache,
sin_cache,
model_type,
pool,
device: vb.device().clone(),
span: tracing::span!(tracing::Level::TRACE, "model"),
Expand All @@ -376,21 +409,19 @@ impl FlashQwen3Model {
let cos = self.cos_cache.index_select(&position_ids, 0)?;
let sin = self.sin_cache.index_select(&position_ids, 0)?;

let mut residual = None;
for layer in &self.layers {
let (h, r) = layer.forward(
let (h, _r) = layer.forward(
&hidden_states,
residual.as_ref(),
None,
&cu_seqlens,
&cos,
&sin,
batch.max_length as usize,
)?;
hidden_states = h;
residual = Some(r);
}

let (outputs, _) = self.norm.forward(&hidden_states, residual.as_ref())?;
let (outputs, _) = self.norm.forward(&hidden_states, None)?;

let has_pooling_requests = !batch.pooled_indices.is_empty();
let has_raw_requests = !batch.raw_indices.is_empty();
Expand Down Expand Up @@ -460,7 +491,8 @@ impl FlashQwen3Model {
// Concatenate all results
Some(Tensor::cat(&results?, 0)?)
} else {
Some((outputs.sum_keepdim(0)? / (batch.max_length as f64))?)
let actual_len = batch.cumulative_seq_lengths[1] as f64;
Some((outputs.sum_keepdim(0)? / actual_len)?)
}
}
Pool::Splade => {
Expand Down Expand Up @@ -512,4 +544,69 @@ impl Model for FlashQwen3Model {
fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
self.forward(batch)
}

fn predict(&self, batch: Batch) -> Result<Tensor> {
match &self.model_type {
ModelType::ListwiseReranker => {
let _enter = self.span.enter();

let batch_size = batch.cumulative_seq_lengths.len() - 1;
let shape = batch.input_ids.len();

let input_ids = Tensor::from_vec(batch.input_ids, shape, &self.device)?;
let position_ids = Tensor::from_vec(batch.position_ids, shape, &self.device)?;
let cu_seqlens = Tensor::from_vec(
batch.cumulative_seq_lengths.clone(),
batch_size + 1,
&self.device,
)?;

let mut hidden_states = self.embeddings.forward(&input_ids)?;

let cos = self.cos_cache.index_select(&position_ids, 0)?;
let sin = self.sin_cache.index_select(&position_ids, 0)?;

for layer in &self.layers {
let (h, _r) = layer.forward(
&hidden_states,
None,
&cu_seqlens,
&cos,
&sin,
batch.max_length as usize,
)?;
hidden_states = h;
}

let (outputs, _) = self.norm.forward(&hidden_states, None)?;

let mut last_hidden_states = Vec::with_capacity(batch_size);

for i in 0..batch_size {
let seq_end = batch.cumulative_seq_lengths[i + 1] as usize;
let last_token_idx = seq_end - 1;

let h_last = outputs.i(last_token_idx)?; // [hidden_size]
last_hidden_states.push(h_last);
}

let h_last = Tensor::stack(&last_hidden_states, 0)?; // [bs, hidden_size]

// Correct token IDs for Qwen3 (verified from tokenizer)
let yes_id = 9454u32; // "yes" token ID
let no_id = 2901u32; // "no" token ID

tracing::debug!("Using Qwen3 token IDs - yes: {}, no: {}", yes_id, no_id);

let ids = Tensor::from_vec(vec![no_id, yes_id], 2, &self.device)?;
let w = self.lm_head_weight.index_select(&ids, 0)?; // [2, hidden_size]
let logits = h_last.matmul(&w.t()?)?; // [bs, 2] (no, yes)
let log_probs = candle_nn::ops::log_softmax(&logits, D::Minus1)?;
let scores = log_probs.i((.., 1))?.exp()?; // P("yes") ∈ (0,1)

Ok(scores)
}
_ => candle::bail!("`predict` is only available for ModelType::ListwiseReranker"),
}
}
}
3 changes: 3 additions & 0 deletions backends/candle/src/models/gte.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,9 @@ impl GTEModel {
(pool, Some(classifier))
}
ModelType::Embedding(pool) => (pool, None),
ModelType::ListwiseReranker => {
candle::bail!("`reranker` model type is not supported for GTE")
}
};

let (word_embeddings, token_type_embeddings, encoder, embeddings_norm) =
Expand Down
Loading