Skip to content

Commit 32ee8dc

Browse files
authored
Merge pull request #627 from bo3z/argmax-softmax
Argmax Softmax
2 parents d31d938 + 2df478d commit 32ee8dc

File tree

9 files changed

+159
-9
lines changed

9 files changed

+159
-9
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from hls4ml.model.layers import Softmax
2+
from hls4ml.model.optimizer.optimizer import OptimizerPass
3+
4+
class SkipSoftmax(OptimizerPass):
5+
def match(self, node):
6+
is_softmax = isinstance(node, Softmax)
7+
remove_softmax = node.get_attr('skip', False)
8+
return is_softmax and remove_softmax
9+
10+
def transform(self, model, node):
11+
model.remove_node(node, rewire=True)
12+
return True

hls4ml/backends/quartus/quartus_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def _register_flows(self):
6262
optimization_passes = [
6363
'quartus:remove_final_reshape',
6464
'quartus:optimize_pointwise_conv',
65+
'quartus:skip_softmax'
6566
]
6667
optimization_flow = register_flow('optimize', optimization_passes, requires=[init_flow], backend=self.name)
6768

hls4ml/backends/vivado/vivado_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def _register_flows(self):
5252
optimization_passes = [
5353
'vivado:remove_final_reshape',
5454
'vivado:optimize_pointwise_conv',
55+
'vivado:skip_softmax'
5556
]
5657
optimization_flow = register_flow('optimize', optimization_passes, requires=[init_flow], backend=self.name)
5758

hls4ml/model/layers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,8 @@ def initialize(self):
697697

698698
class Softmax(Activation):
699699
_expected_attributes = [
700-
ChoiceAttribute('implementation', ['latency', 'stable', 'legacy'], default='stable')
700+
ChoiceAttribute('implementation', ['latency', 'stable', 'argmax', 'legacy'], default='stable'),
701+
Attribute('skip', value_type=bool, default=False),
701702
]
702703

703704
def initialize(self):

hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
#ifndef NNET_ACTIVATION_H_
2121
#define NNET_ACTIVATION_H_
2222

23-
//#include <cmath>
2423
#include "nnet_common.h"
2524

2625
namespace nnet {
@@ -127,7 +126,7 @@ void sigmoid(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in])
127126
// Softmax Activation
128127
// *************************************************
129128

130-
enum class softmax_implementation {latency=0, legacy=1, stable=2};
129+
enum class softmax_implementation {latency=0, legacy=1, stable=2, argmax=3};
131130

132131
template<class data_T, typename CONFIG_T>
133132
inline unsigned softmax_stable_idx_from_real_val(const data_T x){
@@ -242,6 +241,27 @@ void softmax_legacy(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) {
242241
}
243242
}
244243

244+
template<class data_T, class res_T, typename CONFIG_T>
245+
void softmax_argmax(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) {
246+
#pragma unroll
247+
for (int i = 0; i < CONFIG_T::n_in; i++) {
248+
res[i] = (res_T) 0;
249+
}
250+
251+
hls_register data_T maximum = data[0];
252+
hls_register int idx = 0;
253+
254+
#pragma ii 1
255+
for (int i = 1; i < CONFIG_T::n_in; i++) {
256+
if (data[i] > maximum) {
257+
maximum = data[i];
258+
idx = i;
259+
}
260+
}
261+
262+
res[idx] = (res_T) 1;
263+
}
264+
245265
template<class data_T, class res_T, typename CONFIG_T>
246266
inline void softmax(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){
247267
switch(CONFIG_T::implementation) {
@@ -257,6 +277,9 @@ inline void softmax(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){
257277
default:
258278
softmax_stable<data_T, res_T, CONFIG_T>(data, res);
259279
break;
280+
case softmax_implementation::argmax:
281+
softmax_argmax<data_T, res_T, CONFIG_T>(data, res);
282+
break;
260283
}
261284
}
262285

hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation_stream.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,34 @@ void softmax_legacy(stream<data_T> &data, stream<res_T> &res) {
417417
}
418418
}
419419

