Skip to content

Commit 004f0a4

Browse files
committed
refactor: support qwen3 reranker
1 parent c30aebc commit 004f0a4

File tree

20 files changed

+546
-276
lines changed

20 files changed

+546
-276
lines changed

backends/candle/src/lib.rs

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -597,13 +597,33 @@ impl Backend for CandleBackend {
597597
let batch_size = batch.len();
598598

599599
let results = self.model.predict(batch).e()?;
600-
601-
let results = results.to_dtype(DType::F32).e()?.to_vec2().e()?;
600+
let results = results.to_dtype(DType::F32).e()?;
602601

603602
let mut predictions =
604603
HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());
605-
for (i, r) in results.into_iter().enumerate() {
606-
predictions.insert(i, r);
604+
605+
// Handle both 1D and 2D tensors
606+
match results.dims() {
607+
[_n] => {
608+
// 1D tensor (e.g., ListwiseReranker returns a vector of scores)
609+
let scores = results.to_vec1::<f32>().e()?;
610+
for (i, score) in scores.into_iter().enumerate() {
611+
predictions.insert(i, vec![score]);
612+
}
613+
}
614+
[_, _] => {
615+
// 2D tensor (normal case)
616+
let results = results.to_vec2().e()?;
617+
for (i, r) in results.into_iter().enumerate() {
618+
predictions.insert(i, r);
619+
}
620+
}
621+
dims => {
622+
return Err(BackendError::Inference(format!(
623+
"Unexpected tensor shape: {:?}",
624+
dims
625+
)));
626+
}
607627
}
608628

609629
Ok(predictions)

backends/candle/src/models/bert.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,9 @@ impl BertModel {
597597
};
598598
(pool, None, splade)
599599
}
600+
ModelType::ListwiseReranker => {
601+
candle::bail!("`reranker` model type is not supported for Bert")
602+
}
600603
};
601604

602605
let (embeddings, encoder) = match (
@@ -615,7 +618,6 @@ impl BertModel {
615618
}
616619
}
617620
};
618-
619621
Ok(Self {
620622
embeddings,
621623
encoder,
@@ -661,6 +663,9 @@ impl BertModel {
661663
};
662664
(pool, None, splade)
663665
}
666+
ModelType::ListwiseReranker => {
667+
candle::bail!("`reranker` model type is not supported for RoBERTa")
668+
}
664669
};
665670

666671
let (embeddings, encoder) = match (

backends/candle/src/models/distilbert.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,9 @@ impl DistilBertModel {
462462

463463
(pool, None)
464464
}
465+
ModelType::ListwiseReranker => {
466+
candle::bail!("`reranker` model type is not supported for DistilBert")
467+
}
465468
};
466469

467470
let (embeddings, encoder) = match (

0 commit comments

Comments
 (0)