Skip to content

Commit d3f5858

Browse files
committed
Fix Argmax io_stream implementation
1 parent 4019192 commit d3f5858

File tree

2 files changed

+49
-6
lines changed

2 files changed

+49
-6
lines changed

hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation_stream.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -406,22 +406,22 @@ void softmax_legacy(stream<data_T> &data, stream<res_T> &res) {
406406
}
407407

408408
template<class data_T, class res_T, typename CONFIG_T>
409-
void softmax_argmax(stream<data_T> &data, stream<res_T> &res) {
409+
void softmax_argmax(stream<data_T> &data, stream<res_T> &res) {
410410
#pragma ii 1
411411
for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) {
412412
data_T in_data = data.read();
413413
res_T out_data;
414414

415415
#pragma unroll
416-
for (int i = 0; i < CONFIG_T::n_in; i++) {
416+
for (int i = 0; i < res_T::size; i++) {
417417
out_data[i] = (typename res_T::value_type) 0;
418418
}
419419

420420
hls_register typename data_T::value_type maximum = in_data[0];
421421
hls_register int idx = 0;
422422

423423
#pragma ii 1
424-
for (int i = 1; i < CONFIG_T::n_in; i++) {
424+
for (int i = 1; i < res_T::size; i++) {
425425
if (in_data[i] > maximum) {
426426
maximum = in_data[i];
427427
idx = i;

hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,21 +315,21 @@ void softmax_legacy(hls::stream<data_T> &data, hls::stream<res_T> &res) {
315315
}
316316

317317
template<class data_T, class res_T, typename CONFIG_T>
318-
void softmax_argmax(hls::stream<data_T> &data, hls::stream<res_T> &res) {
318+
void softmax_argmax(hls::stream<data_T> &data, hls::stream<res_T> &res) {
319319
for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) {
320320
#pragma HLS PIPELINE
321321
data_T in_data = data.read();
322322
res_T out_data;
323323

324-
for (int i = 0; i < CONFIG_T::n_in; i++) {
324+
for (int i = 0; i < res_T::size; i++) {
325325
#pragma HLS UNROLL
326326
out_data[i] = (typename res_T::value_type) 0;
327327
}
328328

329329
typename data_T::value_type maximum = in_data[0];
330330
int idx = 0;
331331

332-
for (int i = 1; i < CONFIG_T::n_in; i++) {
332+
for (int i = 1; i < res_T::size; i++) {
333333
#pragma HLS PIPELINE
334334
if (in_data[i] > maximum) {
335335
maximum = in_data[i];
@@ -669,6 +669,49 @@ void prelu(hls::stream<data_T> &data, typename data_T::value_type alpha[CONFIG_T
669669
}
670670
}
671671

672+
// *************************************************
673+
// Binary TanH Activation
674+
// *************************************************
675+
template<class data_T, class res_T, typename CONFIG_T>
676+
void binary_tanh(hls::stream<data_T> &data, hls::stream<res_T> &res) {
677+
PReLUActLoop: for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) {
678+
#pragma HLS PIPELINE
679+
680+
data_T in_data = data.read();
681+
res_T out_data;
682+
#pragma HLS DATA_PACK variable=out_data
683+
684+
PReLUPackLoop: for (int j = 0; j < res_T::size; j++) {
685+
#pragma HLS UNROLL
686+
if(in_data[j] > 0) out_data[j] = (typename res_T::value_type) 1;
687+
else out_data[j] = (typename res_T::value_type) -1;
688+
}
689+
res.write(out_data);
690+
}
691+
}
692+
693+
// *************************************************
694+
// Ternary TanH Activation
695+
// *************************************************
696+
template<class data_T, class res_T, typename CONFIG_T>
697+
void ternary_tanh(hls::stream<data_T> &data, hls::stream<res_T> &res) {
698+
PReLUActLoop: for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) {
699+
#pragma HLS PIPELINE
700+
701+
data_T in_data = data.read();
702+
res_T out_data;
703+
#pragma HLS DATA_PACK variable=out_data
704+
705+
PReLUPackLoop: for (int j = 0; j < res_T::size; j++) {
706+
#pragma HLS UNROLL
707+
if(in_data[j] > 1) out_data[j] = (typename res_T::value_type) 1;
708+
else if (in_data[j] <=-1) out_data[j] = (typename res_T::value_type) -1;
709+
else out_data[j] = (typename res_T::value_type) 0;
710+
}
711+
res.write(out_data);
712+
}
713+
}
714+
672715

673716
}
674717

0 commit comments

Comments
 (0)