Skip to content

Commit 5763cd3

Browse files
committed
refactor: support qwen3 reranker
1 parent c30aebc commit 5763cd3

File tree

20 files changed

+528
-276
lines changed

20 files changed

+528
-276
lines changed

backends/candle/src/lib.rs

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -597,13 +597,30 @@ impl Backend for CandleBackend {
597597
let batch_size = batch.len();
598598

599599
let results = self.model.predict(batch).e()?;
600-
601-
let results = results.to_dtype(DType::F32).e()?.to_vec2().e()?;
600+
let results = results.to_dtype(DType::F32).e()?;
602601

603602
let mut predictions =
604603
HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());
605-
for (i, r) in results.into_iter().enumerate() {
606-
predictions.insert(i, r);
604+
605+
match results.dims() {
606+
[_n] => {
607+
let scores = results.to_vec1::<f32>().e()?;
608+
for (i, score) in scores.into_iter().enumerate() {
609+
predictions.insert(i, vec![score]);
610+
}
611+
}
612+
[_, _] => {
613+
let results = results.to_vec2().e()?;
614+
for (i, r) in results.into_iter().enumerate() {
615+
predictions.insert(i, r);
616+
}
617+
}
618+
dims => {
619+
return Err(BackendError::Inference(format!(
620+
"Unexpected tensor shape: {:?}",
621+
dims
622+
)));
623+
}
607624
}
608625

609626
Ok(predictions)

backends/candle/src/models/bert.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,9 @@ impl BertModel {
597597
};
598598
(pool, None, splade)
599599
}
600+
ModelType::ListwiseReranker => {
601+
candle::bail!("`reranker` model type is not supported for Bert")
602+
}
600603
};
601604

602605
let (embeddings, encoder) = match (
@@ -615,7 +618,6 @@ impl BertModel {
615618
}
616619
}
617620
};
618-
619621
Ok(Self {
620622
embeddings,
621623
encoder,
@@ -661,6 +663,9 @@ impl BertModel {
661663
};
662664
(pool, None, splade)
663665
}
666+
ModelType::ListwiseReranker => {
667+
candle::bail!("`reranker` model type is not supported for RoBERTa")
668+
}
664669
};
665670

666671
let (embeddings, encoder) = match (

backends/candle/src/models/distilbert.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,9 @@ impl DistilBertModel {
462462

463463
(pool, None)
464464
}
465+
ModelType::ListwiseReranker => {
466+
candle::bail!("`reranker` model type is not supported for DistilBert")
467+
}
465468
};
466469

