Skip to content

Commit cc6d6dc

Browse files
committed
Add more models for testing
1 parent 9c0835b commit cc6d6dc

File tree

1 file changed

+39
-17
lines changed

1 file changed

+39
-17
lines changed

test/pytest/test_conversion.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,19 @@ def single_layer_model_factory(layer):
6262
return None, None
6363

6464
def onnx_act_model(layer):
65-
inp = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.DOUBLE, [*layer.data_shape])
66-
out = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.DOUBLE, [None for i in range(len(layer.data_shape))])
65+
inp = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [*layer.data_shape])
66+
out = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [None for i in range(len(layer.data_shape))])
6767
model = onnx.helper.make_model(onnx.helper.make_graph([layer.layer], layer.output_dir, [inp], [out]))
6868
return model
6969

7070
def onnx_gemm_model(layer):
7171
assert isinstance(layer, LayerTestWrapper)
7272
wshape = (*layer.input_shape, *layer.output_shape)
73-
w = onnx.helper.make_tensor('b', onnx.TensorProto.DOUBLE, wshape, rand_neg1topos1(*wshape).flatten())
74-
b = onnx.helper.make_tensor('c', onnx.TensorProto.DOUBLE, layer.output_shape, rand_neg1topos1(*layer.output_shape).flatten())
73+
w = onnx.helper.make_tensor('b', onnx.TensorProto.FLOAT, wshape, rand_neg1topos1(*wshape).flatten())
74+
b = onnx.helper.make_tensor('c', onnx.TensorProto.FLOAT, layer.output_shape, rand_neg1topos1(*layer.output_shape).flatten())
7575
#node = onnx.helper.make_node('Gemm', inputs=['x', 'b', 'c'], outputs=['y'], name='gemm')
76-
inp = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.DOUBLE, [None,*layer.input_shape])
77-
out = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.DOUBLE, [None,None])
76+
inp = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [None,*layer.input_shape])
77+
out = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [None,None])
7878
graph = onnx.helper.make_graph([layer.layer], 'gemm', [inp], [out])
7979
graph.initializer.append(w)
8080
graph.initializer.append(b)
@@ -85,19 +85,30 @@ def onnx_gemm_model(layer):
8585
def onnx_matmul_model(layer):
8686
assert isinstance(layer, LayerTestWrapper)
8787
wshape = (*layer.input_shape, *layer.output_shape)
88-
w = onnx.helper.make_tensor('b', onnx.TensorProto.DOUBLE, wshape, rand_neg1topos1(*wshape).flatten())
88+
w = onnx.helper.make_tensor('b', onnx.TensorProto.FLOAT, wshape, rand_neg1topos1(*wshape).flatten())
8989
#node = onnx.helper.make_node('MatMul', inputs=['x', 'b'], outputs=['y'], name='matmul')
90-
inp = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.DOUBLE, [None,*layer.input_shape])
91-
out = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.DOUBLE, [None,None])
90+
inp = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [None,*layer.input_shape])
91+
out = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [None,None])
9292
graph = onnx.helper.make_graph([layer.layer], 'matmul', [inp], [out])
9393
graph.initializer.append(w)
9494
model = onnx.helper.make_model(graph)
9595
onnx.checker.check_model(model)
9696
return model
9797

98-
onnx_model_makers = {'Relu' : onnx_act_model,
99-
'Gemm' : onnx_gemm_model,
100-
'MatMul' : onnx_matmul_model}
98+
onnx_model_makers = {'Relu' : onnx_act_model,
99+
'Elu' : onnx_act_model,
100+
'Selu' : onnx_act_model,
101+
'PRelu' : onnx_act_model,
102+
'ThresholdedRelu' : onnx_act_model,
103+
'LeakyRelu' : onnx_act_model,
104+
'Tanh' : onnx_act_model,
105+
'Sigmoid' : onnx_act_model,
106+
'HardSigmoid' : onnx_act_model,
107+
'Softmax' : onnx_act_model,
108+
'Softplus' : onnx_act_model,
109+
'Clip' : onnx_act_model,
110+
'Gemm' : onnx_gemm_model,
111+
'MatMul' : onnx_matmul_model}
101112

