1616test_root_path = Path (__file__ ).parent
1717
1818@pytest .mark .parametrize ('backend' , ['Vivado' , 'Quartus' ])
19- def test_dense (backend ):
19+ @pytest .mark .parametrize ('io_type' , ['io_parallel' , 'io_stream' ])
20+ def test_dense (backend , io_type ):
2021 model = tf .keras .models .Sequential ()
2122 model .add (Dense (2 ,
2223 input_shape = (1 ,),
@@ -37,9 +38,9 @@ def test_dense(backend):
3738 keras_prediction = model .predict (X_input )
3839
3940 config = hls4ml .utils .config_from_keras_model (model )
40- output_dir = str (test_root_path / f'hls4mlprj_keras_api_dense_{ backend } ' )
41+ output_dir = str (test_root_path / f'hls4mlprj_keras_api_dense_{ backend } _ { io_type } ' )
4142
42- hls_model = hls4ml .converters .convert_from_keras_model (model , hls_config = config , output_dir = output_dir , backend = backend )
43+ hls_model = hls4ml .converters .convert_from_keras_model (model , hls_config = config , output_dir = output_dir , backend = backend , io_type = io_type )
4344
4445 hls_model .compile ()
4546
@@ -66,7 +67,8 @@ def test_dense(backend):
6667 Activation (activation = 'sigmoid' , name = 'Activation' )])
6768 #ThresholdedReLU(theta=1.0)])
6869@pytest .mark .parametrize ('backend' , ['Vivado' , 'Quartus' ])
69- def test_activations (activation_function , backend ):
70+ @pytest .mark .parametrize ('io_type' , ['io_parallel' , 'io_stream' ])
71+ def test_activations (activation_function , backend , io_type ):
7072 model = tf .keras .models .Sequential ()
7173 model .add (Dense (64 ,
7274 input_shape = (1 ,),
@@ -79,8 +81,8 @@ def test_activations(activation_function, backend):
7981 X_input = np .random .rand (100 ,1 )
8082 keras_prediction = model .predict (X_input )
8183 config = hls4ml .utils .config_from_keras_model (model )
82- output_dir = str (test_root_path / 'hls4mlprj_keras_api_activations_{}_{}' .format (activation_function .__class__ .__name__ , backend ))
83- hls_model = hls4ml .converters .convert_from_keras_model (model , hls_config = config , output_dir = output_dir , backend = backend )
84+ output_dir = str (test_root_path / 'hls4mlprj_keras_api_activations_{}_{}_{} ' .format (activation_function .__class__ .__name__ , backend , io_type ))
85+ hls_model = hls4ml .converters .convert_from_keras_model (model , hls_config = config , output_dir = output_dir , backend = backend , io_type = io_type )
8486 hls_model .compile ()
8587 hls_prediction = hls_model .predict (X_input )
8688
0 commit comments