467470
let (embeddings, encoder) = match (

backends/candle/src/models/flash_qwen3.rs

Lines changed: 105 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ impl Qwen3Attention {
4242
"weight",
4343
)?;
4444
let query_bias = if config.attention_bias {
45-
Some(vb.pp("q_proj").get(hidden_size, "bias")?)
45+
Some(
46+
vb.pp("q_proj")
47+
.get(num_attention_heads * attention_head_size, "bias")?,
48+
)
4649
} else {
4750
None
4851
};
@@ -85,7 +88,7 @@ impl Qwen3Attention {
8588
let q_norm = RMSNorm::load(vb.pp("q_norm"), attention_head_size, config.rms_norm_eps)?;
8689
let k_norm = RMSNorm::load(vb.pp("k_norm"), attention_head_size, config.rms_norm_eps)?;
8790

88-
let softmax_scale = (1. / (attention_head_size as f64).sqrt()) as f32;
91+
let softmax_scale = 1.0 / (attention_head_size as f64).sqrt() as f32;
8992

9093
Ok(Self {
9194
q_proj,
@@ -148,6 +151,28 @@ impl Qwen3Attention {
148151

149152
apply_rotary_inplace(&q, &k, &cos, &sin, true)?;
150153

154+
let (k, v) = if self.num_key_value_heads != self.num_attention_heads {
155+
if self.num_attention_heads % self.num_key_value_heads != 0 {
156+
candle::bail!("num_attention_heads must be a multiple of num_key_value_heads");
157+
}
158+
let repeat = self.num_attention_heads / self.num_key_value_heads;
159+
160+
let (total_tokens, n_kv_heads, head_dim) = k.dims3()?;
161+
162+
let k = k
163+
.unsqueeze(2)?
164+
.expand((total_tokens, n_kv_heads, repeat, head_dim))?
165+
.reshape((total_tokens, n_kv_heads * repeat, head_dim))?;
166+
167+
let v = v
168+
.unsqueeze(2)?
169+
.expand((total_tokens, n_kv_heads, repeat, head_dim))?
170+
.reshape((total_tokens, n_kv_heads * repeat, head_dim))?;
171+
(k, v)
172+
} else {
173+
(k, v)
174+
};
175+
151176
let attention = flash_attn_varlen(
152177
&q,
153178
&k,
@@ -277,101 +302,20 @@ impl Qwen3Layer {
277302

278303
let mlp_output = self.mlp.forward(&normed_attn_res_output)?;
279304

280-
Ok((mlp_output, attn_res))
281-
}
282-
}
283-
284-
// Define ClassificationHead trait locally (following TEI pattern)
285-
trait ClassificationHead {
286-
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor>;
287-
}
288-
289-
// Qwen3 Classification Head implementation
290-
#[derive(Debug)]
291-
struct Qwen3ClassificationHead {
292-
dense: Linear,
293-
out_proj: Linear,
294-
activation: HiddenAct,
295-
span: tracing::Span,
296-
}
297-
298-
impl Qwen3ClassificationHead {
299-
pub fn load(vb: VarBuilder, config: &Qwen3Config) -> Result<Self> {
300-
let (dense, out_proj) = if vb.contains_tensor("score.dense.weight") {
301-
tracing::info!("Loading Qwen3 classifier with score layers");
302-
303-
let dense_weight = vb
304-
.pp("score.dense")
305-
.get((config.hidden_size, config.hidden_size), "weight")?;
306-
let dense_bias = vb.pp("score.dense").get(config.hidden_size, "bias")?;
307-
let dense = Linear::new(dense_weight, Some(dense_bias), None);
308-
309-
let out_proj_weight = vb
310-
.pp("score.out_proj")
311-
.get((1, config.hidden_size), "weight")?;
312-
let out_proj_bias = vb.pp("score.out_proj").get(1, "bias")?;
313-
let out_proj = Linear::new(out_proj_weight, Some(out_proj_bias), None);
314-
315-
(dense, out_proj)
316-
} else if vb.contains_tensor("classifier.dense.weight") {
317-
tracing::info!("Loading Qwen3 classifier with classifier layers");
318-
319-
let dense_weight = vb
320-
.pp("classifier.dense")
321-
.get((config.hidden_size, config.hidden_size), "weight")?;
322-
let dense_bias = vb.pp("classifier.dense").get(config.hidden_size, "bias")?;
323-
let dense = Linear::new(dense_weight, Some(dense_bias), None);
324-
325-
let out_proj_weight = vb
326-
.pp("classifier.out_proj")
327-
.get((1, config.hidden_size), "weight")?;
328-
let out_proj_bias = vb.pp("classifier.out_proj").get(1, "bias")?;
329-
let out_proj = Linear::new(out_proj_weight, Some(out_proj_bias), None);
330-
331-
(dense, out_proj)
332-
} else {
333-
candle::bail!(
334-
"Classification layers not found in model weights. \
335-
Expected 'score.dense.weight' or 'classifier.dense.weight' for reranker models. \
336-
This model may not be a trained reranker."
337-
);
338-
};
339-
340-
Ok(Self {
341-
dense,
342-
out_proj,
343-
activation: config.hidden_act.clone(),
344-
span: tracing::span!(tracing::Level::TRACE, "classifier"),
345-
})
346-
}
347-
}
348-
349-
impl ClassificationHead for Qwen3ClassificationHead {
350-
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
351-
let _enter = self.span.enter();
352-
353-
// Input is already pooled
354-
355-
// Apply dense layer with activation
356-
let hidden = self.dense.forward(hidden_states)?;
357-
let hidden = self.activation.forward(&hidden)?;
358-
359-
// Project to single score
360-
let score = self.out_proj.forward(&hidden)?;
361-
362-
// Squeeze to remove the last dimension if it's 1
363-
score.squeeze(candle::D::Minus1)
305+
let output = (&mlp_output + &attn_res)?;
306+
Ok((output, attn_res))
364307
}
365308
}
366309

367310
pub struct FlashQwen3Model {
368311
embeddings: Embedding,
312+
lm_head_weight: Tensor,
369313
layers: Vec<Qwen3Layer>,
370314
norm: RMSNorm,
371315
cos_cache: Tensor,
372316
sin_cache: Tensor,
317+
model_type: ModelType,
373318
pool: Pool,
374-
classifier: Option<Box<dyn ClassificationHead + Send>>,
375319
pub device: Device,
376320

377321
span: tracing::Span,
@@ -388,19 +332,12 @@ impl FlashQwen3Model {
388332
candle::bail!("FlashQwen3 requires DType::F16")
389333
}
390334

391-
let (pool, classifier) = match model_type {
335+
let pool = match &model_type {
392336
ModelType::Classifier => {
393-
let pool = Pool::LastToken;
394-
let classifier: Box<dyn ClassificationHead + Send> =
395-
Box::new(Qwen3ClassificationHead::load(vb.clone(), config)?);
396-
(pool, Some(classifier))
397-
}
398-
ModelType::Embedding(pool) => {
399-
if pool == Pool::Splade {
400-
candle::bail!("`splade` is not supported for Qwen3")
401-
}
402-
(pool, None)
337+
candle::bail!("`classifier` model type is not supported for Qwen3")
403338
}
339+
ModelType::Embedding(pool) => pool.clone(),
340+
ModelType::ListwiseReranker => Pool::LastToken,
404341
};
405342

406343
// The Qwen3-Reranker models contain the `model` key
@@ -411,11 +348,13 @@ impl FlashQwen3Model {
411348
vb
412349
};
413350

414-
let embeddings = Embedding::new(
415-
vb.pp("embed_tokens")
416-
.get((config.vocab_size, config.hidden_size), "weight")?,
417-
config.hidden_size,
418-
);
351+
let embed_weight = vb
352+
.pp("embed_tokens")
353+
.get((config.vocab_size, config.hidden_size), "weight")?;
354+
355+
let embeddings = Embedding::new(embed_weight.clone(), config.hidden_size);
356+
357+
let lm_head_weight = embed_weight;
419358

420359
let layers = (0..config.num_hidden_layers)
421360
.map(|index| Qwen3Layer::load(vb.pp(format!("layers.{index}")), config))
@@ -438,12 +377,13 @@ impl FlashQwen3Model {
438377

439378
Ok(Self {
440379
embeddings,
380+
lm_head_weight,
441381
layers,
442382
norm,
443383
cos_cache,
444384
sin_cache,
385+
model_type,
445386
pool,
446-
classifier,
447387
device: vb.device().clone(),
448388
span: tracing::span!(tracing::Level::TRACE, "model"),
449389
})
@@ -469,21 +409,19 @@ impl FlashQwen3Model {
469409
let cos = self.cos_cache.index_select(&position_ids, 0)?;
470410
let sin = self.sin_cache.index_select(&position_ids, 0)?;
471411

472-
let mut residual = None;
473412
for layer in &self.layers {
474-
let (h, r) = layer.forward(
413+
let (h, _r) = layer.forward(
475414
&hidden_states,
476-
residual.as_ref(),
415+
None,
477416
&cu_seqlens,
478417
&cos,
479418
&sin,
480419
batch.max_length as usize,
481420
)?;
482421
hidden_states = h;
483-
residual = Some(r);
484422
}
485423

486-
let (outputs, _) = self.norm.forward(&hidden_states, residual.as_ref())?;
424+
let (outputs, _) = self.norm.forward(&hidden_states, None)?;
487425

488426
let has_pooling_requests = !batch.pooled_indices.is_empty();
489427
let has_raw_requests = !batch.raw_indices.is_empty();
@@ -553,7 +491,8 @@ impl FlashQwen3Model {
553491
// Concatenate all results
554492
Some(Tensor::cat(&results?, 0)?)
555493
} else {
556-
Some((outputs.sum_keepdim(0)? / (batch.max_length as f64))?)
494+
let actual_len = batch.cumulative_seq_lengths[1] as f64;
495+
Some((outputs.sum_keepdim(0)? / actual_len)?)
557496
}
558497
}
559498
Pool::Splade => {
@@ -607,21 +546,64 @@ impl Model for FlashQwen3Model {
607546
}
608547

609548
fn predict(&self, batch: Batch) -> Result<Tensor> {
610-
match &self.classifier {
611-
None => candle::bail!("`predict` is not implemented for this model"),
612-
Some(classifier) => {
613-
// Run forward pass to get hidden states
614-
let (pooled_embeddings, _) = self.forward(batch)?;
615-
match pooled_embeddings {
616-
Some(embeddings) => {
617-
let scores = classifier.forward(&embeddings)?;
618-
// Apply sigmoid to convert logits to probabilities
619-
let probabilities = candle_nn::ops::sigmoid(&scores)?;
620-
Ok(probabilities)
621-
}
622-
None => candle::bail!("No pooled embeddings returned for classification"),
549+
match &self.model_type {
550+
ModelType::ListwiseReranker => {
551+
let _enter = self.span.enter();
552+
553+
let batch_size = batch.cumulative_seq_lengths.len() - 1;
554+
let shape = batch.input_ids.len();
555+
556+
let input_ids = Tensor::from_vec(batch.input_ids, shape, &self.device)?;
557+
let position_ids = Tensor::from_vec(batch.position_ids, shape, &self.device)?;
558+
let cu_seqlens = Tensor::from_vec(
559+
batch.cumulative_seq_lengths.clone(),
560+
batch_size + 1,
561+
&self.device,
562+
)?;
563+
564+
let mut hidden_states = self.embeddings.forward(&input_ids)?;
565+
566+
let cos = self.cos_cache.index_select(&position_ids, 0)?;
567+
let sin = self.sin_cache.index_select(&position_ids, 0)?;
568+
569+
for layer in &self.layers {
570+
let (h, _r) = layer.forward(
571+
&hidden_states,
572+
None,
573+
&cu_seqlens,
574+
&cos,
575+
&sin,
576+
batch.max_length as usize,
577+
)?;
578+
hidden_states = h;
623579
}
580+
581+
let (outputs, _) = self.norm.forward(&hidden_states, None)?;
582+
583+
let mut last_hidden_states = Vec::with_capacity(batch_size);
584+
585+
for i in 0..batch_size {
586+
let seq_end = batch.cumulative_seq_lengths[i + 1] as usize;
587+
let last_token_idx = seq_end - 1;
588+
589+
let h_last = outputs.i(last_token_idx)?; // [hidden_size]
590+
last_hidden_states.push(h_last);
591+
}
592+
593+
let h_last = Tensor::stack(&last_hidden_states, 0)?; // [bs, hidden_size]
594+
595+
let true_id = 9693u32;
596+
let false_id = 2152u32;
597+
598+
let ids = Tensor::from_vec(vec![false_id, true_id], 2, &self.device)?;
599+
let w = self.lm_head_weight.index_select(&ids, 0)?; // [2, hidden_size]
600+
let logits = h_last.matmul(&w.t()?)?; // [bs, 2] (no, yes)
601+
let log_probs = candle_nn::ops::log_softmax(&logits, D::Minus1)?;
602+
let scores = log_probs.i((.., 1))?.exp()?; // P("yes") ∈ (0,1)
603+
604+
Ok(scores)
624605
}
606+
_ => candle::bail!("`predict` is only available for ModelType::ListwiseReranker"),
625607
}
626608
}
627609
}

0 commit comments

Comments
 (0)