Skip to content

Commit 6fd7f56

Browse files
committed
Fix stable softmax strategy in Quartus
1 parent 97e300a commit 6fd7f56

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,12 @@ enum class softmax_implementation {latency=0, legacy=1, stable=2, argmax=3};
131131
template<class data_T, typename CONFIG_T>
132132
inline unsigned softmax_stable_idx_from_real_val(const data_T x){
133133
// Number of address bits for table
134-
static constexpr int N = ceillog2(CONFIG_T::table_size);
134+
static constexpr int N = ceillog2(CONFIG_T::table_size);
135135

136136
// Slice the top N bits of the input
137-
hls_register ac_int<N, false> y = x.template slc<N>(x.width-N-1);
137+
hls_register ac_int<N, false> y = x.template slc<N>(x.width-N-1);
138+
// If x is the most negative value, the slice will be 0, so we need to set the 0-th bit to ensure correctness
139+
if (x != 0 && y == 0) y[0] = 1;
138140
return y.to_uint();
139141
}
140142

@@ -158,11 +160,18 @@ void softmax_stable(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){
158160
Op_max<data_T> op_max;
159161
hls_register data_T x_max = reduce<data_T, CONFIG_T::n_in, Op_max<data_T>>(data, op_max);
160162

163+
// For the diffs, use the same type as the input but force rounding and saturation
164+
hls_register ac_fixed<data_T::width, data_T::i_width, true, AC_RND, AC_SAT> d_xi_xmax[CONFIG_T::n_in];
165+
for(unsigned i = 0; i < CONFIG_T::n_in; i++){
166+
#pragma HLS unroll
167+
d_xi_xmax[i] = data[i] - x_max;
168+
}
169+
161170
// Calculate all the e^x's
162171
hls_register typename CONFIG_T::exp_table_t exp_res[CONFIG_T::n_in];
163172
#pragma unroll
164173
for(unsigned i = 0; i < CONFIG_T::n_in; i++) {
165-
exp_res[i] = exp_table[softmax_stable_idx_from_real_val<data_T, CONFIG_T>(data[i] - x_max)];
174+
exp_res[i] = exp_table[softmax_stable_idx_from_real_val<data_T, CONFIG_T>(d_xi_xmax[i])];
166175
}
167176

168177
// Explicitly sum previously calculated exponentials with an adder tree

0 commit comments

Comments
 (0)