Skip to content

Commit 4019192

Browse files
committed
Argmax & Skipped Softmax
1 parent 62046d7 commit 4019192

File tree

9 files changed

+165
-54
lines changed

9 files changed

+165
-54
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from hls4ml.model.layers import Softmax
2+
from hls4ml.model.optimizer.optimizer import OptimizerPass
3+
4+
class SkipSoftmax(OptimizerPass):
5+
def match(self, node):
6+
is_softmax = isinstance(node, Softmax)
7+
remove_softmax = node.get_attr('skip', False)
8+
return is_softmax and remove_softmax
9+
10+
def transform(self, model, node):
11+
model.remove_node(node, rewire=True)
12+
return True

hls4ml/backends/quartus/quartus_backend.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ def _register_flows(self):
4343
]
4444
quantization_flow = register_flow('quantization', quantization_passes, requires=[init_flow], backend=self.name)
4545

46+
optimization_passes = [
47+
'quartus:skip_softmax',
48+
]
49+
optimization_flow = register_flow('optimize', optimization_passes, requires=[init_flow], backend=self.name)
50+
4651
templates = self._get_layer_templates()
4752
template_flow = register_flow('apply_templates', templates, requires=[init_flow], backend=self.name)
4853

@@ -57,15 +62,15 @@ def _register_flows(self):
5762

5863
extras = [
5964
# Ideally this should be empty
60-
opt_pass for opt_pass in all_passes if opt_pass not in initializers + quartus_types + templates + writer_passes
65+
opt_pass for opt_pass in all_passes if opt_pass not in initializers + quartus_types + optimization_passes + templates + writer_passes
6166
]
6267

6368
if len(extras) > 0:
6469
extras_flow = register_flow('extras', extras, requires=[init_flow], backend=self.name)
6570
else:
6671
extras_flow = None
6772

68-
ip_flow_requirements = ['optimize', init_flow, streaming_flow, quantization_flow, quartus_types_flow, extras_flow, template_flow]
73+
ip_flow_requirements = ['optimize', init_flow, streaming_flow, quantization_flow, optimization_flow, quartus_types_flow, extras_flow, template_flow]
6974
ip_flow_requirements = list(filter(None, ip_flow_requirements))
7075

7176
self._default_flow = register_flow('ip', None, requires=ip_flow_requirements, backend=self.name)

hls4ml/backends/vivado/vivado_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def _register_flows(self):
5252

5353
optimization_passes = [
5454
'vivado:optimize_pointwise_conv',
55+
'vivado:skip_softmax'
5556
]
5657
optimization_flow = register_flow('optimize', optimization_passes, requires=[init_flow], backend=self.name)
5758

hls4ml/model/layers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,8 @@ def initialize(self):
698698

699699
class Softmax(Activation):
700700
_expected_attributes = [
701-
ChoiceAttribute('implementation', ['latency', 'stable', 'legacy'], default='stable')
701+
ChoiceAttribute('implementation', ['latency', 'stable', 'argmax', 'legacy'], default='stable'),
702+
Attribute('skip', value_type=bool, default=False),
702703
]
703704

704705
def initialize(self):

hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
#ifndef NNET_ACTIVATION_H_
2121
#define NNET_ACTIVATION_H_
2222

23-
//#include <cmath>
2423
#include "nnet_common.h"
2524

2625
namespace nnet {
@@ -127,7 +126,7 @@ void sigmoid(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in])
127126
// Softmax Activation
128127
// *************************************************
129128

130-
enum class softmax_implementation {latency=0, legacy=1, stable=2};
129+
enum class softmax_implementation {latency=0, legacy=1, stable=2, argmax=3};
131130

132131
template<class data_T, typename CONFIG_T>
133132
inline unsigned softmax_idx_from_real_val(const data_T x){
@@ -248,6 +247,27 @@ void softmax_legacy(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) {
248247
}
249248
}
250249

250+
template<class data_T, class res_T, typename CONFIG_T>
251+
void softmax_argmax(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) {
252+
#pragma unroll
253+
for (int i = 0; i < CONFIG_T::n_in; i++) {
254+
res[i] = (res_T) 0;
255+
}
256+
257+
hls_register data_T maximum = data[0];
258+
hls_register int idx = 0;
259+
260+
#pragma ii 1
261+
for (int i = 1; i < CONFIG_T::n_in; i++) {
262+
if (data[i] > maximum) {
263+
maximum = data[i];
264+
idx = i;
265+
}
266+
}
267+
268+
res[idx] = (res_T) 1;
269+
}
270+
251271
template<class data_T, class res_T, typename CONFIG_T>
252272
inline void softmax(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){
253273
switch(CONFIG_T::implementation) {
@@ -263,6 +283,9 @@ inline void softmax(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){
263283
default:
264284
softmax_stable<data_T, res_T, CONFIG_T>(data, res);
265285
break;
286+
case softmax_implementation::argmax:
287+
softmax_argmax<data_T, res_T, CONFIG_T>(data, res);
288+
break;
266289
}
267290
}
268291

hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation_stream.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,34 @@ void softmax_legacy(stream<data_T> &data, stream<res_T> &res) {
405405
}
406406
}
407407

