|
3 | 3 | //! Functionality for modeling sampling strategies and logits processing in text generation |
4 | 4 | //! with support for temperature-based sampling, top-k filtering, nucleus sampling (top-p), |
5 | 5 | //! and combinations thereof. |
6 | | -use candle::{Context, DType, Error, Result, Tensor}; |
| 6 | +use candle::{DType, Error, Result, Tensor}; |
7 | 7 | use rand::{distr::Distribution, SeedableRng}; |
8 | 8 |
|
9 | 9 | #[derive(Clone, PartialEq, Debug)] |
@@ -41,19 +41,12 @@ impl LogitsProcessor { |
41 | 41 | } |
42 | 42 |
|
43 | 43 | fn sample_argmax(&mut self, logits: Tensor) -> Result<u32> { |
44 | | - let logits_v: Vec<f32> = logits.to_vec1()?; |
45 | | - let next_token = logits_v |
46 | | - .iter() |
47 | | - .enumerate() |
48 | | - .max_by(|(_, u), (_, v)| u.total_cmp(v)) |
49 | | - .map(|(i, _)| i as u32) |
50 | | - .context("empty logits")?; |
51 | | - Ok(next_token) |
| 44 | + logits.argmax(candle::D::Minus1)?.to_scalar::<u32>() |
52 | 45 | } |
53 | 46 |
|
54 | 47 | fn sample_gumbel_softmax(&mut self, logits: &Tensor, temperature: f64) -> Result<u32> { |
55 | 48 | let sampled = candle_nn::sampling::gumbel_softmax(logits, temperature, candle::D::Minus1)?; |
56 | | - sampled.to_vec0::<u32>() |
| 49 | + sampled.to_scalar::<u32>() |
57 | 50 | } |
58 | 51 |
|
59 | 52 | fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> { |
|
0 commit comments