@@ -24,7 +24,7 @@ def generate_data(function, input_shape):
2424 return function ((1000 , * input_shape ))
2525
2626@pytest .mark .parametrize ('backend' , ['Vivado' , 'Quartus' ])
27- @pytest .mark .parametrize ('strategy' , ['stable' ])
27+ @pytest .mark .parametrize ('strategy' , ['stable' , 'argmax' ])
2828@pytest .mark .parametrize ('function,input_shape,io_type' , [
2929 (flat_distribution , (8 ,), 'io_parallel' ),
3030 (high_accuracy_distribution , (8 ,), 'io_parallel' ),
@@ -57,3 +57,29 @@ def test_softmax(backend, strategy, generate_data, input_shape, io_type, functio
5757 print ('Accuracy hls4ml relative to keras: {}' .format (acc_hls4ml ))
5858
5959 assert acc_hls4ml >= 0.98
60+
61+ @pytest .mark .parametrize ('backend' , ['Vivado' , 'Quartus' ])
62+ @pytest .mark .parametrize ('io_type' , ['io_parallel' , 'io_stream' ])
63+ def test_softmax_skipped (backend , io_type ):
64+ X = np .random .rand (100 , 10 )
65+ model = tf .keras .models .Sequential ()
66+ model .add (tf .keras .layers .Dense (14 , input_shape = (10 , ), name = 'dense' ))
67+ model .add (tf .keras .layers .Activation (activation = 'softmax' , name = 'softmax' ))
68+ model .compile ()
69+
70+ cfg = hls4ml .utils .config_from_keras_model (model , granularity = 'name' )
71+ cfg ['LayerName' ]['softmax' ]['skip' ] = True
72+
73+ odir = str (test_root_path / 'hls4mlprj_softmax_skipped_{}_{}' ).format (backend , io_type )
74+ hls_model = hls4ml .converters .convert_from_keras_model (model , hls_config = cfg , io_type = io_type , output_dir = odir , backend = backend )
75+ hls_model .compile ()
76+
77+ # Verify Softmax was removed
78+ hls_layers = list (hls_model .get_layers ()) # 0 is Input, 1 is Dense, 2 is Softmax (if not removed)
79+ assert len (hls_layers )== 2
80+
81+ # Verify hls4ml output is equal to Dense output
82+ y_keras = model .predict (X )
83+ y_hls4ml = hls_model .predict (X ).reshape (y_keras .shape )
84+ keras_trace = hls4ml .model .profiling .get_ymodel_keras (model , X )
85+ np .testing .assert_allclose (y_hls4ml , keras_trace ['dense' ], rtol = 0 , atol = 2e-2 )
0 commit comments