420+
template<class data_T, class res_T, typename CONFIG_T>
421+
void softmax_argmax(stream<data_T> &data, stream<res_T> &res) {
422+
#pragma ii 1
423+
for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) {
424+
data_T in_data = data.read();
425+
res_T out_data;
426+
427+
#pragma unroll
428+
for (int i = 0; i < res_T::size; i++) {
429+
out_data[i] = (typename res_T::value_type) 0;
430+
}
431+
432+
hls_register typename data_T::value_type maximum = in_data[0];
433+
hls_register int idx = 0;
434+
435+
#pragma ii 1
436+
for (int i = 1; i < res_T::size; i++) {
437+
if (in_data[i] > maximum) {
438+
maximum = in_data[i];
439+
idx = i;
440+
}
441+
}
442+
443+
out_data[idx] = (typename res_T::value_type) 1;
444+
res.write(out_data);
445+
}
446+
}
447+
420448
template<class data_T, class res_T, typename CONFIG_T>
421449
void softmax(stream<data_T> &data, stream<res_T> &res) {
422450
switch(CONFIG_T::implementation) {
@@ -429,6 +457,9 @@ void softmax(stream<data_T> &data, stream<res_T> &res) {
429457
case softmax_implementation::legacy:
430458
softmax_legacy<data_T, res_T, CONFIG_T>(data, res);
431459
break;
460+
case softmax_implementation::argmax:
461+
softmax_argmax<data_T, res_T, CONFIG_T>(data, res);
462+
break;
432463
default:
433464
softmax_stable<data_T, res_T, CONFIG_T>(data, res);
434465
break;

hls4ml/templates/vivado/nnet_utils/nnet_activation.h

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ void sigmoid(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in])
155155
// Softmax Activation
156156
// *************************************************
157157

158-
enum class softmax_implementation {latency=0, legacy=1, stable=2};
158+
enum class softmax_implementation {latency=0, legacy=1, stable=2, argmax=3};
159159

160160
inline float exp_fcn_float(float input) {
161161
return std::exp(input);
@@ -382,6 +382,27 @@ void softmax_legacy(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in])
382382

383383
}
384384

385+
template<class data_T, class res_T, typename CONFIG_T>
386+
void softmax_argmax(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) {
387+
for (int i = 0; i < CONFIG_T::n_in; i++) {
388+
#pragma HLS UNROLL
389+
res[i] = (res_T) 0;
390+
}
391+
392+
data_T maximum = data[0];
393+
int idx = 0;
394+
395+
for (int i = 1; i < CONFIG_T::n_in; i++) {
396+
#pragma HLS PIPELINE
397+
if (data[i] > maximum) {
398+
maximum = data[i];
399+
idx = i;
400+
}
401+
}
402+
403+
res[idx] = (res_T) 1;
404+
}
405+
385406
template<class data_T, class res_T, typename CONFIG_T>
386407
void softmax(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){
387408
#pragma HLS inline
@@ -395,6 +416,9 @@ void softmax(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){
395416
case softmax_implementation::legacy:
396417
softmax_legacy<data_T, res_T, CONFIG_T>(data, res);
397418
break;
419+
case softmax_implementation::argmax:
420+
softmax_argmax<data_T, res_T, CONFIG_T>(data, res);
421+
break;
398422
}
399423
}
400424

@@ -776,4 +800,4 @@ void ternary_tanh(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in])
776800

777801
}
778802

779-
#endif
803+
#endif

hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,35 @@ void softmax_legacy(hls::stream<data_T> &data, hls::stream<res_T> &res) {
314314
}
315315
}
316316

