Skip to content

Commit d4bac37

Browse files
Fix the gumbel softmax by casting to f32. (#2928)
1 parent e98754f commit d4bac37

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

candle-nn/src/sampling.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,16 @@ pub fn gumbel_softmax<D: candle::shape::Dim>(
88
) -> Result<Tensor> {
99
if temperature <= 0.0 {
1010
logits.argmax(dim)
11-
} else if temperature == 1.0 {
12-
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
13-
let sampled = (logits - minus_g)?.argmax(dim)?;
14-
Ok(sampled)
1511
} else {
12+
// Cast to f32, doing the Gumbel softmax in bf16 is a bit unstable.
13+
let logits = logits.to_dtype(candle::DType::F32)?;
1614
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
17-
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
18-
Ok(sampled)
15+
if temperature == 1.0 {
16+
let sampled = (logits - minus_g)?.argmax(dim)?;
17+
Ok(sampled)
18+
} else {
19+
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
20+
Ok(sampled)
21+
}
1922
}
2023
}

0 commit comments

Comments
 (0)