File tree Expand file tree Collapse file tree 1 file changed +9
-6
lines changed Expand file tree Collapse file tree 1 file changed +9
-6
lines changed Original file line number Diff line number Diff line change @@ -8,13 +8,16 @@ pub fn gumbel_softmax<D: candle::shape::Dim>(
88) -> Result < Tensor > {
99 if temperature <= 0.0 {
1010 logits. argmax ( dim)
11- } else if temperature == 1.0 {
12- let minus_g = logits. rand_like ( 1e-7 , 0.999 ) ?. log ( ) ?. neg ( ) ?. log ( ) ?;
13- let sampled = ( logits - minus_g) ?. argmax ( dim) ?;
14- Ok ( sampled)
1511 } else {
12+ // Cast to f32, doing the Gumbel softmax in bf16 is a bit unstable.
13+ let logits = logits. to_dtype ( candle:: DType :: F32 ) ?;
1614 let minus_g = logits. rand_like ( 1e-7 , 0.999 ) ?. log ( ) ?. neg ( ) ?. log ( ) ?;
17- let sampled = ( logits + minus_g * ( -temperature) ) ?. argmax ( dim) ?;
18- Ok ( sampled)
15+ if temperature == 1.0 {
16+ let sampled = ( logits - minus_g) ?. argmax ( dim) ?;
17+ Ok ( sampled)
18+ } else {
19+ let sampled = ( logits + minus_g * ( -temperature) ) ?. argmax ( dim) ?;
20+ Ok ( sampled)
21+ }
1922 }
2023}
You can’t perform that action at this time.
0 commit comments