317+
template<class data_T, class res_T, typename CONFIG_T>
318+
void softmax_argmax(hls::stream<data_T> &data, hls::stream<res_T> &res) {
319+
for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) {
320+
#pragma HLS PIPELINE
321+
data_T in_data = data.read();
322+
res_T out_data;
323+
324+
for (int i = 0; i < res_T::size; i++) {
325+
#pragma HLS UNROLL
326+
out_data[i] = (typename res_T::value_type) 0;
327+
}
328+
329+
typename data_T::value_type maximum = in_data[0];
330+
int idx = 0;
331+
332+
for (int i = 1; i < res_T::size; i++) {
333+
#pragma HLS PIPELINE
334+
if (in_data[i] > maximum) {
335+
maximum = in_data[i];
336+
idx = i;
337+
}
338+
}
339+
340+
out_data[idx] = (typename res_T::value_type) 1;
341+
res.write(out_data);
342+
}
343+
}
344+
345+
317346
template<class data_T, class res_T, typename CONFIG_T>
318347
void softmax(hls::stream<data_T> &data, hls::stream<res_T> &res){
319348
assert(CONFIG_T::axis == -1);
@@ -328,7 +357,10 @@ void softmax(hls::stream<data_T> &data, hls::stream<res_T> &res){
328357
case softmax_implementation::legacy:
329358
softmax_legacy<data_T, res_T, CONFIG_T>(data, res);
330359
break;
331-
}
360+
case softmax_implementation::argmax:
361+
softmax_argmax<data_T, res_T, CONFIG_T>(data, res);
362+
break;
363+
}
332364
}
333365

334366
// *************************************************
@@ -681,7 +713,6 @@ void ternary_tanh(hls::stream<data_T> &data, hls::stream<res_T> &res) {
681713
}
682714

683715

684-
685716
}
686717

687-
#endif
718+
#endif

test/pytest/test_softmax.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def generate_data(function, input_shape):
2424
return function((1000, *input_shape))
2525

2626
@pytest.mark.parametrize('backend', ['Vivado', 'Quartus'])
27-
@pytest.mark.parametrize('strategy', ['stable'])
27+
@pytest.mark.parametrize('strategy', ['stable', 'argmax'])
2828
@pytest.mark.parametrize('function,input_shape,io_type', [
2929
(flat_distribution, (8,), 'io_parallel'),
3030
(high_accuracy_distribution, (8,), 'io_parallel'),
@@ -57,3 +57,29 @@ def test_softmax(backend, strategy, generate_data, input_shape, io_type, functio
5757
print('Accuracy hls4ml relative to keras: {}'.format(acc_hls4ml))
5858

5959
assert acc_hls4ml >= 0.98
60+
61+
@pytest.mark.parametrize('backend', ['Vivado', 'Quartus'])
62+
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
63+
def test_softmax_skipped(backend, io_type):
64+
X = np.random.rand(100, 10)
65+
model = tf.keras.models.Sequential()
66+
model.add(tf.keras.layers.Dense(14, input_shape=(10, ), name='dense'))
67+
model.add(tf.keras.layers.Activation(activation='softmax', name='softmax'))
68+
model.compile()
69+
70+
cfg = hls4ml.utils.config_from_keras_model(model, granularity='name')
71+
cfg['LayerName']['softmax']['skip'] = True
72+
73+
odir = str(test_root_path / 'hls4mlprj_softmax_skipped_{}_{}').format(backend, io_type)
74+
hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=cfg, io_type=io_type, output_dir=odir, backend=backend)
75+
hls_model.compile()
76+
77+
# Verify Softmax was removed
78+
hls_layers = list(hls_model.get_layers()) # 0 is Input, 1 is Dense, 2 is Softmax (if not removed)
79+
assert len(hls_layers)==2
80+
81+
# Verify hls4ml output is equal to Dense output
82+
y_keras = model.predict(X)
83+
y_hls4ml = hls_model.predict(X).reshape(y_keras.shape)
84+
keras_trace = hls4ml.model.profiling.get_ymodel_keras(model, X)
85+
np.testing.assert_allclose(y_hls4ml, keras_trace['dense'], rtol=0, atol=2e-2)

0 commit comments

Comments
 (0)