Skip to content

Commit a3fdb75

Browse files
authored
Merge pull request #19 from Gab-San/bambu/pytest-fixes
fastmachinelearning/hls4ml PR fastmachinelearning#1417 fixes
2 parents 8795238 + 1a2118a commit a3fdb75

File tree

6 files changed

+105
-57
lines changed

6 files changed

+105
-57
lines changed

test/pytest/test_clone_flatten.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@ def keras_model():
2727

2828

2929
@pytest.fixture
30-
@pytest.mark.parametrize('io_type', ['io_stream'])
31-
@pytest.mark.parametrize('backend', ['Vivado', 'Quartus', 'Catapult'])
32-
def hls_model(keras_model, backend, io_type):
30+
def hls_model(keras_model, request):
31+
io_type, backend = request.param
3332
hls_config = hls4ml.utils.config_from_keras_model(
3433
keras_model, default_precision='ap_int<6>', granularity='name', backend=backend
3534
)
@@ -46,8 +45,12 @@ def hls_model(keras_model, backend, io_type):
4645
return hls_model
4746

4847

49-
@pytest.mark.parametrize('io_type', ['io_stream'])
50-
@pytest.mark.parametrize('backend', ['Vivado', 'Quartus'])
48+
@pytest.mark.parametrize(
49+
'hls_model',
50+
[('io_stream', 'Vivado'), ('io_stream', 'Quartus')],
51+
indirect=True,
52+
ids=['io_stream_Vivado', 'io_stream_Quartus'],
53+
)
5154
def test_accuracy(data, keras_model, hls_model):
5255
X = data
5356
model = keras_model

test/pytest/test_cnn_mnist_qkeras.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,22 +41,8 @@ def mnist_model():
4141

4242

4343
@pytest.fixture
44-
@pytest.mark.parametrize(
45-
'backend,io_type,strategy',
46-
[
47-
('Quartus', 'io_parallel', 'resource'),
48-
('Quartus', 'io_stream', 'resource'),
49-
('Vivado', 'io_parallel', 'resource'),
50-
('Vivado', 'io_parallel', 'latency'),
51-
('Vivado', 'io_stream', 'latency'),
52-
('Vivado', 'io_stream', 'resource'),
53-
('Vitis', 'io_parallel', 'resource'),
54-
('Vitis', 'io_parallel', 'latency'),
55-
('Vitis', 'io_stream', 'latency'),
56-
('Vitis', 'io_stream', 'resource'),
57-
],
58-
)
59-
def hls_model(mnist_model, backend, io_type, strategy):
44+
def hls_model(mnist_model, request):
45+
backend, io_type, strategy = request.param
6046
keras_model = mnist_model
6147
hls_config = hls4ml.utils.config_from_keras_model(keras_model, granularity='name', backend=backend)
6248
hls_config['Model']['Strategy'] = strategy
@@ -72,7 +58,7 @@ def hls_model(mnist_model, backend, io_type, strategy):
7258

7359

7460
@pytest.mark.parametrize(
75-
'backend,io_type,strategy',
61+
'hls_model',
7662
[
7763
('Quartus', 'io_parallel', 'resource'),
7864
('Quartus', 'io_stream', 'resource'),
@@ -85,6 +71,19 @@ def hls_model(mnist_model, backend, io_type, strategy):
8571
('Vitis', 'io_stream', 'latency'),
8672
('Vitis', 'io_stream', 'resource'),
8773
],
74+
indirect=True,
75+
ids=[
76+
'Quartus_io_parallel_resource',
77+
'Quartus_io_stream_resource',
78+
'Vivado_io_parallel_resource',
79+
'Vivado_io_parallel_latency',
80+
'Vivado_io_stream_latency',
81+
'Vivado_io_stream_resource',
82+
'Vitis_io_parallel_resource',
83+
'Vitis_io_parallel_latency',
84+
'Vitis_io_stream_latency',
85+
'Vitis_io_stream_resource',
86+
],
8887
)
8988
def test_accuracy(mnist_data, mnist_model, hls_model):
9089
x_train, y_train, x_test, y_test = mnist_data

