Skip to content

Commit c30aebc

Browse files
committed
feat: support qwen3 reranker
1 parent c8ff435 commit c30aebc

File tree

4 files changed

+253
-8
lines changed

4 files changed

+253
-8
lines changed

backends/candle/src/models/flash_qwen3.rs

Lines changed: 115 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,13 +281,97 @@ impl Qwen3Layer {
281281
}
282282
}
283283

284+
// Define ClassificationHead trait locally (following TEI pattern)
285+
trait ClassificationHead {
286+
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor>;
287+
}
288+
289+
// Qwen3 Classification Head implementation
290+
#[derive(Debug)]
291+
struct Qwen3ClassificationHead {
292+
dense: Linear,
293+
out_proj: Linear,
294+
activation: HiddenAct,
295+
span: tracing::Span,
296+
}
297+
298+
impl Qwen3ClassificationHead {
299+
pub fn load(vb: VarBuilder, config: &Qwen3Config) -> Result<Self> {
300+
let (dense, out_proj) = if vb.contains_tensor("score.dense.weight") {
301+
tracing::info!("Loading Qwen3 classifier with score layers");
302+
303+
let dense_weight = vb
304+
.pp("score.dense")
305+
.get((config.hidden_size, config.hidden_size), "weight")?;
306+
let dense_bias = vb.pp("score.dense").get(config.hidden_size, "bias")?;
307+
let dense = Linear::new(dense_weight, Some(dense_bias), None);
308+
309+
let out_proj_weight = vb
310+
.pp("score.out_proj")
311+
.get((1, config.hidden_size), "weight")?;
312+
let out_proj_bias = vb.pp("score.out_proj").get(1, "bias")?;
313+
let out_proj = Linear::new(out_proj_weight, Some(out_proj_bias), None);
314+
315+
(dense, out_proj)
316+
} else if vb.contains_tensor("classifier.dense.weight") {
317+
tracing::info!("Loading Qwen3 classifier with classifier layers");
318+
319+
let dense_weight = vb
320+
.pp("classifier.dense")
321+
.get((config.hidden_size, config.hidden_size), "weight")?;
322+
let dense_bias = vb.pp("classifier.dense").get(config.hidden_size, "bias")?;
323+
let dense = Linear::new(dense_weight, Some(dense_bias), None);
324+
325+
let out_proj_weight = vb
326+
.pp("classifier.out_proj")
327+
.get((1, config.hidden_size), "weight")?;
328+
let out_proj_bias = vb.pp("classifier.out_proj").get(1, "bias")?;
329+
let out_proj = Linear::new(out_proj_weight, Some(out_proj_bias), None);
330+
331+
(dense, out_proj)
332+
} else {
333+
candle::bail!(
334+
"Classification layers not found in model weights. \
335+
Expected 'score.dense.weight' or 'classifier.dense.weight' for reranker models. \
336+
This model may not be a trained reranker."
337+
);
338+
};
339+
340+
Ok(Self {
341+
dense,
342+
out_proj,
343+
activation: config.hidden_act.clone(),
344+
span: tracing::span!(tracing::Level::TRACE, "classifier"),
345+
})
346+
}
347+
}
348+
349+
impl ClassificationHead for Qwen3ClassificationHead {
350+
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
351+
let _enter = self.span.enter();
352+
353+
// Input is already pooled
354+
355+
// Apply dense layer with activation
356+
let hidden = self.dense.forward(hidden_states)?;
357+
let hidden = self.activation.forward(&hidden)?;
358+
359+
// Project to single score
360+
let score = self.out_proj.forward(&hidden)?;
361+
362+
// Squeeze to remove the last dimension if it's 1
363+
score.squeeze(candle::D::Minus1)
364+
}
365+
}
366+
284367
pub struct FlashQwen3Model {
285368
embeddings: Embedding,
286369
layers: Vec<Qwen3Layer>,
287370
norm: RMSNorm,
288371
cos_cache: Tensor,
289372
sin_cache: Tensor,
290373
pool: Pool,
374+
classifier: Option<Box<dyn ClassificationHead + Send>>,
291375
pub device: Device,
292376

293377
span: tracing::Span,
@@ -304,11 +388,19 @@ impl FlashQwen3Model {
304388
candle::bail!("FlashQwen3 requires DType::F16")
305389
}
306390

307-
let pool = match model_type {
391+
let (pool, classifier) = match model_type {
308392
ModelType::Classifier => {
309-
candle::bail!("`classifier` model type is not supported for Qwen3")
393+
let pool = Pool::LastToken;
394+
let classifier: Box<dyn ClassificationHead + Send> =
395+
Box::new(Qwen3ClassificationHead::load(vb.clone(), config)?);
396+
(pool, Some(classifier))
397+
}
398+
ModelType::Embedding(pool) => {
399+
if pool == Pool::Splade {
400+
candle::bail!("`splade` is not supported for Qwen3")
401+
}
402+
(pool, None)
310403
}
311-
ModelType::Embedding(pool) => pool,
312404
};
313405

314406
// The Qwen3-Reranker models contain the `model` key
@@ -351,6 +443,7 @@ impl FlashQwen3Model {
351443
cos_cache,
352444
sin_cache,
353445
pool,
446+
classifier,
354447
device: vb.device().clone(),
355448
span: tracing::span!(tracing::Level::TRACE, "model"),
356449
})
@@ -512,4 +605,23 @@ impl Model for FlashQwen3Model {
512605
fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
513606
self.forward(batch)
514607
}
608+
609+
fn predict(&self, batch: Batch) -> Result<Tensor> {
610+
match &self.classifier {
611+
None => candle::bail!("`predict` is not implemented for this model"),
612+
Some(classifier) => {
613+
// Run forward pass to get hidden states
614+
let (pooled_embeddings, _) = self.forward(batch)?;
615+
match pooled_embeddings {
616+
Some(embeddings) => {
617+
let scores = classifier.forward(&embeddings)?;
618+
// Apply sigmoid to convert logits to probabilities
619+
let probabilities = candle_nn::ops::sigmoid(&scores)?;
620+
Ok(probabilities)
621+
}
622+
None => candle::bail!("No pooled embeddings returned for classification"),
623+
}
624+
}
625+
}
626+
}
515627
}

