Skip to content

Commit bf3d3f2

Browse files
authored
Use Tensor::argmax instead of manual cpu impl (#3173)
1 parent 836540f commit bf3d3f2

File tree

1 file changed

+3
-10
lines changed
  • candle-transformers/src/generation

1 file changed

+3
-10
lines changed

candle-transformers/src/generation/mod.rs

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
//! Functionality for modeling sampling strategies and logits processing in text generation
44
//! with support for temperature-based sampling, top-k filtering, nucleus sampling (top-p),
55
//! and combinations thereof.
6-
use candle::{Context, DType, Error, Result, Tensor};
6+
use candle::{DType, Error, Result, Tensor};
77
use rand::{distr::Distribution, SeedableRng};
88

99
#[derive(Clone, PartialEq, Debug)]
@@ -41,19 +41,12 @@ impl LogitsProcessor {
4141
}
4242

4343
fn sample_argmax(&mut self, logits: Tensor) -> Result<u32> {
44-
let logits_v: Vec<f32> = logits.to_vec1()?;
45-
let next_token = logits_v
46-
.iter()
47-
.enumerate()
48-
.max_by(|(_, u), (_, v)| u.total_cmp(v))
49-
.map(|(i, _)| i as u32)
50-
.context("empty logits")?;
51-
Ok(next_token)
44+
logits.argmax(candle::D::Minus1)?.to_scalar::<u32>()
5245
}
5346

5447
fn sample_gumbel_softmax(&mut self, logits: &Tensor, temperature: f64) -> Result<u32> {
5548
let sampled = candle_nn::sampling::gumbel_softmax(logits, temperature, candle::D::Minus1)?;
56-
sampled.to_vec0::<u32>()
49+
sampled.to_scalar::<u32>()
5750
}
5851

5952
fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {

0 commit comments

Comments
 (0)