test/pytest/test_conv1d.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,26 +27,8 @@ def keras_model():
2727

2828

2929
@pytest.fixture
30-
@pytest.mark.parametrize(
31-
'backend,io_type,strategy',
32-
[
33-
('Quartus', 'io_parallel', 'resource'),
34-
('Quartus', 'io_stream', 'resource'),
35-
('oneAPI', 'io_parallel', 'resource'),
36-
('oneAPI', 'io_stream', 'resource'),
37-
('Vivado', 'io_parallel', 'resource'),
38-
('Vivado', 'io_parallel', 'latency'),
39-
('Vivado', 'io_stream', 'latency'),
40-
('Vivado', 'io_stream', 'resource'),
41-
('Vitis', 'io_parallel', 'resource'),
42-
('Vitis', 'io_parallel', 'latency'),
43-
('Vitis', 'io_stream', 'latency'),
44-
('Vitis', 'io_stream', 'resource'),
45-
('Catapult', 'io_stream', 'latency'),
46-
('Catapult', 'io_stream', 'resource'),
47-
],
48-
)
49-
def hls_model(keras_model, backend, io_type, strategy):
30+
def hls_model(keras_model, request):
31+
backend, io_type, strategy = request.param
5032
default_precision = (
5133
'ap_fixed<16,3,AP_RND_CONV,AP_SAT>' if backend == 'Vivado' else 'ac_fixed<16,3,true,AC_RND_CONV,AC_SAT>'
5234
)
@@ -82,7 +64,7 @@ def hls_model(keras_model, backend, io_type, strategy):
8264

8365

8466
@pytest.mark.parametrize(
85-
'backend,io_type,strategy',
67+
'hls_model',
8668
[
8769
('Quartus', 'io_parallel', 'resource'),
8870
('Quartus', 'io_stream', 'resource'),
@@ -99,6 +81,23 @@ def hls_model(keras_model, backend, io_type, strategy):
9981
('Catapult', 'io_stream', 'latency'),
10082
('Catapult', 'io_stream', 'resource'),
10183
],
84+
indirect=True,
85+
ids=[
86+
'Quartus_io_parallel_resource',
87+
'Quartus_io_stream_resource',
88+
'oneAPI_io_parallel_resource',
89+
'oneAPI_io_stream_resource',
90+
'Vivado_io_parallel_resource',
91+
'Vivado_io_parallel_latency',
92+
'Vivado_io_stream_latency',
93+
'Vivado_io_stream_resource',
94+
'Vitis_io_parallel_resource',
95+
'Vitis_io_parallel_latency',
96+
'Vitis_io_stream_latency',
97+
'Vitis_io_stream_resource',
98+
'Catapult_io_stream_latency',
99+
'Catapult_io_stream_resource',
100+
],
102101
)
103102
def test_accuracy(data, keras_model, hls_model):
104103
X = data

test/pytest/test_embed.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,8 @@ def keras_model():
2525

2626

2727
@pytest.fixture
28-
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'Catapult', 'oneAPI'])
29-
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
30-
def hls_model(keras_model, backend, io_type):
28+
def hls_model(keras_model, request):
29+
backend, io_type = request.param
3130
hls_config = hls4ml.utils.config_from_keras_model(
3231
keras_model, default_precision='ap_fixed<16,6>', granularity='name', backend=backend
3332
)
@@ -41,8 +40,34 @@ def hls_model(keras_model, backend, io_type):
4140
return hls_model
4241

4342

