1010
1111# To conveniently run QONNX inference
1212from qonnx .core .modelwrapper import ModelWrapper
13+ from qonnx .transformation .channels_last import ConvertToChannelsLastAndClean
14+ from qonnx .transformation .gemm_to_matmul import GemmToMatMul
1315
1416import hls4ml
1517
@@ -99,10 +101,23 @@ def sep_conv_model():
99101 return model
100102
101103
104+ @pytest .fixture (scope = 'module' )
105+ def two_layer_keras_model ():
106+ """
107+ Load a simple, two-layer, originally keras, unquantized model
108+ """
109+ dl_file = str (example_model_path / "onnx/two_layer_keras.onnx" )
110+ assert os .path .isfile (dl_file )
111+
112+ model = ModelWrapper (dl_file )
113+ model = qonnx .util .cleanup .cleanup_model (model )
114+ return model
115+
116+
102117@pytest .fixture (scope = 'module' )
103118def three_layer_keras_model ():
104119 """
105- Load a simple, originally keras unquantized model
120+ Load a simple, three-layer, originally keras, unquantized model
106121 """
107122 dl_file = str (example_model_path / "onnx/three_layer_keras.onnx" )
108123 assert os .path .isfile (dl_file )
@@ -112,6 +127,84 @@ def three_layer_keras_model():
112127 return model
113128
114129
130+ @pytest .fixture (scope = 'module' )
131+ def two_layer_pytorch_model ():
132+ """
133+ Load a simple, two-layer, originally pytorch, unquantized model
134+ """
135+ dl_file = str (example_model_path / "onnx/two_layer_keras.onnx" )
136+ assert os .path .isfile (dl_file )
137+
138+ model = ModelWrapper (dl_file )
139+ model = qonnx .util .cleanup .cleanup_model (model )
140+ model = model .transform (GemmToMatMul ())
141+ model = qonnx .util .cleanup .cleanup_model (model )
142+ return model
143+
144+
145+ @pytest .fixture (scope = 'module' )
146+ def three_layer_pytorch_model ():
147+ """
148+ Load a simple, three-layer, originally pytorch, unquantized model
149+ """
150+ dl_file = str (example_model_path / "onnx/three_layer_pytorch.onnx" )
151+ assert os .path .isfile (dl_file )
152+
153+ model = ModelWrapper (dl_file )
154+ model = qonnx .util .cleanup .cleanup_model (model )
155+ model = model .transform (GemmToMatMul ())
156+ model = qonnx .util .cleanup .cleanup_model (model )
157+ return model
158+
159+
160+ @pytest .fixture (scope = 'module' )
161+ def conv1d_small_keras_model ():
162+ """
163+ Load a simple conv1d, originally keras, unquantized model
164+ """
165+ dl_file = str (example_model_path / "onnx/conv1d_small_keras.onnx" )
166+ assert os .path .isfile (dl_file )
167+
168+ model = ModelWrapper (dl_file )
169+ model = qonnx .util .cleanup .cleanup_model (model )
170+ model = model .transform (ConvertToChannelsLastAndClean ())
171+ model = model .transform (GemmToMatMul ())
172+ model = qonnx .util .cleanup .cleanup_model (model )
173+ return model
174+
175+
176+ @pytest .fixture (scope = 'module' )
177+ def conv2d_small_keras_model ():
178+ """
179+ Load a simple conv2d, originally keras, unquantized model
180+ """
181+ dl_file = str (example_model_path / "onnx/conv2d_small_keras.onnx" )
182+ assert os .path .isfile (dl_file )
183+
184+ model = ModelWrapper (dl_file )
185+ model = qonnx .util .cleanup .cleanup_model (model )
186+ model = model .transform (ConvertToChannelsLastAndClean ())
187+ model = model .transform (GemmToMatMul ())
188+ model = qonnx .util .cleanup .cleanup_model (model )
189+ return model
190+
191+
192+ @pytest .fixture (scope = 'module' )
193+ def conv2d_small_mp_keras_model ():
194+ """
195+ Load a conv2d model with max pooling, originally keras, unquantized model
196+ """
197+ dl_file = str (example_model_path / "onnx/conv2d_small_mp_keras.onnx" )
198+ assert os .path .isfile (dl_file )
199+
200+ model = ModelWrapper (dl_file )
201+ model = qonnx .util .cleanup .cleanup_model (model )
202+ model = model .transform (ConvertToChannelsLastAndClean ())
203+ model = model .transform (GemmToMatMul ())
204+ model = qonnx .util .cleanup .cleanup_model (model )
205+ return model
206+
207+
115208# The actual tests
116209
117210
@@ -216,25 +309,43 @@ def test_sep_conv(sep_conv_model, backend):
216309 np .testing .assert_allclose (y_qonnx .ravel (), y_hls4ml .ravel (), atol = 1e-2 , rtol = 1 )
217310
218311
312+ @pytest .mark .parametrize (
313+ 'model_name' ,
314+ [
315+ 'two_layer_keras_model' ,
316+ 'three_layer_keras_model' ,
317+ 'two_layer_pytorch_model' ,
318+ 'three_layer_pytorch_model' ,
319+ 'conv1d_small_keras_model' ,
320+ 'conv2d_small_keras_model' ,
321+ 'conv2d_small_mp_keras_model' ,
322+ ],
323+ )
219324@pytest .mark .parametrize ('backend' , ['Vitis' ])
220325@pytest .mark .parametrize ('io_type' , ['io_parallel' , 'io_stream' ])
221- def test_three_layer_keras (three_layer_keras_model , io_type , backend ):
222- model = three_layer_keras_model
326+ def test_simple_model (model_name , io_type , backend , request ):
327+ if model_name == 'conv2d_small_mp_keras_model' and io_type == 'io_stream' :
328+ # Not yet supported due to an issue with channels last conversion
329+ # There is a qonnx PR.
330+ pytest .skip ()
331+ model = request .getfixturevalue (model_name )
223332 ishape = tuple (model .get_tensor_shape (model .graph .input [0 ].name ))
224333 X = np .random .uniform (low = 0 , high = 1 , size = np .prod (ishape )).reshape (ishape )
225- X = (np .round (X * 2 ** 16 ) * 2 ** - 16 ).astype (np .float32 )
334+ X = (np .round (X * 2 ** 10 ) * 2 ** - 10 ).astype (np .float32 )
226335 idict = {model .graph .input [0 ].name : X }
227336 y_qonnx = oxe .execute_onnx (model , idict )[model .graph .output [0 ].name ]
228337
229338 config = hls4ml .utils .config .config_from_onnx_model (
230- model , granularity = 'name' , backend = backend , default_precision = 'fixed<32,16 >'
339+ model , granularity = 'name' , backend = backend , default_precision = 'fixed<16,6 >'
231340 )
232341
233- config ['LayerName' ]['Softmax_0' ]['Implementation' ] = 'legacy'
342+ for layer in config ['LayerName' ]:
343+ if layer .startswith ('Softmax' ):
344+ config ['LayerName' ][layer ]['Implementation' ] = 'legacy'
234345
235346 hls_model = hls4ml .converters .convert_from_onnx_model (
236347 model ,
237- output_dir = str (test_root_path / f'hls4mlprj_onnx_three_layer_keras_ { io_type } _{ backend } ' ),
348+ output_dir = str (test_root_path / f'hls4mlprj_onnx_ { model_name } _ { io_type } _{ backend } ' ),
238349 io_type = io_type ,
239350 backend = backend ,
240351 hls_config = config ,
0 commit comments