@@ -314,6 +314,35 @@ void softmax_legacy(hls::stream<data_T> &data, hls::stream<res_T> &res) {
314314 }
315315}
316316
317+ template <class data_T , class res_T , typename CONFIG_T>
318+ void softmax_argmax (hls::stream<data_T> &data, hls::stream<res_T> &res) {
319+ for (int i = 0 ; i < CONFIG_T::n_in / res_T::size; i++) {
320+ #pragma HLS PIPELINE
321+ data_T in_data = data.read ();
322+ res_T out_data;
323+
324+ for (int i = 0 ; i < CONFIG_T::n_in; i++) {
325+ #pragma HLS UNROLL
326+ out_data[i] = (typename res_T::value_type) 0 ;
327+ }
328+
329+ typename data_T::value_type maximum = in_data[0 ];
330+ int idx = 0 ;
331+
332+ for (int i = 1 ; i < CONFIG_T::n_in; i++) {
333+ #pragma HLS PIPELINE
334+ if (in_data[i] > maximum) {
335+ maximum = in_data[i];
336+ idx = i;
337+ }
338+ }
339+
340+ out_data[idx] = (typename res_T::value_type) 1 ;
341+ res.write (out_data);
342+ }
343+ }
344+
345+
317346template <class data_T , class res_T , typename CONFIG_T>
318347void softmax (hls::stream<data_T> &data, hls::stream<res_T> &res){
319348 assert (CONFIG_T::axis == -1 );
@@ -328,7 +357,10 @@ void softmax(hls::stream<data_T> &data, hls::stream<res_T> &res){
328357 case softmax_implementation::legacy:
329358 softmax_legacy<data_T, res_T, CONFIG_T>(data, res);
330359 break ;
331- }
360+ case softmax_implementation::argmax:
361+ softmax_argmax<data_T, res_T, CONFIG_T>(data, res);
362+ break ;
363+ }
332364}
333365
334366// *************************************************
@@ -637,51 +669,7 @@ void prelu(hls::stream<data_T> &data, typename data_T::value_type alpha[CONFIG_T
637669 }
638670}
639671
640- // *************************************************
641- // Binary TanH Activation
642- // *************************************************
643- template <class data_T , class res_T , typename CONFIG_T>
644- void binary_tanh (hls::stream<data_T> &data, hls::stream<res_T> &res) {
645- PReLUActLoop: for (int i = 0 ; i < CONFIG_T::n_in / res_T::size; i++) {
646- #pragma HLS PIPELINE
647-
648- data_T in_data = data.read ();
649- res_T out_data;
650- #pragma HLS DATA_PACK variable=out_data
651-
652- PReLUPackLoop: for (int j = 0 ; j < res_T::size; j++) {
653- #pragma HLS UNROLL
654- if (in_data[j] > 0 ) out_data[j] = (typename res_T::value_type) 1 ;
655- else out_data[j] = (typename res_T::value_type) -1 ;
656- }
657- res.write (out_data);
658- }
659- }
660-
661- // *************************************************
662- // Ternary TanH Activation
663- // *************************************************
664- template <class data_T , class res_T , typename CONFIG_T>
665- void ternary_tanh (hls::stream<data_T> &data, hls::stream<res_T> &res) {
666- PReLUActLoop: for (int i = 0 ; i < CONFIG_T::n_in / res_T::size; i++) {
667- #pragma HLS PIPELINE
668-
669- data_T in_data = data.read ();
670- res_T out_data;
671- #pragma HLS DATA_PACK variable=out_data
672-
673- PReLUPackLoop: for (int j = 0 ; j < res_T::size; j++) {
674- #pragma HLS UNROLL
675- if (in_data[j] > 1 ) out_data[j] = (typename res_T::value_type) 1 ;
676- else if (in_data[j] <=-1 ) out_data[j] = (typename res_T::value_type) -1 ;
677- else out_data[j] = (typename res_T::value_type) 0 ;
678- }
679- res.write (out_data);
680- }
681- }
682-
683-
684672
685673}
686674
687- #endif
675+ #endif
0 commit comments