diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index e506a63d..4852d305 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -597,13 +597,30 @@ impl Backend for CandleBackend { let batch_size = batch.len(); let results = self.model.predict(batch).e()?; - - let results = results.to_dtype(DType::F32).e()?.to_vec2().e()?; + let results = results.to_dtype(DType::F32).e()?; let mut predictions = HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); - for (i, r) in results.into_iter().enumerate() { - predictions.insert(i, r); + + match results.dims() { + [_n] => { + let scores = results.to_vec1::().e()?; + for (i, score) in scores.into_iter().enumerate() { + predictions.insert(i, vec![score]); + } + } + [_, _] => { + let results = results.to_vec2().e()?; + for (i, r) in results.into_iter().enumerate() { + predictions.insert(i, r); + } + } + dims => { + return Err(BackendError::Inference(format!( + "Unexpected tensor shape: {:?}", + dims + ))); + } } Ok(predictions) diff --git a/backends/candle/src/models/bert.rs b/backends/candle/src/models/bert.rs index 1720ce9d..48bc7f48 100644 --- a/backends/candle/src/models/bert.rs +++ b/backends/candle/src/models/bert.rs @@ -597,6 +597,9 @@ impl BertModel { }; (pool, None, splade) } + ModelType::ListwiseReranker => { + candle::bail!("`reranker` model type is not supported for Bert") + } }; let (embeddings, encoder) = match ( @@ -615,7 +618,6 @@ impl BertModel { } } }; - Ok(Self { embeddings, encoder, @@ -661,6 +663,9 @@ impl BertModel { }; (pool, None, splade) } + ModelType::ListwiseReranker => { + candle::bail!("`reranker` model type is not supported for RoBERTa") + } }; let (embeddings, encoder) = match ( diff --git a/backends/candle/src/models/distilbert.rs b/backends/candle/src/models/distilbert.rs index 7b8f0786..a569e558 100644 --- a/backends/candle/src/models/distilbert.rs +++ b/backends/candle/src/models/distilbert.rs @@ -462,6 +462,9 @@ impl DistilBertModel { (pool, None) } + ModelType::ListwiseReranker => { + candle::bail!("`reranker` model type is not supported for DistilBert") + } }; let (embeddings, encoder) = match ( diff --git a/backends/candle/src/models/flash_bert.rs b/backends/candle/src/models/flash_bert.rs index 9b14d9a0..bad6d3da 100644 --- a/backends/candle/src/models/flash_bert.rs +++ b/backends/candle/src/models/flash_bert.rs @@ -259,6 +259,7 @@ impl FlashBertModel { }; (pool, None, splade) } + ModelType::ListwiseReranker => todo!(), }; let (embeddings, encoder) = match ( @@ -326,6 +327,7 @@ impl FlashBertModel { }; (pool, None, splade) } + ModelType::ListwiseReranker => todo!(), }; let (embeddings, encoder) = match ( diff --git a/backends/candle/src/models/flash_distilbert.rs b/backends/candle/src/models/flash_distilbert.rs index 2664c660..e6eb2ece 100644 --- a/backends/candle/src/models/flash_distilbert.rs +++ b/backends/candle/src/models/flash_distilbert.rs @@ -200,6 +200,7 @@ impl FlashDistilBertModel { candle::bail!("`classifier` model type is not supported for DistilBert") } ModelType::Embedding(pool) => pool, + ModelType::ListwiseReranker => todo!(), }; let (embeddings, encoder) = match ( diff --git a/backends/candle/src/models/flash_gte.rs b/backends/candle/src/models/flash_gte.rs index 0d38c704..8965b810 100644 --- a/backends/candle/src/models/flash_gte.rs +++ b/backends/candle/src/models/flash_gte.rs @@ -191,6 +191,7 @@ impl FlashGTEModel { (pool, Some(classifier)) } ModelType::Embedding(pool) => (pool, None), + ModelType::ListwiseReranker => todo!(), }; let (word_embeddings, token_type_embeddings, layers, embeddings_norm) = diff --git a/backends/candle/src/models/flash_jina.rs b/backends/candle/src/models/flash_jina.rs index 05341b84..ac310d32 100644 --- a/backends/candle/src/models/flash_jina.rs +++ b/backends/candle/src/models/flash_jina.rs @@ -267,6 +267,7 @@ impl FlashJinaBertModel { } (pool, None) } + ModelType::ListwiseReranker => todo!(), }; let (embeddings, encoder) = match ( diff --git a/backends/candle/src/models/flash_jina_code.rs b/backends/candle/src/models/flash_jina_code.rs index e00f758d..2f1f5bd7 100644 --- a/backends/candle/src/models/flash_jina_code.rs +++ b/backends/candle/src/models/flash_jina_code.rs @@ -314,6 +314,7 @@ impl FlashJinaCodeBertModel { } pool } + ModelType::ListwiseReranker => todo!(), }; let (embeddings, encoder) = match ( diff --git a/backends/candle/src/models/flash_mistral.rs b/backends/candle/src/models/flash_mistral.rs index c8488f36..cdb43a89 100644 --- a/backends/candle/src/models/flash_mistral.rs +++ b/backends/candle/src/models/flash_mistral.rs @@ -251,6 +251,7 @@ impl FlashMistralModel { candle::bail!("`classifier` model type is not supported for Mistral") } ModelType::Embedding(pool) => pool, + ModelType::ListwiseReranker => todo!(), }; let embeddings = Embedding::new( diff --git a/backends/candle/src/models/flash_modernbert.rs b/backends/candle/src/models/flash_modernbert.rs index 3c876c63..e888ab3f 100644 --- a/backends/candle/src/models/flash_modernbert.rs +++ b/backends/candle/src/models/flash_modernbert.rs @@ -274,6 +274,7 @@ impl FlashModernBertModel { (pool, None) } + ModelType::ListwiseReranker => todo!(), }; let embeddings = ModernBertEmbeddings::load(vb.pp("model.embeddings"), config) diff --git a/backends/candle/src/models/flash_nomic.rs b/backends/candle/src/models/flash_nomic.rs index 32cd31b6..c453b757 100644 --- a/backends/candle/src/models/flash_nomic.rs +++ b/backends/candle/src/models/flash_nomic.rs @@ -228,6 +228,7 @@ impl FlashNomicBertModel { } pool } + ModelType::ListwiseReranker => todo!(), }; let embeddings = NomicBertEmbeddings::load(vb.clone(), config)?; diff --git a/backends/candle/src/models/flash_qwen2.rs b/backends/candle/src/models/flash_qwen2.rs index c9116311..9b76c989 100644 --- a/backends/candle/src/models/flash_qwen2.rs +++ b/backends/candle/src/models/flash_qwen2.rs @@ -259,6 +259,7 @@ impl FlashQwen2Model { candle::bail!("`classifier` model type is not supported for Qwen2") } ModelType::Embedding(pool) => pool, + ModelType::ListwiseReranker => todo!(), }; // Pushing the prefix for `model` is apparently only required if the model architecture is diff --git a/backends/candle/src/models/flash_qwen3.rs b/backends/candle/src/models/flash_qwen3.rs index 10f27bdd..7654541e 100644 --- a/backends/candle/src/models/flash_qwen3.rs +++ b/backends/candle/src/models/flash_qwen3.rs @@ -1,7 +1,7 @@ use crate::flash_attn::flash_attn_varlen; use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, Linear, RMSNorm}; use crate::models::{Model, Qwen3Config}; -use candle::{DType, Device, IndexOp, Result, Tensor}; +use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{Embedding, Module, VarBuilder}; use candle_rotary::apply_rotary_inplace; use text_embeddings_backend_core::{Batch, ModelType, Pool}; @@ -42,7 +42,10 @@ impl Qwen3Attention { "weight", )?; let query_bias = if config.attention_bias { - Some(vb.pp("q_proj").get(hidden_size, "bias")?) + Some( + vb.pp("q_proj") + .get(num_attention_heads * attention_head_size, "bias")?, + ) } else { None }; @@ -85,7 +88,7 @@ impl Qwen3Attention { let q_norm = RMSNorm::load(vb.pp("q_norm"), attention_head_size, config.rms_norm_eps)?; let k_norm = RMSNorm::load(vb.pp("k_norm"), attention_head_size, config.rms_norm_eps)?; - let softmax_scale = (1. / (attention_head_size as f64).sqrt()) as f32; + let softmax_scale = 1.0 / (attention_head_size as f64).sqrt() as f32; Ok(Self { q_proj, @@ -148,6 +151,28 @@ impl Qwen3Attention { apply_rotary_inplace(&q, &k, &cos, &sin, true)?; + let (k, v) = if self.num_key_value_heads != self.num_attention_heads { + if self.num_attention_heads % self.num_key_value_heads != 0 { + candle::bail!("num_attention_heads must be a multiple of num_key_value_heads"); + } + let repeat = self.num_attention_heads / self.num_key_value_heads; + + let (total_tokens, n_kv_heads, head_dim) = k.dims3()?; + + let k = k + .unsqueeze(2)? + .expand((total_tokens, n_kv_heads, repeat, head_dim))? + .reshape((total_tokens, n_kv_heads * repeat, head_dim))?; + + let v = v + .unsqueeze(2)? + .expand((total_tokens, n_kv_heads, repeat, head_dim))? + .reshape((total_tokens, n_kv_heads * repeat, head_dim))?; + (k, v) + } else { + (k, v) + }; + let attention = flash_attn_varlen( &q, &k, @@ -277,16 +302,19 @@ impl Qwen3Layer { let mlp_output = self.mlp.forward(&normed_attn_res_output)?; - Ok((mlp_output, attn_res)) + let output = (&mlp_output + &attn_res)?; + Ok((output, attn_res)) } } pub struct FlashQwen3Model { embeddings: Embedding, + lm_head_weight: Tensor, layers: Vec, norm: RMSNorm, cos_cache: Tensor, sin_cache: Tensor, + model_type: ModelType, pool: Pool, pub device: Device, @@ -304,11 +332,12 @@ impl FlashQwen3Model { candle::bail!("FlashQwen3 requires DType::F16") } - let pool = match model_type { + let pool = match &model_type { ModelType::Classifier => { candle::bail!("`classifier` model type is not supported for Qwen3") } - ModelType::Embedding(pool) => pool, + ModelType::Embedding(pool) => pool.clone(), + ModelType::ListwiseReranker => Pool::LastToken, }; // The Qwen3-Reranker models contain the `model` key @@ -319,11 +348,13 @@ impl FlashQwen3Model { vb }; - let embeddings = Embedding::new( - vb.pp("embed_tokens") - .get((config.vocab_size, config.hidden_size), "weight")?, - config.hidden_size, - ); + let embed_weight = vb + .pp("embed_tokens") + .get((config.vocab_size, config.hidden_size), "weight")?; + + let embeddings = Embedding::new(embed_weight.clone(), config.hidden_size); + + let lm_head_weight = embed_weight; let layers = (0..config.num_hidden_layers) .map(|index| Qwen3Layer::load(vb.pp(format!("layers.{index}")), config)) @@ -346,10 +377,12 @@ impl FlashQwen3Model { Ok(Self { embeddings, + lm_head_weight, layers, norm, cos_cache, sin_cache, + model_type, pool, device: vb.device().clone(), span: tracing::span!(tracing::Level::TRACE, "model"), @@ -376,21 +409,19 @@ impl FlashQwen3Model { let cos = self.cos_cache.index_select(&position_ids, 0)?; let sin = self.sin_cache.index_select(&position_ids, 0)?; - let mut residual = None; for layer in &self.layers { - let (h, r) = layer.forward( + let (h, _r) = layer.forward( &hidden_states, - residual.as_ref(), + None, &cu_seqlens, &cos, &sin, batch.max_length as usize, )?; hidden_states = h; - residual = Some(r); } - let (outputs, _) = self.norm.forward(&hidden_states, residual.as_ref())?; + let (outputs, _) = self.norm.forward(&hidden_states, None)?; let has_pooling_requests = !batch.pooled_indices.is_empty(); let has_raw_requests = !batch.raw_indices.is_empty(); @@ -460,7 +491,8 @@ impl FlashQwen3Model { // Concatenate all results Some(Tensor::cat(&results?, 0)?) } else { - Some((outputs.sum_keepdim(0)? / (batch.max_length as f64))?) + let actual_len = batch.cumulative_seq_lengths[1] as f64; + Some((outputs.sum_keepdim(0)? / actual_len)?) } } Pool::Splade => { @@ -512,4 +544,69 @@ impl Model for FlashQwen3Model { fn embed(&self, batch: Batch) -> Result<(Option, Option)> { self.forward(batch) } + + fn predict(&self, batch: Batch) -> Result { + match &self.model_type { + ModelType::ListwiseReranker => { + let _enter = self.span.enter(); + + let batch_size = batch.cumulative_seq_lengths.len() - 1; + let shape = batch.input_ids.len(); + + let input_ids = Tensor::from_vec(batch.input_ids, shape, &self.device)?; + let position_ids = Tensor::from_vec(batch.position_ids, shape, &self.device)?; + let cu_seqlens = Tensor::from_vec( + batch.cumulative_seq_lengths.clone(), + batch_size + 1, + &self.device, + )?; + + let mut hidden_states = self.embeddings.forward(&input_ids)?; + + let cos = self.cos_cache.index_select(&position_ids, 0)?; + let sin = self.sin_cache.index_select(&position_ids, 0)?; + + for layer in &self.layers { + let (h, _r) = layer.forward( + &hidden_states, + None, + &cu_seqlens, + &cos, + &sin, + batch.max_length as usize, + )?; + hidden_states = h; + } + + let (outputs, _) = self.norm.forward(&hidden_states, None)?; + + let mut last_hidden_states = Vec::with_capacity(batch_size); + + for i in 0..batch_size { + let seq_end = batch.cumulative_seq_lengths[i + 1] as usize; + let last_token_idx = seq_end - 1; + + let h_last = outputs.i(last_token_idx)?; // [hidden_size] + last_hidden_states.push(h_last); + } + + let h_last = Tensor::stack(&last_hidden_states, 0)?; // [bs, hidden_size] + + // Correct token IDs for Qwen3 (verified from tokenizer) + let yes_id = 9454u32; // "yes" token ID + let no_id = 2901u32; // "no" token ID + + tracing::debug!("Using Qwen3 token IDs - yes: {}, no: {}", yes_id, no_id); + + let ids = Tensor::from_vec(vec![no_id, yes_id], 2, &self.device)?; + let w = self.lm_head_weight.index_select(&ids, 0)?; // [2, hidden_size] + let logits = h_last.matmul(&w.t()?)?; // [bs, 2] (no, yes) + let log_probs = candle_nn::ops::log_softmax(&logits, D::Minus1)?; + let scores = log_probs.i((.., 1))?.exp()?; // P("yes") ∈ (0,1) + + Ok(scores) + } + _ => candle::bail!("`predict` is only available for ModelType::ListwiseReranker"), + } + } } diff --git a/backends/candle/src/models/gte.rs b/backends/candle/src/models/gte.rs index d5cf3412..9514e5ce 100644 --- a/backends/candle/src/models/gte.rs +++ b/backends/candle/src/models/gte.rs @@ -406,6 +406,9 @@ impl GTEModel { (pool, Some(classifier)) } ModelType::Embedding(pool) => (pool, None), + ModelType::ListwiseReranker => { + candle::bail!("`reranker` model type is not supported for GTE") + } }; let (word_embeddings, token_type_embeddings, encoder, embeddings_norm) = diff --git a/backends/candle/src/models/jina.rs b/backends/candle/src/models/jina.rs index d54d49c6..0783dd65 100644 --- a/backends/candle/src/models/jina.rs +++ b/backends/candle/src/models/jina.rs @@ -436,6 +436,9 @@ impl JinaBertModel { } (pool, None) } + ModelType::ListwiseReranker => { + candle::bail!("`reranker` model type is not supported for Jina") + } }; let (embeddings, encoder) = match ( diff --git a/backends/candle/src/models/jina_code.rs b/backends/candle/src/models/jina_code.rs index 8cb6f65a..c6d6ea6c 100644 --- a/backends/candle/src/models/jina_code.rs +++ b/backends/candle/src/models/jina_code.rs @@ -364,6 +364,9 @@ impl JinaCodeBertModel { } pool } + ModelType::ListwiseReranker => { + candle::bail!("`reranker` model type is not supported for JinaCode") + } }; let (embeddings, encoder) = match ( diff --git a/backends/candle/src/models/modernbert.rs b/backends/candle/src/models/modernbert.rs index b94325ca..f7170166 100644 --- a/backends/candle/src/models/modernbert.rs +++ b/backends/candle/src/models/modernbert.rs @@ -498,6 +498,9 @@ impl ModernBertModel { (pool, None) } + ModelType::ListwiseReranker => { + candle::bail!("`reranker` model type is not supported for ModernBert") + } }; let embeddings = ModernBertEmbeddings::load(vb.pp("model.embeddings"), config) diff --git a/backends/candle/src/models/mpnet.rs b/backends/candle/src/models/mpnet.rs index 2fd52f31..28dbc7b3 100644 --- a/backends/candle/src/models/mpnet.rs +++ b/backends/candle/src/models/mpnet.rs @@ -450,6 +450,9 @@ impl MPNetModel { } pool } + ModelType::ListwiseReranker => { + candle::bail!("`reranker` model type is not supported for MPNet") + } }; let (embeddings, encoder) = match ( diff --git a/backends/candle/src/models/nomic.rs b/backends/candle/src/models/nomic.rs index 8748db38..3b3e031b 100644 --- a/backends/candle/src/models/nomic.rs +++ b/backends/candle/src/models/nomic.rs @@ -697,6 +697,9 @@ impl NomicBertModel { } pool } + ModelType::ListwiseReranker => { + candle::bail!("`reranker` model type is not supported for Nomic") + } }; let embeddings = NomicBertEmbeddings::load(vb.clone(), config)?; diff --git a/backends/candle/src/models/qwen3.rs b/backends/candle/src/models/qwen3.rs index 13309927..7a62845d 100644 --- a/backends/candle/src/models/qwen3.rs +++ b/backends/candle/src/models/qwen3.rs @@ -377,10 +377,12 @@ impl Qwen3Layer { pub struct Qwen3Model { embeddings: Embedding, + lm_head_weight: Tensor, layers: Vec, norm: RMSNorm, rotary_cache: (Tensor, Tensor), rotary_dim: usize, + model_type: ModelType, pool: Pool, num_attention_heads: usize, pad_token_id: u32, @@ -393,11 +395,12 @@ pub struct Qwen3Model { impl Qwen3Model { pub fn load(vb: VarBuilder, config: &Qwen3Config, model_type: ModelType) -> Result { - let pool = match model_type { + let pool = match &model_type { ModelType::Classifier => { candle::bail!("`classifier` model type is not supported for Qwen3") } - ModelType::Embedding(pool) => pool, + ModelType::Embedding(pool) => pool.clone(), + ModelType::ListwiseReranker => Pool::LastToken, }; // The Qwen3-Reranker models contain the `model` key @@ -408,11 +411,13 @@ impl Qwen3Model { vb }; - let embeddings = Embedding::new( - vb.pp("embed_tokens") - .get((config.vocab_size, config.hidden_size), "weight")?, - config.hidden_size, - ); + let embed_weight = vb + .pp("embed_tokens") + .get((config.vocab_size, config.hidden_size), "weight")?; + + let embeddings = Embedding::new(embed_weight.clone(), config.hidden_size); + + let lm_head_weight = embed_weight; let layers = (0..config.num_hidden_layers) .map(|index| Qwen3Layer::load(vb.pp(format!("layers.{index}")), config)) @@ -431,10 +436,12 @@ impl Qwen3Model { Ok(Self { embeddings, + lm_head_weight, layers, norm, rotary_cache, rotary_dim, + model_type, pool, pad_token_id: config.eos_token_id as u32, num_attention_heads: config.num_attention_heads, @@ -700,4 +707,63 @@ impl Model for Qwen3Model { fn embed(&self, batch: Batch) -> Result<(Option, Option)> { self.forward(batch) } + + fn predict(&self, batch: Batch) -> Result { + match &self.model_type { + ModelType::ListwiseReranker => { + // Extract needed values before moving batch + let batch_size = batch.len(); + let max_length = batch.max_length as usize; + + // Use the existing forward method to get hidden states + let (_, raw_embeddings) = self.forward(batch)?; + + let hidden_states = match raw_embeddings { + Some(embeddings) => embeddings, + None => candle::bail!("No hidden states returned from forward pass"), + }; + + // Project through LM head to get logits + let logits = hidden_states.matmul(&self.lm_head_weight.t()?)?; + + // Correct token IDs for Qwen3 (verified from tokenizer) + let yes_id = 9454u32; // "yes" token ID + let no_id = 2901u32; // "no" token ID + + tracing::debug!("Using Qwen3 token IDs - yes: {}, no: {}", yes_id, no_id); + + // Extract logits for last position of each sequence + let mut scores_vec = Vec::with_capacity(batch_size); + + for i in 0..batch_size { + // For left-padded sequences, the last position contains the actual output + let last_pos = max_length - 1; + + // Get logits for the last position + let last_logits = logits.i((i, last_pos, ..))?; + + // Extract yes/no logits directly + let yes_logit = last_logits.i(yes_id as usize)?; + let no_logit = last_logits.i(no_id as usize)?; + + // Stack [no, yes] and apply log_softmax + let logit_pair = Tensor::stack(&[&no_logit, &yes_logit], 0)?; + let log_probs = + candle_nn::ops::log_softmax(&logit_pair.unsqueeze(0)?, D::Minus1)?; + + // Extract yes probability (index 1) and exp + let yes_log_prob = log_probs.i((0, 1))?; + let score = yes_log_prob.exp()?.to_scalar::()?; + + scores_vec.push(score); + } + + // Convert to tensor + let scores = Tensor::from_vec(scores_vec, batch_size, &self.device)?; + + Ok(scores) + } + _ => candle::bail!("`predict` is only available for ModelType::ListwiseReranker"), + } + } } diff --git a/backends/candle/tests/test_qwen3_reranker.rs b/backends/candle/tests/test_qwen3_reranker.rs new file mode 100644 index 00000000..b967384b --- /dev/null +++ b/backends/candle/tests/test_qwen3_reranker.rs @@ -0,0 +1,145 @@ +mod common; + +use anyhow::Result; +use common::{batch, download_artifacts, load_tokenizer}; +use text_embeddings_backend_candle::CandleBackend; +use text_embeddings_backend_core::{Backend, ModelType}; + +#[test] +#[serial_test::serial] +fn test_qwen3_reranker() -> Result<()> { + if std::env::var("SKIP_DOWNLOAD_TESTS").is_ok() { + return Ok(()); + } + + let model_root = download_artifacts("Qwen/Qwen3-Reranker-0.6B", None, None)?; + let tokenizer = load_tokenizer(&model_root)?; + + let backend = CandleBackend::new( + &model_root, + "float32".to_string(), + ModelType::ListwiseReranker, + None, + )?; + + let prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"; + let suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; + let instruct = "Given a web search query, retrieve relevant passages that answer the query"; + + let format_input = |query: &str, document: &str| -> String { + format!("{prefix}: {instruct}\n: {query}\n: {document}{suffix}") + }; + + let texts = [ + format_input( + "What is the capital of China?", + "The capital of China is Beijing.", + ), + format_input( + "What is the capital of China?", + "The capital of France is Paris.", + ), + format_input( + "What is the capital of China?", + "China is a large country in Asia.", + ), + ]; + + let input_batch = batch( + texts + .iter() + .map(|t| tokenizer.encode(t.as_str(), true).unwrap()) + .collect(), + vec![0, 1, 2], + vec![], + ); + + let predictions = backend.predict(input_batch)?; + let scores_vec: Vec = predictions.into_iter().flat_map(|(_, v)| v).collect(); + + assert_eq!(scores_vec.len(), 3, "Should return 3 scores for 3 inputs"); + + for (i, &score) in scores_vec.iter().enumerate() { + assert!( + score.is_finite() && (0.0..=1.0).contains(&score), + "Score[{}] = {} should be a valid probability", + i, + score + ); + } + + assert!( + scores_vec[0] > scores_vec[1], + "Beijing document (score={}) should score higher than Paris document (score={})", + scores_vec[0], + scores_vec[1] + ); + + assert!( + scores_vec[0] > scores_vec[2], + "Beijing document (score={}) should score higher than generic China document (score={})", + scores_vec[0], + scores_vec[2] + ); + + let single_text = format_input( + "What is machine learning?", + "Machine learning is a subset of artificial intelligence.", + ); + let input_single = batch( + vec![tokenizer.encode(single_text.as_str(), true).unwrap()], + vec![0], + vec![], + ); + + let single_predictions = backend.predict(input_single)?; + let single_score_vec: Vec = single_predictions + .into_iter() + .flat_map(|(_, v)| v) + .collect(); + + assert_eq!( + single_score_vec.len(), + 1, + "Should return 1 score for 1 input" + ); + assert!( + single_score_vec[0].is_finite() && (0.0..=1.0).contains(&single_score_vec[0]), + "Single score should be a valid probability" + ); + + Ok(()) +} + +#[test] +fn test_qwen3_reranker_model_detection() { + // Test that model names containing "reranker" are properly detected + let reranker_models = vec![ + "Qwen/Qwen3-Reranker-0.6B", + "Qwen/Qwen3-Reranker-4B", + "custom/qwen3-reranker-finetuned", + "org/model-qwen3-reranker", + ]; + + let non_reranker_models = vec![ + "Qwen/Qwen3-0.5B", + "Qwen/Qwen3-7B-Instruct", + "Qwen/Qwen3-Embedding-0.6B", + ]; + + for model_name in reranker_models { + assert!( + model_name.to_lowercase().contains("reranker"), + "Model {} should be detected as reranker", + model_name + ); + } + + for model_name in non_reranker_models { + assert!( + !model_name.to_lowercase().contains("reranker"), + "Model {} should NOT be detected as reranker", + model_name + ); + } +} diff --git a/backends/core/src/lib.rs b/backends/core/src/lib.rs index 8e134d2b..9d45c37d 100644 --- a/backends/core/src/lib.rs +++ b/backends/core/src/lib.rs @@ -51,6 +51,7 @@ pub trait Backend { pub enum ModelType { Classifier, Embedding(Pool), + ListwiseReranker, } #[derive(Debug, PartialEq, Clone, Deserialize)] diff --git a/backends/ort/src/lib.rs b/backends/ort/src/lib.rs index add5b33d..7f6b2b3c 100644 --- a/backends/ort/src/lib.rs +++ b/backends/ort/src/lib.rs @@ -38,6 +38,11 @@ impl OrtBackend { } pool => pool, }, + ModelType::ListwiseReranker => { + return Err(BackendError::Start( + "Reranker models are not supported in the ONNX backend".to_string(), + )); + } }; // Get model path diff --git a/backends/python/src/lib.rs b/backends/python/src/lib.rs index 53255b07..e68bbe88 100644 --- a/backends/python/src/lib.rs +++ b/backends/python/src/lib.rs @@ -27,6 +27,7 @@ impl PythonBackend { let pool = match model_type { ModelType::Classifier => Pool::Cls, ModelType::Embedding(pool) => pool, + ModelType::ListwiseReranker => Pool::LastToken, }; let backend_process = management::BackendProcess::new( diff --git a/backends/src/dtype.rs b/backends/src/dtype.rs index 80292be7..4d8f8b4e 100644 --- a/backends/src/dtype.rs +++ b/backends/src/dtype.rs @@ -7,30 +7,20 @@ use clap::ValueEnum; #[cfg_attr(feature = "clap", derive(Clone, ValueEnum))] pub enum DType { // Float16 is not available on accelerate - #[cfg(any( - feature = "python", - all(feature = "candle", not(feature = "accelerate")) - ))] + #[cfg(all(feature = "candle", not(feature = "accelerate")))] Float16, - #[cfg(any(feature = "python", feature = "candle", feature = "ort"))] + #[cfg(any(feature = "candle", feature = "ort"))] Float32, - #[cfg(feature = "python")] - Bfloat16, } impl fmt::Display for DType { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { // Float16 is not available on accelerate - #[cfg(any( - feature = "python", - all(feature = "candle", not(feature = "accelerate")) - ))] + #[cfg(all(feature = "candle", not(feature = "accelerate")))] DType::Float16 => write!(f, "float16"), - #[cfg(any(feature = "python", feature = "candle", feature = "ort"))] + #[cfg(any(feature = "candle", feature = "ort"))] DType::Float32 => write!(f, "float32"), - #[cfg(feature = "python")] - DType::Bfloat16 => write!(f, "bfloat16"), } } } @@ -42,18 +32,9 @@ impl Default for DType { { DType::Float32 } - #[cfg(not(any( - feature = "accelerate", - feature = "mkl", - feature = "ort", - feature = "python" - )))] + #[cfg(not(any(feature = "accelerate", feature = "mkl", feature = "ort")))] { DType::Float16 } - #[cfg(feature = "python")] - { - DType::Bfloat16 - } } } diff --git a/backends/src/lib.rs b/backends/src/lib.rs index 073f94f4..5711fa7a 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -24,9 +24,6 @@ use text_embeddings_backend_candle::CandleBackend; #[cfg(feature = "ort")] use text_embeddings_backend_ort::OrtBackend; -#[cfg(feature = "python")] -use text_embeddings_backend_python::PythonBackend; - fn powers_of_two(max_value: usize) -> Vec { let mut result = Vec::new(); let mut power: usize = 1; @@ -170,7 +167,9 @@ impl Backend { for shape in shapes.iter() { let batch = self.create_warmup_batch(*shape, max_token as u32, seq_bucket_size as u32); match &self.model_type { - ModelType::Classifier => self.predict(batch).await.map(|_| ()), + ModelType::Classifier | ModelType::ListwiseReranker => { + self.predict(batch).await.map(|_| ()) + } ModelType::Embedding(_) => self.embed(batch).await.map(|_| ()), }?; tracing::info!("finish warmup for batch: {}, length: {}", shape.0, shape.1); @@ -179,7 +178,12 @@ impl Backend { } #[instrument(skip_all)] - pub fn create_warmup_batch(&self, shape: (u32, u32), max_token: u32, seq_bucket_size: u32) -> Batch { + pub fn create_warmup_batch( + &self, + shape: (u32, u32), + max_token: u32, + seq_bucket_size: u32, + ) -> Batch { let (batch_size, length) = shape; let min_length = length.saturating_sub(seq_bucket_size).saturating_add(1); let tmp_length = if min_length < length { @@ -275,7 +279,9 @@ impl Backend { }; match &self.model_type { - ModelType::Classifier => self.predict(batch).await.map(|_| ()), + ModelType::Classifier | ModelType::ListwiseReranker => { + self.predict(batch).await.map(|_| ()) + } ModelType::Embedding(_) => self.embed(batch).await.map(|_| ()), } } @@ -308,7 +314,9 @@ impl Backend { raw_indices: vec![], }; match &self.model_type { - ModelType::Classifier => self.predict(batch).await.map(|_| ()), + ModelType::Classifier | ModelType::ListwiseReranker => { + self.predict(batch).await.map(|_| ()) + } ModelType::Embedding(_) => self.embed(batch).await.map(|_| ()), } } @@ -387,7 +395,7 @@ async fn init_backend( } if let Some(api_repo) = api_repo.as_ref() { - if cfg!(feature = "python") || cfg!(feature = "candle") { + if cfg!(feature = "candle") { let start = std::time::Instant::now(); if download_safetensors(api_repo).await.is_err() { tracing::warn!("safetensors weights not found. Using `pytorch_model.bin` instead. Model loading will be significantly slower."); @@ -421,32 +429,6 @@ async fn init_backend( } } - if cfg!(feature = "python") { - #[cfg(feature = "python")] - { - let backend = std::thread::spawn(move || { - PythonBackend::new( - model_path.to_str().unwrap().to_string(), - dtype.to_string(), - model_type, - uds_path, - otlp_endpoint, - otlp_service_name, - ) - }) - .join() - .expect("Python Backend management thread failed"); - - match backend { - Ok(b) => return Ok(Box::new(b)), - Err(err) => { - tracing::error!("Could not start Python backend: {err}"); - backend_start_failed = true; - } - } - } - } - if backend_start_failed { Err(BackendError::Start( "Could not start a suitable backend".to_string(), diff --git a/core/src/infer.rs b/core/src/infer.rs index a2ff22c5..e6fa8eb4 100644 --- a/core/src/infer.rs +++ b/core/src/infer.rs @@ -499,7 +499,10 @@ impl Infer { #[instrument(skip(self))] pub fn is_classifier(&self) -> bool { - matches!(self.backend.model_type, ModelType::Classifier) + matches!( + self.backend.model_type, + ModelType::Classifier | ModelType::ListwiseReranker + ) } #[instrument(skip(self))] @@ -547,7 +550,7 @@ async fn batching_task(queue: Queue, notify: Arc, embed_sender: mpsc::Se async fn backend_task(backend: Backend, mut embed_receiver: mpsc::Receiver) { while let Some(batch) = embed_receiver.recv().await { match &backend.model_type { - ModelType::Classifier => { + ModelType::Classifier | ModelType::ListwiseReranker => { let results = backend.predict(batch.1).await; // Handle sending responses in another thread to avoid starving the backend diff --git a/core/src/tokenization.rs b/core/src/tokenization.rs index 7636afa8..8bac5d21 100644 --- a/core/src/tokenization.rs +++ b/core/src/tokenization.rs @@ -1,6 +1,7 @@ /// Payload tokenization logic use crate::TextEmbeddingsError; use std::collections::HashMap; +use text_embeddings_backend::ModelType; use tokenizers::tokenizer::Tokenizer; pub use tokenizers::Encoding as RawEncoding; use tokenizers::{TruncationDirection, TruncationParams, TruncationStrategy}; @@ -34,6 +35,7 @@ impl Tokenization { position_offset: usize, default_prompt: Option, prompts: Option>, + model_type: ModelType, ) -> Self { tracing::info!("Starting {workers} tokenization workers"); @@ -46,6 +48,7 @@ impl Tokenization { let receiver_clone = receiver.clone(); let default_prompt_clone = default_prompt.clone(); let prompts_clone = prompts.clone(); + let model_type_clone = model_type.clone(); // Spawn worker std::thread::spawn(move || { tokenizer_worker( @@ -54,6 +57,7 @@ impl Tokenization { position_offset, default_prompt_clone, prompts_clone, + model_type_clone, receiver_clone, ) }); @@ -172,6 +176,7 @@ fn tokenizer_worker( position_offset: usize, default_prompt: Option, prompts: Option>, + model_type: ModelType, receiver: async_channel::Receiver, ) { // Loop over requests @@ -203,6 +208,7 @@ fn tokenizer_worker( default_prompt_clone, prompt_name, prompts.as_ref(), + &model_type, &mut tokenizer, )); } @@ -232,6 +238,7 @@ fn tokenizer_worker( default_prompt_clone, prompt_name, prompts.as_ref(), + &model_type, &mut tokenizer, )); } @@ -261,6 +268,81 @@ fn decode_ids( .decode(&ids, skip_special_tokens)?) } +/// Format input for Qwen3 reranker models +fn format_qwen3_reranker_input( + inputs: EncodingInput, +) -> Result { + // The Qwen3 reranker expects a specific prompt format + match inputs { + EncodingInput::Single(text) => { + // For reranking, we expect the input to contain both query and document + // They should be separated by a delimiter like "||" + let parts: Vec<&str> = text.split("||").collect(); + if parts.len() != 2 { + return Err(TextEmbeddingsError::Validation( + "Qwen3 reranker expects input format: 'query||document'".to_string(), + )); + } + + let query = parts[0].trim(); + let document = parts[1].trim(); + let instruction = + "Given a web search query, retrieve relevant passages that answer the query"; + + let formatted = format!( + r#"<|im_start|>system +Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no". +<|im_end|> +<|im_start|>user +: {} +: {} +: {} +<|im_end|> +<|im_start|>assistant + + + + +"#, + instruction, query, document + ); + + Ok(EncodingInput::Single(formatted)) + } + EncodingInput::Dual(query, document) => { + // If we already have query and document separated, format them directly + let instruction = + "Given a web search query, retrieve relevant passages that answer the query"; + + let formatted = format!( + r#"<|im_start|>system +Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no". +<|im_end|> +<|im_start|>user +: {} +: {} +: {} +<|im_end|> +<|im_start|>assistant + + + + +"#, + instruction, query, document + ); + + Ok(EncodingInput::Single(formatted)) + } + EncodingInput::Ids(_) => { + // If already tokenized, we can't format it + Err(TextEmbeddingsError::Validation( + "Cannot format pre-tokenized input for Qwen3 reranker".to_string(), + )) + } + } +} + fn prepare_pre_prompt( default_prompt: Option, prompt_name: Option, @@ -291,8 +373,21 @@ fn tokenize_input( default_prompt: Option, prompt_name: Option, prompts: Option<&HashMap>, + model_type: &ModelType, tokenizer: &mut Tokenizer, ) -> Result<(Option, RawEncoding), TextEmbeddingsError> { + // Check if this is a Qwen3 reranker and apply special formatting + if matches!(model_type, ModelType::ListwiseReranker) { + tracing::debug!("Applying Qwen3 reranker formatting to input"); + // For Qwen3 reranker, we need to format the input with the special template + inputs = format_qwen3_reranker_input(inputs)?; + + // Debug log the formatted input + if let EncodingInput::Single(ref text) = inputs { + tracing::debug!("Formatted Qwen3 input: {}", text); + } + } + let pre_prompt = prepare_pre_prompt(default_prompt, prompt_name, prompts)?; let input_chars = inputs.count_chars(); @@ -372,6 +467,7 @@ fn encode_input( default_prompt: Option, prompt_name: Option, prompts: Option<&HashMap>, + model_type: &ModelType, tokenizer: &mut Tokenizer, ) -> Result { // Default truncation params @@ -390,6 +486,7 @@ fn encode_input( default_prompt, prompt_name, prompts, + model_type, tokenizer, )?; let seq_len = encoding.len(); diff --git a/router/Cargo.toml b/router/Cargo.toml index c411957a..9feefef7 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -85,7 +85,6 @@ grpc = ["metrics-exporter-prometheus/http-listener", "dep:prost", "dep:tonic", " metal = ["text-embeddings-backend/metal"] mkl = ["text-embeddings-backend/mkl"] accelerate = ["text-embeddings-backend/accelerate"] -python = ["text-embeddings-backend/python"] ort = ["text-embeddings-backend/ort"] candle = ["text-embeddings-backend/candle"] candle-cuda = ["candle", "text-embeddings-backend/flash-attn", "dep:cudarc"] diff --git a/router/src/grpc/server.rs b/router/src/grpc/server.rs index 130e519d..59f90ce5 100644 --- a/router/src/grpc/server.rs +++ b/router/src/grpc/server.rs @@ -283,7 +283,7 @@ impl TextEmbeddingsService { let id2label = match &self.info.model_type { ModelType::Classifier(classifier) => &classifier.id2label, - ModelType::Reranker(classifier) => &classifier.id2label, + ModelType::ListwiseReranker(classifier) => &classifier.id2label, _ => panic!(), }; @@ -563,7 +563,7 @@ impl grpc::info_server::Info for TextEmbeddingsService { let model_type = match self.info.model_type { ModelType::Classifier(_) => grpc::ModelType::Classifier, ModelType::Embedding(_) => grpc::ModelType::Embedding, - ModelType::Reranker(_) => grpc::ModelType::Reranker, + ModelType::ListwiseReranker(_) => grpc::ModelType::Reranker, }; Ok(Response::new(InfoResponse { @@ -906,7 +906,7 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService { tracing::error!("{message}"); Err(Status::new(Code::FailedPrecondition, message)) } - ModelType::Reranker(_) => Ok(()), + ModelType::ListwiseReranker(_) => Ok(()), ModelType::Embedding(_) => { let counter = metrics::counter!("te_request_failure", "err" => "model_type"); counter.increment(1); @@ -1084,7 +1084,7 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService { tracing::error!("{message}"); Err(Status::new(Code::FailedPrecondition, message)) } - ModelType::Reranker(_) => Ok(()), + ModelType::ListwiseReranker(_) => Ok(()), ModelType::Embedding(_) => { let counter = metrics::counter!("te_request_failure", "err" => "model_type"); counter.increment(1); @@ -1415,7 +1415,7 @@ pub async fn run( // Match on model type and set the health of the correct service(s) // - // If Reranker, we have both a predict and rerank service + // If ListwiseReranker, we have both a predict and rerank service // // This logic hints back to the user that if they try using the wrong service // given the model type, it will always return an error. @@ -1440,8 +1440,8 @@ pub async fn run( ) .await } - ModelType::Reranker(_) => { - // Reranker has both a predict and rerank service + ModelType::ListwiseReranker(_) => { + // ListwiseReranker has both a predict and rerank service health_reporter .set_service_status( >::NAME, diff --git a/router/src/http/server.rs b/router/src/http/server.rs index a22af962..2e37aba7 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -138,7 +138,7 @@ async fn predict( let id2label = match &info.model_type { ModelType::Classifier(classifier) => &classifier.id2label, - ModelType::Reranker(classifier) => &classifier.id2label, + ModelType::ListwiseReranker(classifier) => &classifier.id2label, _ => panic!(), }; @@ -330,7 +330,7 @@ async fn rerank( } match &info.model_type { - ModelType::Reranker(_) => Ok(()), + ModelType::ListwiseReranker(_) => Ok(()), ModelType::Classifier(_) | ModelType::Embedding(_) => { let counter = metrics::counter!("te_request_failure", "err" => "model_type"); counter.increment(1); @@ -1551,7 +1551,7 @@ async fn vertex_compatibility( } match info.model_type { - ModelType::Classifier(_) | ModelType::Reranker(_) => { + ModelType::Classifier(_) | ModelType::ListwiseReranker(_) => { let instance = serde_json::from_value::(instance) .map_err(ErrorResponse::from)?; futures @@ -1782,7 +1782,7 @@ pub async fn run( // AWS Sagemaker route .route("/invocations", post(predict)) } - ModelType::Reranker(_) => { + ModelType::ListwiseReranker(_) => { routes .route("/", post(rerank)) // AWS Sagemaker route diff --git a/router/src/lib.rs b/router/src/lib.rs index 901cc42e..bc5ee3dc 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -26,6 +26,7 @@ use std::fs; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::Path; use std::time::{Duration, Instant}; +use serde_json::json; use text_embeddings_backend::{DType, Pool}; use text_embeddings_core::download::{download_artifacts, ST_CONFIG_NAMES}; use text_embeddings_core::infer::Infer; @@ -104,13 +105,13 @@ pub async fn run( let dense_root = dense_path.map(|path| model_root.join(path)); // Load config - let config_path = model_root.join("config.json"); - let config = fs::read_to_string(config_path).context("`config.json` not found")?; + let main_config_path = model_root.join("config.json"); + let config_str = fs::read_to_string(&main_config_path).context("`config.json` not found")?; let config: ModelConfig = - serde_json::from_str(&config).context("Failed to parse `config.json`")?; + serde_json::from_str(&config_str).context("Failed to parse `config.json`")?; // Set model type from config - let backend_model_type = get_backend_model_type(&config, &model_root, pooling)?; + let backend_model_type = get_backend_model_type(&config, &model_root, &model_id, pooling)?; // Info model type let model_type = match &backend_model_type { @@ -128,7 +129,7 @@ pub async fn run( if n_classes > 1 { ModelType::Classifier(classifier_model) } else { - ModelType::Reranker(classifier_model) + ModelType::ListwiseReranker(classifier_model) } } text_embeddings_backend::ModelType::Embedding(pool) => { @@ -136,6 +137,16 @@ pub async fn run( pooling: pool.to_string(), }) } + text_embeddings_backend::ModelType::ListwiseReranker => { + // For Qwen3 reranker, we don't have id2label/label2id + // Create a dummy classifier model for now + ModelType::ListwiseReranker(ClassifierModel { + id2label: [("0".to_string(), "positive".to_string())] + .into_iter() + .collect(), + label2id: [("positive".to_string(), 0)].into_iter().collect(), + }) + } }; // Load tokenizer @@ -230,6 +241,7 @@ pub async fn run( position_offset, default_prompt, prompts, + backend_model_type.clone(), ); // Get dtype @@ -355,8 +367,45 @@ pub async fn run( fn get_backend_model_type( config: &ModelConfig, model_root: &Path, + model_id: &str, pooling: Option, ) -> Result { + // Check for reranker models using config-based detection + // 1. Explicit is_reranker flag + if config.is_reranker == Some(true) { + tracing::info!("Detected reranker model from config.is_reranker flag"); + if pooling.is_some() { + tracing::warn!( + "`--pooling` arg is set but model is a reranker. Ignoring `--pooling` arg." + ); + } + return Ok(text_embeddings_backend::ModelType::ListwiseReranker); + } + + // 2. Fallback to name-based detection for Qwen3 models + if config + .architectures + .iter() + .any(|arch| arch == "Qwen3ForCausalLM") + { + // Check if the model ID contains "reranker" + let model_name = model_id + .split('/') + .last() + .unwrap_or(model_id) + .to_lowercase(); + + if model_name.contains("reranker") { + tracing::info!("Detected Qwen3-Reranker model from model name"); + if pooling.is_some() { + tracing::warn!( + "`--pooling` arg is set but model is a reranker. Ignoring `--pooling` arg." + ); + } + return Ok(text_embeddings_backend::ModelType::ListwiseReranker); + } + } + for arch in &config.architectures { // Edge case affecting `Alibaba-NLP/gte-multilingual-base` and possibly other fine-tunes of // the same base model. More context at https://huggingface.co/Alibaba-NLP/gte-multilingual-base/discussions/7 @@ -371,6 +420,13 @@ fn get_backend_model_type( return Ok(text_embeddings_backend::ModelType::Embedding( text_embeddings_backend::Pool::Splade, )); + } else if arch == "Qwen3ForSequenceClassification" { + if pooling.is_some() { + tracing::warn!( + "`--pooling` arg is set but model is a reranker. Ignoring `--pooling` arg." + ); + } + return Ok(text_embeddings_backend::ModelType::ListwiseReranker); } else if arch.ends_with("Classification") { if pooling.is_some() { tracing::warn!( @@ -424,6 +480,8 @@ pub struct ModelConfig { pub pad_token_id: usize, pub id2label: Option>, pub label2id: Option>, + // Reranker-specific field + pub is_reranker: Option, } #[derive(Debug, Clone, PartialEq, Deserialize)] @@ -483,7 +541,7 @@ pub struct ClassifierModel { pub enum ModelType { Classifier(ClassifierModel), Embedding(EmbeddingModel), - Reranker(ClassifierModel), + ListwiseReranker(ClassifierModel), } #[derive(Clone, Debug, Serialize)]