@@ -42,47 +42,48 @@ Tensor& softmax_out(
4242 // Adjust for negative dim
4343 dim = dim < 0 ? dim + nonzero_dim (in) : dim;
4444
45- ET_SWITCH_FLOATH_TYPES (in.scalar_type (), ctx, " _softmax.out" , CTYPE, [&]() {
46- const CTYPE* const in_data = in.const_data_ptr <CTYPE>();
47- CTYPE* const out_data = out.mutable_data_ptr <CTYPE>();
45+ ET_SWITCH_FLOATHBF16_TYPES (
46+ in.scalar_type (), ctx, " _softmax.out" , CTYPE, [&]() {
47+ const CTYPE* const in_data = in.const_data_ptr <CTYPE>();
48+ CTYPE* const out_data = out.mutable_data_ptr <CTYPE>();
4849
49- apply_over_dim (
50- [in_data, out_data](
51- const size_t size, const size_t stride, const size_t base) {
52- // calculate max in softmax dim. During softmax computation each
53- // value is subtracted by the maximum in value before calling exp
54- // to preserve numerical stability.
55- const CTYPE max_in = apply_unary_reduce_fn (
56- [](const CTYPE val_in, CTYPE val_accum) {
57- return std::max (val_in, val_accum);
58- },
59- in_data + base,
60- size,
61- stride);
50+ apply_over_dim (
51+ [in_data, out_data](
52+ const size_t size, const size_t stride, const size_t base) {
53+ // calculate max in softmax dim. During softmax computation each
54+ // value is subtracted by the maximum in value before calling exp
55+ // to preserve numerical stability.
56+ const CTYPE max_in = apply_unary_reduce_fn (
57+ [](const CTYPE val_in, CTYPE val_accum) {
58+ return std::max (val_in, val_accum);
59+ },
60+ in_data + base,
61+ size,
62+ stride);
6263
63- const CTYPE temp_sum = apply_unary_map_reduce_fn<CTYPE, CTYPE>(
64- [max_in](const CTYPE val_in) {
65- return std::exp (val_in - max_in);
66- },
67- [](const CTYPE mapped_in, CTYPE val_accum) {
68- return val_accum + mapped_in;
69- },
70- in_data + base,
71- size,
72- stride);
64+ const CTYPE temp_sum = apply_unary_map_reduce_fn<CTYPE, CTYPE>(
65+ [max_in](const CTYPE val_in) {
66+ return std::exp (val_in - max_in);
67+ },
68+ [](const CTYPE mapped_in, CTYPE val_accum) {
69+ return val_accum + mapped_in;
70+ },
71+ in_data + base,
72+ size,
73+ stride);
7374
74- apply_unary_map_fn (
75- [max_in, temp_sum](const CTYPE val_in) {
76- return std::exp (val_in - max_in) / temp_sum;
77- },
78- in_data + base,
79- out_data + base,
80- size,
81- stride);
82- },
83- in,
84- dim);
85- });
75+ apply_unary_map_fn (
76+ [max_in, temp_sum](const CTYPE val_in) {
77+ return std::exp (val_in - max_in) / temp_sum;
78+ },
79+ in_data + base,
80+ out_data + base,
81+ size,
82+ stride);
83+ },
84+ in,
85+ dim);
86+ });
8687
8788 return out;
8889}
0 commit comments