backends/candle/src/models/qwen3.rs

Lines changed: 107 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -375,13 +375,89 @@ impl Qwen3Layer {
375375
}
376376
}
377377

378+
trait ClassificationHead {
379+
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor>;
380+
}
381+
382+
#[derive(Debug)]
383+
struct Qwen3ClassificationHead {
384+
dense: Linear,
385+
out_proj: Linear,
386+
activation: HiddenAct,
387+
span: tracing::Span,
388+
}
389+
390+
impl Qwen3ClassificationHead {
391+
pub fn load(vb: VarBuilder, config: &Qwen3Config) -> Result<Self> {
392+
let (dense, out_proj) = if vb.contains_tensor("score.dense.weight") {
393+
tracing::info!("Loading Qwen3 classifier with score layers");
394+
395+
let dense_weight = vb
396+
.pp("score.dense")
397+
.get((config.hidden_size, config.hidden_size), "weight")?;
398+
let dense_bias = vb.pp("score.dense").get(config.hidden_size, "bias")?;
399+
let dense = Linear::new(dense_weight, Some(dense_bias), None);
400+
401+
let out_proj_weight = vb
402+
.pp("score.out_proj")
403+
.get((1, config.hidden_size), "weight")?;
404+
let out_proj_bias = vb.pp("score.out_proj").get(1, "bias")?;
405+
let out_proj = Linear::new(out_proj_weight, Some(out_proj_bias), None);
406+
407+
(dense, out_proj)
408+
} else if vb.contains_tensor("classifier.dense.weight") {
409+
tracing::info!("Loading Qwen3 classifier with classifier layers");
410+
411+
let dense_weight = vb
412+
.pp("classifier.dense")
413+
.get((config.hidden_size, config.hidden_size), "weight")?;
414+
let dense_bias = vb.pp("classifier.dense").get(config.hidden_size, "bias")?;
415+
let dense = Linear::new(dense_weight, Some(dense_bias), None);
416+
417+
let out_proj_weight = vb
418+
.pp("classifier.out_proj")
419+
.get((1, config.hidden_size), "weight")?;
420+
let out_proj_bias = vb.pp("classifier.out_proj").get(1, "bias")?;
421+
let out_proj = Linear::new(out_proj_weight, Some(out_proj_bias), None);
422+
423+
(dense, out_proj)
424+
} else {
425+
candle::bail!(
426+
"Classification layers not found in model weights. \
427+
Expected 'score.dense.weight' or 'classifier.dense.weight' for reranker models. \
428+
This model may not be a trained reranker."
429+
);
430+
};
431+
432+
Ok(Self {
433+
dense,
434+
out_proj,
435+
activation: config.hidden_act.clone(),
436+
span: tracing::span!(tracing::Level::TRACE, "classifier"),
437+
})
438+
}
439+
}
440+
441+
impl ClassificationHead for Qwen3ClassificationHead {
442+
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
443+
let _enter = self.span.enter();
444+
445+
let hidden = self.dense.forward(hidden_states)?;
446+
let hidden = self.activation.forward(&hidden)?;
447+
let score = self.out_proj.forward(&hidden)?;
448+
449+
score.squeeze(D::Minus1)
450+
}
451+
}
452+
378453
pub struct Qwen3Model {
379454
embeddings: Embedding,
380455
layers: Vec<Qwen3Layer>,
381456
norm: RMSNorm,
382457
rotary_cache: (Tensor, Tensor),
383458
rotary_dim: usize,
384459
pool: Pool,
460+
classifier: Option<Box<dyn ClassificationHead + Send>>,
385461
num_attention_heads: usize,
386462
pad_token_id: u32,
387463

@@ -393,11 +469,19 @@ pub struct Qwen3Model {
393469

394470
impl Qwen3Model {
395471
pub fn load(vb: VarBuilder, config: &Qwen3Config, model_type: ModelType) -> Result<Self> {
396-
let pool = match model_type {
472+
let (pool, classifier) = match model_type {
397473
ModelType::Classifier => {
398-
candle::bail!("`classifier` model type is not supported for Qwen3")
474+
let pool = Pool::LastToken;
475+
let classifier: Box<dyn ClassificationHead + Send> =
476+
Box::new(Qwen3ClassificationHead::load(vb.clone(), config)?);
477+
(pool, Some(classifier))
478+
}
479+
ModelType::Embedding(pool) => {
480+
if pool == Pool::Splade {
481+
candle::bail!("`splade` is not supported for Qwen3")
482+
}
483+
(pool, None)
399484
}
400-
ModelType::Embedding(pool) => pool,
401485
};
402486

403487
// The Qwen3-Reranker models contain the `model` key
@@ -436,6 +520,7 @@ impl Qwen3Model {
436520
rotary_cache,
437521
rotary_dim,
438522
pool,
523+
classifier,
439524
pad_token_id: config.eos_token_id as u32,
440525
num_attention_heads: config.num_attention_heads,
441526
dtype: vb.dtype(),
@@ -700,4 +785,23 @@ impl Model for Qwen3Model {
700785
fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
701786
self.forward(batch)
702787
}
788+
789+
fn predict(&self, batch: Batch) -> Result<Tensor> {
790+
match &self.classifier {
791+
None => candle::bail!("`predict` is not implemented for this model"),
792+
Some(classifier) => {
793+
// Run forward pass to get hidden states
794+
let (pooled_embeddings, _) = self.forward(batch)?;
795+
match pooled_embeddings {
796+
Some(embeddings) => {
797+
let scores = classifier.forward(&embeddings)?;
798+
// Apply sigmoid to convert logits to probabilities
799+
let probabilities = candle_nn::ops::sigmoid(&scores)?;
800+
Ok(probabilities)
801+
}
802+
None => candle::bail!("No pooled embeddings returned for classification"),
803+
}
804+
}
805+
}
806+
}
703807
}

backends/src/lib.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,12 @@ impl Backend {
179179
}
180180

181181
#[instrument(skip_all)]
182-
pub fn create_warmup_batch(&self, shape: (u32, u32), max_token: u32, seq_bucket_size: u32) -> Batch {
182+
pub fn create_warmup_batch(
183+
&self,
184+
shape: (u32, u32),
185+
max_token: u32,
186+
seq_bucket_size: u32,
187+
) -> Batch {
183188
let (batch_size, length) = shape;
184189
let min_length = length.saturating_sub(seq_bucket_size).saturating_add(1);
185190
let tmp_length = if min_length < length {

router/src/lib.rs

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ pub async fn run(
110110
serde_json::from_str(&config).context("Failed to parse `config.json`")?;
111111

112112
// Set model type from config
113-
let backend_model_type = get_backend_model_type(&config, &model_root, pooling)?;
113+
let backend_model_type = get_backend_model_type(&config, &model_root, &model_id, pooling)?;
114114

115115
// Info model type
116116
let model_type = match &backend_model_type {
@@ -355,6 +355,7 @@ pub async fn run(
355355
fn get_backend_model_type(
356356
config: &ModelConfig,
357357
model_root: &Path,
358+
model_id: &str,
358359
pooling: Option<text_embeddings_backend::Pool>,
359360
) -> Result<text_embeddings_backend::ModelType> {
360361
for arch in &config.architectures {
@@ -381,6 +382,29 @@ fn get_backend_model_type(
381382
}
382383
}
383384

385+
// Qwen3-Reranker detection
386+
if config
387+
.architectures
388+
.iter()
389+
.any(|arch| arch == "Qwen3ForCausalLM")
390+
{
391+
let model_name = model_id
392+
.split('/')
393+
.last()
394+
.unwrap_or(model_id)
395+
.to_lowercase();
396+
397+
if model_name.contains("reranker") {
398+
tracing::info!("Detected Qwen3-Reranker model, treating as classifier");
399+
if pooling.is_some() {
400+
tracing::warn!(
401+
"`--pooling` arg is set but model is a reranker. Ignoring `--pooling` arg."
402+
);
403+
}
404+
return Ok(text_embeddings_backend::ModelType::Classifier);
405+
}
406+
}
407+
384408
if Some(text_embeddings_backend::Pool::Splade) == pooling {
385409
return Err(anyhow!(
386410
"Splade pooling is not supported: model is not a ForMaskedLM model"

0 commit comments

Comments
 (0)