Skip to content

Commit f601fd8

Browse files
authored
Update modernbert.rs (#3010)
1 parent 9fe6232 commit f601fd8

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

candle-transformers/src/models/modernbert.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ impl ModernBertForSequenceClassification {
488488
pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
489489
let output = self.model.forward(xs, mask)?;
490490
let last_hidden_state = match self.classifier_pooling {
491-
ClassifierPooling::CLS => output.i((.., .., 0))?,
491+
ClassifierPooling::CLS => output.i((.., 0, ..))?,
492492
ClassifierPooling::MEAN => {
493493
let unsqueezed_mask = &mask.unsqueeze(D::Minus1)?.to_dtype(DType::F32)?;
494494
let sum_output = output.broadcast_mul(unsqueezed_mask)?.sum(1)?;

0 commit comments

Comments
 (0)