408+
template<class data_T, class res_T, typename CONFIG_T>
409+
void softmax_argmax(stream<data_T> &data, stream<res_T> &res) {
410+
#pragma ii 1
411+
for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) {
412+
data_T in_data = data.read();
413+
res_T out_data;
414+
415+
#pragma unroll
416+
for (int i = 0; i < CONFIG_T::n_in; i++) {
417+
out_data[i] = (typename res_T::value_type) 0;
418+
}
419+
420+
hls_register typename data_T::value_type maximum = in_data[0];
421+
hls_register int idx = 0;
422+
423+
#pragma ii 1
424+
for (int i = 1; i < CONFIG_T::n_in; i++) {
425+
if (in_data[i] > maximum) {
426+
maximum = in_data[i];
427+
idx = i;
428+
}
429+
}
430+
431+
out_data[idx] = (typename res_T::value_type) 1;
432+
res.write(out_data);
433+
}
434+
}
435+
408436
template<class data_T, class res_T, typename CONFIG_T>
409437
void softmax(stream<data_T> &data, stream<res_T> &res) {
410438
switch(CONFIG_T::implementation) {
@@ -417,6 +445,9 @@ void softmax(stream<data_T> &data, stream<res_T> &res) {
417445
case softmax_implementation::legacy:
418446
softmax_legacy<data_T, res_T, CONFIG_T>(data, res);
419447
break;
448+
case softmax_implementation::argmax:
449+
softmax_argmax<data_T, res_T, CONFIG_T>(data, res);
450+
break;
420451
default:
421452
softmax_stable<data_T, res_T, CONFIG_T>(data, res);
422453
break;

hls4ml/templates/vivado/nnet_utils/nnet_activation.h

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ void sigmoid(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in])
155155
// Softmax Activation
156156
// *************************************************
157157

158-
enum class softmax_implementation {latency=0, legacy=1, stable=2};
158+
enum class softmax_implementation {latency=0, legacy=1, stable=2, argmax=3};
159159

160160
inline float exp_fcn_float(float input) {
161161
return std::exp(input);
@@ -382,6 +382,27 @@ void softmax_legacy(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in])
382382

383383
}
384384

385+
template<class data_T, class res_T, typename CONFIG_T>
386+
void softmax_argmax(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) {
387+
for (int i = 0; i < CONFIG_T::n_in; i++) {
388+
#pragma HLS UNROLL
389+
res[i] = (res_T) 0;
390+
}
391+
392+
data_T maximum = data[0];
393+
int idx = 0;
394+
395+
for (int i = 1; i < CONFIG_T::n_in; i++) {
396+
#pragma HLS PIPELINE
397+
if (data[i] > maximum) {
398+
maximum = data[i];
399+
idx = i;
400+
}
401+
}
402+
403+
res[idx] = (res_T) 1;
404+
}
405+
385406
template<class data_T, class res_T, typename CONFIG_T>
386407
void softmax(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){
387408
#pragma HLS inline
@@ -395,6 +416,9 @@ void softmax(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){
395416
case softmax_implementation::legacy:
396417
softmax_legacy<data_T, res_T, CONFIG_T>(data, res);
397418
break;
419+
case softmax_implementation::argmax:
420+
softmax_argmax<data_T, res_T, CONFIG_T>(data, res);
421+
break;
398422
}
399423
}
400424

@@ -776,4 +800,4 @@ void ternary_tanh(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in])
776800

777801
}
778802

779-
#endif
803+
#endif

hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h

Lines changed: 34 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,35 @@ void softmax_legacy(hls::stream<data_T> &data, hls::stream<res_T> &res) {
314314
}
315315
}
316316

317+
template<class data_T, class res_T, typename CONFIG_T>
318+
void softmax_argmax(hls::stream<data_T> &data, hls::stream<res_T> &res) {
319+
for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) {
320+
#pragma HLS PIPELINE
321+
data_T in_data = data.read();
322+
res_T out_data;
323+
324+
for (int i = 0; i < CONFIG_T::n_in; i++) {
325+
#pragma HLS UNROLL
326+
out_data[i] = (typename res_T::value_type) 0;
327+
}
328+
329+
typename data_T::value_type maximum = in_data[0];
330+
int idx = 0;
331+
332+
for (int i = 1; i < CONFIG_T::n_in; i++) {
333+
#pragma HLS PIPELINE
334+
if (in_data[i] > maximum) {
335+
maximum = in_data[i];
336+
idx = i;
337+
}
338+
}
339+
340+
out_data[idx] = (typename res_T::value_type) 1;
341+
res.write(out_data);
342+
}
343+
}
344+
345+
317346
template<class data_T, class res_T, typename CONFIG_T>
318347
void softmax(hls::stream<data_T> &data, hls::stream<res_T> &res){
319348
assert(CONFIG_T::axis == -1);
@@ -328,7 +357,10 @@ void softmax(hls::stream<data_T> &data, hls::stream<res_T> &res){
328357
case softmax_implementation::legacy:
329358
softmax_legacy<data_T, res_T, CONFIG_T>(data, res);
330359
break;
331-
}
360+
case softmax_implementation::argmax:
361+
softmax_argmax<data_T, res_T, CONFIG_T>(data, res);
362+
break;
363+
}
332364
}
333365