44-
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'Catapult', 'oneAPI'])
45-
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
43+
@pytest.mark.parametrize(
44+
'hls_model',
45+
[
46+
('Vivado', 'io_parallel'),
47+
('Vitis', 'io_parallel'),
48+
('Quartus', 'io_parallel'),
49+
('Catapult', 'io_parallel'),
50+
('oneAPI', 'io_parallel'),
51+
('Vivado', 'io_stream'),
52+
('Vitis', 'io_stream'),
53+
('Quartus', 'io_stream'),
54+
('Catapult', 'io_stream'),
55+
('oneAPI', 'io_stream'),
56+
],
57+
ids=[
58+
'vivado_parallel',
59+
'vitis_parallel',
60+
'quartus_parallel',
61+
'catapult_parallel',
62+
'oneapi_parallel',
63+
'vivado_stream',
64+
'vitis_stream',
65+
'quartus_stream',
66+
'catapult_stream',
67+
'oneapi_stream',
68+
],
69+
indirect=True,
70+
)
4671
def test_embedding_accuracy(data, keras_model, hls_model):
4772
X = data
4873
model = keras_model

test/pytest/test_qkeras.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,12 @@ def load_jettagging_model():
7070

7171
# TODO - Paramaterize for Quartus (different strategies?)
7272
@pytest.fixture
73-
@pytest.mark.parametrize('strategy', ['latency', 'resource'])
74-
def convert(load_jettagging_model, strategy):
73+
def convert(load_jettagging_model, request):
7574
"""
7675
Convert a QKeras model trained on the jet tagging dataset
7776
"""
77+
78+
strategy = request.param
7879
model = load_jettagging_model
7980

8081
config = hls4ml.utils.config_from_keras_model(model, granularity='name', backend='Vivado')
@@ -91,8 +92,8 @@ def convert(load_jettagging_model, strategy):
9192
return hls_model
9293

9394

94-
@pytest.mark.parametrize('strategy', ['latency', 'resource'])
95-
def test_accuracy(convert, load_jettagging_model, get_jettagging_data, strategy):
95+
@pytest.mark.parametrize('convert', ['latency', 'resource'], indirect=True, ids=['latency', 'resource'])
96+
def test_accuracy(convert, load_jettagging_model, get_jettagging_data):
9697
"""
9798
Test the hls4ml-evaluated accuracy of a 3 hidden layer QKeras model trained on
9899
the jet tagging dataset. QKeras model accuracy is required to be over 70%, and

test/pytest/test_transpose_concat.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@ def keras_model():
2828

2929

3030
@pytest.fixture
31-
@pytest.mark.parametrize('io_type', ['io_stream', 'io_parallel'])
32-
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI'])
33-
def hls_model(keras_model, backend, io_type):
31+
def hls_model(keras_model, request):
32+
io_type, backend = request.param
3433
hls_config = hls4ml.utils.config_from_keras_model(
3534
keras_model, default_precision='ap_fixed<16,3,AP_RND_CONV,AP_SAT>', granularity='name', backend=backend
3635
)
@@ -44,8 +43,30 @@ def hls_model(keras_model, backend, io_type):
4443
return hls_model
4544

4645

47-
@pytest.mark.parametrize('io_type', ['io_stream', 'io_parallel'])
48-
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI'])
46+
@pytest.mark.parametrize(
47+
'hls_model',
48+
[
49+
('io_stream', 'Vivado'),
50+
('io_stream', 'Vitis'),
51+
('io_stream', 'Quartus'),
52+
('io_stream', 'oneAPI'),
53+
('io_parallel', 'Vivado'),
54+
('io_parallel', 'Vitis'),
55+
('io_parallel', 'Quartus'),
56+
('io_parallel', 'oneAPI'),
57+
],
58+
indirect=True,
59+
ids=[
60+
'vivado_stream',
61+
'vitis_streamq',
62+
'quartus_stream',
63+
'oneapi_stream',
64+
'vivado_parallel',
65+
'vitis_parallel',
66+
'quartus_parallel',
67+
'oneapi_parallel',
68+
],
69+
)
4970
def test_accuracy(data, keras_model, hls_model):
5071
X = data
5172
model = keras_model

0 commit comments

Comments
 (0)