@@ -131,10 +131,12 @@ enum class softmax_implementation {latency=0, legacy=1, stable=2, argmax=3};
131131template <class data_T , typename CONFIG_T>
132132inline unsigned softmax_stable_idx_from_real_val (const data_T x){
133133 // Number of address bits for table
134- static constexpr int N = ceillog2 (CONFIG_T::table_size);
134+ static constexpr int N = ceillog2 (CONFIG_T::table_size);
135135
136136 // Slice the top N bits of the input
137- hls_register ac_int<N, false > y = x.template slc <N>(x.width -N-1 );
137+ hls_register ac_int<N, false > y = x.template slc <N>(x.width -N-1 );
138+ // If x is the most negative value, the slice will be 0, so we need to set the 0-th bit to ensure correctness
139+ if (x != 0 && y == 0 ) y[0 ] = 1 ;
138140 return y.to_uint ();
139141}
140142
@@ -158,11 +160,18 @@ void softmax_stable(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){
158160 Op_max<data_T> op_max;
159161 hls_register data_T x_max = reduce<data_T, CONFIG_T::n_in, Op_max<data_T>>(data, op_max);
160162
163+ // For the diffs, use the same type as the input but force rounding and saturation
164+ hls_register ac_fixed<data_T::width, data_T::i_width, true , AC_RND, AC_SAT> d_xi_xmax[CONFIG_T::n_in];
165+ for (unsigned i = 0 ; i < CONFIG_T::n_in; i++){
166+ #pragma HLS unroll
167+ d_xi_xmax[i] = data[i] - x_max;
168+ }
169+
161170 // Calculate all the e^x's
162171 hls_register typename CONFIG_T::exp_table_t exp_res[CONFIG_T::n_in];
163172 #pragma unroll
164173 for (unsigned i = 0 ; i < CONFIG_T::n_in; i++) {
165- exp_res[i] = exp_table[softmax_stable_idx_from_real_val<data_T, CONFIG_T>(data [i] - x_max )];
174+ exp_res[i] = exp_table[softmax_stable_idx_from_real_val<data_T, CONFIG_T>(d_xi_xmax [i])];
166175 }
167176
168177 // Explicitly sum previously calculated exponentials with an adder tree
0 commit comments