102113
def validate_model_predictions(model, hls_model, shape, fdata=np.random.rand, test=np.testing.assert_allclose, test_kwargs={'atol':1e-2, 'rtol':1e-2}):
103114
'''Generate random data with shape, execute inference on model and hls_model, and test for correctness'''
@@ -109,7 +120,7 @@ def validate_model_predictions(model, hls_model, shape, fdata=np.random.rand, te
109120
y_ref = model(torch.Tensor(X)).detach().numpy()
110121
elif isinstance(model, onnx.onnx_ml_pb2.ModelProto):
111122
session = onnxruntime.InferenceSession(f'{model.metadata_props[0].value}/model.onnx')
112-
y_ref = session.run(['y'], {'x' : X})[0]
123+
y_ref = session.run(['y'], {'x' : X.astype(np.float32)})[0]
113124
y_hls = hls_model.predict(X)
114125
# Reshape output for Conv models. Use the hls_model's shape for extra validation
115126
if len(y_ref.shape) > 2:
@@ -223,20 +234,31 @@ def __repr__(self):
223234
#(torch.nn.Linear(16, 16, bias=False), f'{odb}pytorch_linear', (16,), None, 100),
224235
# Activations
225236
(torch.nn.ReLU(), f'{odb}pytorch_relu', (16,), None, 100),
226-
(torch.nn.LeakyReLU(negative_slope=1.0), f'{odb}pytorch_activation_leakyrelu_1', (16,), None, 100),
237+
#(torch.nn.LeakyReLU(negative_slope=1.0), f'{odb}pytorch_activation_leakyrelu_1', (16,), None, 100),
227238
#(torch.nn.LeakyReLU(negative_slope=0.5), f'{odb}pytorch_activation_leakyrelu_1', (16,), None, 100),
228-
(torch.nn.ELU(alpha=1.0), f'{odb}pytorch_activation_elu_1', (16,), None, 100),
239+
#(torch.nn.ELU(alpha=1.0), f'{odb}pytorch_activation_elu_1', (16,), None, 100),
229240
#(torch.nn.ELU(alpha=0.5), f'{odb}pytorch_activation_elu_1', (16,), None, 100),
230241
] # close pytorch_layers
231242

232243
# TODO: find out why Gemm tests don't pass
233244
onnx_layers = [(onnx.helper.make_node('MatMul', inputs=['x', 'b'], outputs=['y'], name='matmul'), f'{odb}onnx_matmul_1', (16,), (16,), 100),
234245
(onnx.helper.make_node('MatMul', inputs=['x', 'b'], outputs=['y'], name='matmul'), f'{odb}onnx_matmul_2', (16,), (8,), 100),
235-
(onnx.helper.make_node('MatMul', inputs=['x', 'b'], outputs=['y'], name='matmul'), f'{odb}onnx_matmul_2', (8,), (16,), 100),
246+
(onnx.helper.make_node('MatMul', inputs=['x', 'b'], outputs=['y'], name='matmul'), f'{odb}onnx_matmul_3', (8,), (16,), 100),
236247
#(onnx.helper.make_node('Gemm', inputs=['x', 'b', 'c'], outputs=['y'], name='gemm'), f'{odb}onnx_gemm_1', (16,), (16,), 100),
237248
#(onnx.helper.make_node('Gemm', inputs=['x', 'b', 'c'], outputs=['y'], name='gemm'), f'{odb}onnx_gemm_2', (16,), (8,), 100),
238249
#(onnx.helper.make_node('Gemm', inputs=['x', 'b', 'c'], outputs=['y'], name='gemm'), f'{odb}onnx_gemm_3', (8,), (16,), 100),
239-
(onnx.helper.make_node('Relu', inputs=['x'], outputs=['y'], name='relu'), f'{odb}onnx_relu', (1,), None, 100),
250+
(onnx.helper.make_node('Relu', inputs=['x'], outputs=['y'], name='relu'), f'{odb}onnx_act_relu', (1,), None, 100),
251+
(onnx.helper.make_node('Elu', inputs=['x'], outputs=['y'], name='elu', alpha=1.0), f'{odb}onnx_act_elu_1', (1,), None, 100),
252+
(onnx.helper.make_node('Elu', inputs=['x'], outputs=['y'], name='elu', alpha=0.5), f'{odb}onnx_elu_2', (1,), None, 100),
253+
(onnx.helper.make_node('LeakyRelu', inputs=['x'], outputs=['y'], name='leakyrelu', alpha=1.0), f'{odb}onnx_act_leakyrelu_1', (1,), None, 100),
254+
(onnx.helper.make_node('LeakyRelu', inputs=['x'], outputs=['y'], name='leakyrelu', alpha=0.5), f'{odb}onnx_leakyrelu_2', (1,), None, 100),
255+
(onnx.helper.make_node('ThresholdedRelu', inputs=['x'], outputs=['y'], name='thresholdedrelu', alpha=1.0), f'{odb}onnx_act_thresholdedrelu_1', (1,), None, 100),
256+
(onnx.helper.make_node('ThresholdedRelu', inputs=['x'], outputs=['y'], name='thresholdedrelu', alpha=0.5), f'{odb}onnx_act_thresholdedrelu_2', (1,), None, 100),
257+
(onnx.helper.make_node('Tanh', inputs=['x'], outputs=['y'], name='tanh'), f'{odb}onnx_act_tanh', (1,), None, 100),
258+
(onnx.helper.make_node('Sigmoid', inputs=['x'], outputs=['y'], name='sigmoid'), f'{odb}onnx_act_sigmoid', (1,), None, 100),
259+
#(onnx.helper.make_node('HardSigmoid', inputs=['x'], outputs=['y'], name='hardsigmoid'), f'{odb}onnx_act_hardsigmoid', (1,), None, 100),
260+
(onnx.helper.make_node('Softplus', inputs=['x'], outputs=['y'], name='softplus'), f'{odb}onnx_act_softplus', (1,), None, 100),
261+
#(onnx.helper.make_node('Clip', inputs=['x'], outputs=['y'], name='clip'), f'{odb}onnx_act_clip_1', (1,), None, 100),
240262
] #close onnx_layers
241263

242264
layers = [*keras_layers, *pytorch_layers, *onnx_layers]

0 commit comments

Comments
 (0)