334366
// *************************************************
@@ -637,51 +669,7 @@ void prelu(hls::stream<data_T> &data, typename data_T::value_type alpha[CONFIG_T
637669
}
638670
}
639671

640-
// *************************************************
641-
// Binary TanH Activation
642-
// *************************************************
643-
template<class data_T, class res_T, typename CONFIG_T>
644-
void binary_tanh(hls::stream<data_T> &data, hls::stream<res_T> &res) {
645-
PReLUActLoop: for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) {
646-
#pragma HLS PIPELINE
647-
648-
data_T in_data = data.read();
649-
res_T out_data;
650-
#pragma HLS DATA_PACK variable=out_data
651-
652-
PReLUPackLoop: for (int j = 0; j < res_T::size; j++) {
653-
#pragma HLS UNROLL
654-
if(in_data[j] > 0) out_data[j] = (typename res_T::value_type) 1;
655-
else out_data[j] = (typename res_T::value_type) -1;
656-
}
657-
res.write(out_data);
658-
}
659-
}
660-
661-
// *************************************************
662-
// Ternary TanH Activation
663-
// *************************************************
664-
template<class data_T, class res_T, typename CONFIG_T>
665-
void ternary_tanh(hls::stream<data_T> &data, hls::stream<res_T> &res) {
666-
PReLUActLoop: for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) {
667-
#pragma HLS PIPELINE
668-
669-
data_T in_data = data.read();
670-
res_T out_data;
671-
#pragma HLS DATA_PACK variable=out_data
672-
673-
PReLUPackLoop: for (int j = 0; j < res_T::size; j++) {
674-
#pragma HLS UNROLL
675-
if(in_data[j] > 1) out_data[j] = (typename res_T::value_type) 1;
676-
else if (in_data[j] <=-1) out_data[j] = (typename res_T::value_type) -1;
677-
else out_data[j] = (typename res_T::value_type) 0;
678-
}
679-
res.write(out_data);
680-
}
681-
}
682-
683-
684672

685673
}
686674

687-
#endif
675+
#endif

test/pytest/test_softmax.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def generate_data(function, input_shape):
2424
return function((1000, *input_shape))
2525

2626
@pytest.mark.parametrize('backend', ['Vivado', 'Quartus'])
27-
@pytest.mark.parametrize('strategy', ['stable'])
27+
@pytest.mark.parametrize('strategy', ['stable', 'argmax'])
2828
@pytest.mark.parametrize('function,input_shape,io_type', [
2929
(flat_distribution, (8,), 'io_parallel'),
3030
(high_accuracy_distribution, (8,), 'io_parallel'),
@@ -57,3 +57,29 @@ def test_softmax(backend, strategy, generate_data, input_shape, io_type, functio
5757
print('Accuracy hls4ml relative to keras: {}'.format(acc_hls4ml))
5858

5959
assert acc_hls4ml >= 0.98
60+
61+
@pytest.mark.parametrize('backend', ['Vivado', 'Quartus'])
62+
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
63+
def test_softmax_skipped(backend, io_type):
64+
X = np.random.rand(100, 10)
65+
model = tf.keras.models.Sequential()
66+
model.add(tf.keras.layers.Dense(14, input_shape=(10, ), name='dense'))
67+
model.add(tf.keras.layers.Activation(activation='softmax', name='softmax'))
68+
model.compile()
69+
70+
cfg = hls4ml.utils.config_from_keras_model(model, granularity='name')
71+
cfg['LayerName']['softmax']['skip'] = True
72+
73+
odir = str(test_root_path / 'hls4mlprj_softmax_skipped_{}_{}').format(backend, io_type)
74+
hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=cfg, io_type=io_type, output_dir=odir, backend=backend)
75+
hls_model.compile()
76+
77+
# Verify Softmax was removed
78+
hls_layers = list(hls_model.get_layers()) # 0 is Input, 1 is Dense, 2 is Softmax (if not removed)
79+
assert len(hls_layers)==2
80+
81+
# Verify hls4ml output is equal to Dense output
82+
y_keras = model.predict(X)
83+
y_hls4ml = hls_model.predict(X).reshape(y_keras.shape)
84+
keras_trace = hls4ml.model.profiling.get_ymodel_keras(model, X)
85+
np.testing.assert_allclose(y_hls4ml, keras_trace['dense'], rtol=0, atol=2e-2)

0 commit comments

Comments
 (0)