Skip to content

Commit c16c0bd

Browse files
committed
refactor: support qwen3 reranker
1 parent c30aebc commit c16c0bd

File tree

20 files changed

+540
-276
lines changed

20 files changed

+540
-276
lines changed

backends/candle/src/lib.rs

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -597,13 +597,30 @@ impl Backend for CandleBackend {
597597
let batch_size = batch.len();
598598

599599
let results = self.model.predict(batch).e()?;
600-
601-
let results = results.to_dtype(DType::F32).e()?.to_vec2().e()?;
600+
let results = results.to_dtype(DType::F32).e()?;
602601

603602
let mut predictions =
604603
HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());
605-
for (i, r) in results.into_iter().enumerate() {
606-
predictions.insert(i, r);
604+
605+
match results.dims() {
606+
[_n] => {
607+
let scores = results.to_vec1::<f32>().e()?;
608+
for (i, score) in scores.into_iter().enumerate() {
609+
predictions.insert(i, vec![score]);
610+
}
611+
}
612+
[_, _] => {
613+
let results = results.to_vec2().e()?;
614+
for (i, r) in results.into_iter().enumerate() {
615+
predictions.insert(i, r);
616+
}
617+
}
618+
dims => {
619+
return Err(BackendError::Inference(format!(
620+
"Unexpected tensor shape: {:?}",
621+
dims
622+
)));
623+
}
607624
}
608625

609626
Ok(predictions)

backends/candle/src/models/bert.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,9 @@ impl BertModel {
597597
};
598598
(pool, None, splade)
599599
}
600+
ModelType::ListwiseReranker => {
601+
candle::bail!("`reranker` model type is not supported for Bert")
602+
}
600603
};
601604

602605
let (embeddings, encoder) = match (
@@ -615,7 +618,6 @@ impl BertModel {
615618
}
616619
}
617620
};
618-
619621
Ok(Self {
620622
embeddings,
621623
encoder,
@@ -661,6 +663,9 @@ impl BertModel {
661663
};
662664
(pool, None, splade)
663665
}
666+
ModelType::ListwiseReranker => {
667+
candle::bail!("`reranker` model type is not supported for RoBERTa")
668+
}
664669
};
665670

666671
let (embeddings, encoder) = match (

backends/candle/src/models/distilbert.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,9 @@ impl DistilBertModel {
462462

463463
(pool, None)
464464
}
465+
ModelType::ListwiseReranker => {
466+
candle::bail!("`reranker` model type is not supported for DistilBert")
467+
}
465468
};
466469

467470
let (embeddings, encoder) = match (

0 commit comments

Comments
 (0)