Skip to content

Commit f801d3e

Browse files
authored
Fix some bugs found by running onnx-1.1.2 checker (#25)
* Fix some bugs found by running onnx checker 1. One-hot encoder has not "inputlist" attribute 2. Create RNN/LSTM/GRU nodes in a topological order 3. Change attribute name from "values" to "value" for Pad * Remove lines not needed * Append initializers' value info's to model input list * Fix InnerProduct's initializer names * Fix LSTM and discard one change * Fix some minor bugs * Fix a C-style naming problem for convolution conversion * Address comments and fix one more bug 1. Fix comments about doc 2. Update the interfaces of converting functions so that arguments can be passed in * Fix shape calculator for merge layer (support [N, C] inputs)
1 parent a28a535 commit f801d3e

File tree

13 files changed

+70
-92
lines changed

13 files changed

+70
-92
lines changed

onnxmltools/convert/common/_topology.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,8 +631,24 @@ def convert_topology(topology, model_name, doc_string):
631631
container.nodes[i], container.nodes[another_node_id] = \
632632
container.nodes[another_node_id], container.nodes[i]
633633

634+
# When calling ModelComponentContainer's add_initializer(...), nothing is added into the input list. However, in
635+
# ONNX initializers should also be model's (GraphProto) inputs. Thus, we create ValueInfoProto objects from
636+
# initializers (type: TensorProto) directly and then add them into model's input list.
637+
extra_inputs = [] # ValueInfoProto list of the initializers
638+
for tensor in container.initializers:
639+
# Sometimes (especially when creating optional input values such as RNN's initial hidden state), an initializer
640+
# is also one of the original model's input, so it has been added into the container's input list. If this is
641+
# the case, we need to skip one iteration to avoid duplicated inputs.
642+
if tensor.name in [value_info.name for value_info in container.inputs]:
643+
continue
644+
645+
# Initializers are always tensors so we can just call make_tensor_value_info(...)
646+
value_info = helper.make_tensor_value_info(tensor.name, tensor.data_type, tensor.dims)
647+
extra_inputs.append(value_info)
648+
634649
# Create a graph from its main components
635-
graph = helper.make_graph(container.nodes, model_name, container.inputs, container.outputs, container.initializers)
650+
graph = helper.make_graph(container.nodes, model_name, container.inputs + extra_inputs,
651+
container.outputs, container.initializers)
636652

637653
# Add extra information related to the graph
638654
graph.value_info.extend(container.value_info)

onnxmltools/convert/coreml/operator_converters/neural_network/BatchNorm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def convert_batch_normalization(scope, operator, container):
3131
inputs.append(bias_tensor_name)
3232

3333
attrs['epsilon'] = params.epsilon
34-
attrs['spatial'] = 1 # True
3534

3635
if op_type == 'BatchNormalization':
3736
mean_tensor_name = scope.get_unique_variable_name(op_type + '_mean')
@@ -43,6 +42,7 @@ def convert_batch_normalization(scope, operator, container):
4342
params.variance.floatValue)
4443
inputs.append(variance_tensor_name)
4544
attrs['momentum'] = 0.
45+
attrs['spatial'] = 1 # True
4646

4747
if not params.instanceNormalization and params.computeMeanVar:
4848
# In this case, we apply batch normalization and adjust the statistics stored according the the batch
@@ -63,8 +63,6 @@ def convert_batch_normalization(scope, operator, container):
6363
attrs['is_test'] = 1 # True
6464
else:
6565
raise ValueError('Unsupported operation mode')
66-
else:
67-
attrs['is_test'] = 1 # True
6866

6967
container.add_node(op_type, inputs, outputs, **attrs)
7068

onnxmltools/convert/coreml/operator_converters/neural_network/BidirectionalLSTM.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -325,42 +325,39 @@ def convert_bidirectional_lstm(scope, operator, container):
325325
lstm_attrs['clip'] = lstm_params.cellClipThreshold
326326
lstm_attrs['input_forget'] = lstm_params.coupledInputAndForgetGate
327327

328-
# Create output part of CoreML LSTM
329-
if lstm_params.sequenceOutput:
330-
lstm_y_name = scope.get_unique_variable_name(lstm_op_name + '_Y')
331-
lstm_outputs.append(lstm_y_name)
328+
# Create the major LSTM operator. We assign a tensor name to each output of LSTM. However, variables can be
329+
# undefined in some cases. For example, when output_sequence=False, the first output is not meaningful.
330+
lstm_y_name = scope.get_unique_variable_name(lstm_op_name + '_Y')
331+
lstm_y_h_name = scope.get_unique_variable_name(lstm_op_name + '_Y_h')
332+
lstm_y_c_name = scope.get_unique_variable_name(lstm_op_name + '_Y_c')
333+
lstm_outputs.extend([lstm_y_name, lstm_y_h_name, lstm_y_c_name])
334+
container.add_node('LSTM', lstm_inputs, lstm_outputs, **lstm_attrs)
332335

336+
# Create post-processing operators for converting ONNX LSTM outputs to CoreML ones
337+
if lstm_params.sequenceOutput:
333338
container.add_node('Reshape', lstm_y_name, operator.outputs[0].full_name,
334339
name=scope.get_unique_operator_name('Reshape'), shape=[-1, 2 * hidden_size])
335340

336341
if len(operator.outputs) > 1:
337-
lstm_y_h_name = scope.get_unique_variable_name(lstm_op_name + '_Y_h')
338-
lstm_outputs.append(lstm_y_h_name)
339-
340342
lstm_y_h_reshape_name = scope.get_unique_variable_name(lstm_op_name + '_Y_h_reshape')
341343
container.add_node('Reshape', lstm_y_h_name, lstm_y_h_reshape_name,
342344
name=scope.get_unique_operator_name('Reshape'), shape=[2, hidden_size])
343345

344346
container.add_node('Split', lstm_y_h_reshape_name,
345347
[operator.outputs[1].full_name, operator.outputs[3].full_name],
346-
op_version=2, name=scope.get_unique_operator_name('Split'), split=[1, 1, ], axis=0)
348+
op_version=2, name=scope.get_unique_operator_name('Split'), split=[1, 1], axis=0)
347349
else:
348-
# Here we ingore ONNX RNN's first output because it's useless.
349-
lstm_outputs.append(scope.get_unique_variable_name('isolated'))
350-
351-
# Handle the second output of ONNX LSTM. It will become the first and the second outputs of
352-
# CoreML's LSTM.
353-
lstm_y_name = scope.get_unique_variable_name(lstm_op_name + '_Y')
354-
lstm_outputs.append(lstm_y_name)
350+
# Here we ignore ONNX RNN's first output because it's useless. The second output of ONNX LSTM will be used to
351+
# generate the first and the second outputs of CoreML LSTM.
355352

356353
# Directly reshape ONNX LSTM's 2nd output to CoreML LSTM's 1st output.
357-
container.add_node('Reshape', lstm_y_name, operator.outputs[0].full_name,
354+
container.add_node('Reshape', lstm_y_h_name, operator.outputs[0].full_name,
358355
name=scope.get_unique_operator_name('Reshape'), shape=[1, 2 * hidden_size])
359356

360357
if len(operator.outputs) > 1:
361358
lstm_y_reshape_name = scope.get_unique_variable_name(lstm_op_name + '_Y_reshape')
362359

363-
container.add_node('Reshape', lstm_y_name, lstm_y_reshape_name,
360+
container.add_node('Reshape', lstm_y_h_name, lstm_y_reshape_name,
364361
name=scope.get_unique_operator_name('Reshape'), shape=[2, hidden_size])
365362

366363
container.add_node('Split', lstm_y_reshape_name,
@@ -369,9 +366,6 @@ def convert_bidirectional_lstm(scope, operator, container):
369366

370367
# Output cell state if necessary
371368
if len(operator.outputs) > 2:
372-
lstm_y_c_name = scope.get_unique_variable_name(lstm_op_name + '_Y_c')
373-
lstm_outputs.append(lstm_y_c_name)
374-
375369
lstm_y_c_reshape_name = scope.get_unique_variable_name(lstm_op_name + '_Y_c_reshape')
376370
container.add_node('Reshape', lstm_y_c_name, lstm_y_c_reshape_name,
377371
name=scope.get_unique_operator_name('Reshape'), shape=[2, hidden_size])
@@ -380,8 +374,5 @@ def convert_bidirectional_lstm(scope, operator, container):
380374
[operator.outputs[2].full_name, operator.outputs[4].full_name],
381375
op_version=2, name=scope.get_unique_operator_name('Split'), split=[1, 1], axis=0)
382376

383-
# Create the major LSTM operator
384-
container.add_node('LSTM', lstm_inputs, lstm_outputs, **lstm_attrs)
385-
386377

387378
register_converter('biDirectionalLSTM', convert_bidirectional_lstm)

onnxmltools/convert/coreml/operator_converters/neural_network/Convolution.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ def convert_convolution(scope, operator, container):
2323
if params.isDeconvolution:
2424
shape_w[0] = params.kernelChannels
2525
shape_w[1] = int(params.outputChannels / n_groups)
26-
name_w = operator.full_name + '.W'
26+
name_w = scope.get_unique_variable_name(operator.full_name + '_W')
2727
inputs.append(name_w)
2828
container.add_initializer(name_w, onnx_proto.TensorProto.FLOAT, shape_w, params.weights.floatValue)
2929

3030
if params.hasBias:
3131
shape_b = [len(params.bias.floatValue)]
32-
name_b = operator.full_name + '.B'
32+
name_b = scope.get_unique_variable_name(operator.full_name + '_B')
3333
inputs.append(name_b)
3434
container.add_initializer(name_b, onnx_proto.TensorProto.FLOAT, shape_b, params.bias.floatValue)
3535

onnxmltools/convert/coreml/operator_converters/neural_network/GRU.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -155,39 +155,32 @@ def convert_gru(scope, operator, container):
155155
gru_attrs['output_sequence'] = params.sequenceOutput
156156
gru_attrs['hidden_size'] = hidden_size
157157

158+
# Create the major GRU operator in ONNX.
159+
gru_y_name = scope.get_unique_variable_name(gru_op_name + '_Y')
160+
gru_y_h_name = scope.get_unique_variable_name(gru_op_name + '_Y_h')
161+
gru_outputs.extend([gru_y_name, gru_y_h_name])
162+
container.add_node('GRU', gru_inputs, gru_outputs, **gru_attrs)
163+
164+
# To simulate CoreML LSTM, we add post-processing operators to adjust ONNX LSTM outputs
158165
if params.sequenceOutput:
159166
# Again, the output shapes in ONNX's GRU is not consistent with that in CoreML, so we need
160167
# to adjust the result produced by ONNX according to CoreML format.
161-
gru_y_name = scope.get_unique_variable_name(gru_op_name + '_Y')
162-
gru_outputs.append(gru_y_name)
163168
container.add_node('Reshape', gru_y_name, operator.outputs[0].full_name,
164169
name=scope.get_unique_operator_name('Reshape'), shape=[-1, hidden_size])
165170

166171
# Handle the second output, the last hidden state of a sequence, if exists.
167172
if len(operator.outputs) == 2:
168-
gru_y_h_name = scope.get_unique_variable_name(gru_op_name + '_Y_h')
169-
gru_outputs.append(gru_y_h_name)
170173
container.add_node('Reshape', gru_y_h_name, operator.outputs[1].full_name,
171174
name=scope.get_unique_operator_name('Reshape'), shape=[1, hidden_size])
172175
else:
173176
# Recall that when sequence output is false, the first and the second outputs of GRU
174177
# are identical. Thus, we can ignore ONNX GRU's first output.
175-
gru_outputs.append(scope.get_unique_variable_name('isloated'))
176-
177-
# As the two outputs are always identical, so we just need to compute one of them and
178-
# produce the other using identiy operator.
179-
gru_y_name = scope.get_unique_variable_name(gru_op_name + '_Y')
180-
gru_outputs.append(gru_y_name)
181-
182-
container.add_node('Reshape', gru_y_name, operator.outputs[0].full_name,
178+
container.add_node('Reshape', gru_y_h_name, operator.outputs[0].full_name,
183179
name=scope.get_unique_operator_name('Reshape'), shape=[1, hidden_size])
184180

185181
if len(operator.outputs) == 2:
186182
container.add_node('Identity', operator.outputs[0].full_name, operator.outputs[1].full_name,
187183
name=scope.get_unique_operator_name('Identity'))
188184

189-
# Finally, we create the major GRU operator in ONNX.
190-
container.add_node('GRU', gru_inputs, gru_outputs, **gru_attrs)
191-
192185

193186
register_converter('gru', convert_gru)

onnxmltools/convert/coreml/operator_converters/neural_network/InnerProduct.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ def convert_inner_product(scope, operator, container):
1515
outputs = [variable.full_name for variable in operator.outputs]
1616
attrs = {'name': operator.full_name}
1717

18-
name_w = operator.full_name + '.W'
18+
name_w = scope.get_unique_variable_name(operator.full_name + '_W')
1919
shape_w = [params.outputChannels, params.inputChannels]
2020
inputs.append(name_w)
2121
container.add_initializer(name_w, onnx_proto.TensorProto.FLOAT, shape_w, params.weights.floatValue)
2222

23-
name_b = operator.full_name + '.B'
23+
name_b = scope.get_unique_variable_name(operator.full_name + '_B')
2424
shape_b = [params.outputChannels]
2525
inputs.append(name_b)
2626
if params.hasBias:

onnxmltools/convert/coreml/operator_converters/neural_network/LSTM.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -249,28 +249,27 @@ def convert_unidirectional_lstm(scope, operator, container):
249249
lstm_attrs['clip'] = lstm_params.cellClipThreshold
250250
lstm_attrs['input_forget'] = lstm_params.coupledInputAndForgetGate
251251

252+
# Create the main LSTM operator
253+
lstm_y_name = scope.get_unique_variable_name(lstm_op_name + '_Y')
254+
lstm_y_h_name = scope.get_unique_variable_name(lstm_op_name + '_Y_h')
255+
lstm_c_name = scope.get_unique_variable_name(lstm_op_name + '_Y_c')
256+
lstm_outputs.extend([lstm_y_name, lstm_y_h_name, lstm_c_name])
257+
container.add_node('LSTM', lstm_inputs, lstm_outputs, **lstm_attrs)
258+
252259
# Handle the first output of LSTM
253260
if lstm_params.sequenceOutput:
254261
# Handle the first output of LSTM
255-
lstm_y_name = scope.get_unique_variable_name(lstm_op_name + '_Y')
256-
lstm_outputs.append(lstm_y_name)
257262
container.add_node('Reshape', lstm_y_name, operator.outputs[0].full_name,
258263
name=scope.get_unique_operator_name('Reshape'), shape=[-1, hidden_size])
259264

260265
# Handle the second output of LSTM
261266
if len(operator.outputs) > 1:
262-
lstm_y_h_name = scope.get_unique_variable_name(lstm_op_name + '_Y_h')
263-
lstm_outputs.append(lstm_y_h_name)
264267
container.add_node('Reshape', lstm_y_h_name, operator.outputs[1].full_name,
265268
name=scope.get_unique_operator_name('Reshape'), shape=[1, hidden_size])
266269
else:
267-
# Here we ingore ONNX RNN's first output because it's useless.
268-
lstm_outputs.append(scope.get_unique_variable_name('isolated'))
269-
270-
# Use the second output of ONNX LSTM to produce the first output of CoreML LSTM
271-
lstm_y_name = scope.get_unique_variable_name(lstm_op_name + '_Y')
272-
lstm_outputs.append(lstm_y_name)
273-
container.add_node('Reshape', lstm_y_name, operator.outputs[0].full_name,
270+
# Here we ingore ONNX RNN's first output because it's useless and use the second output of ONNX LSTM to produce
271+
# the first output of CoreML LSTM
272+
container.add_node('Reshape', lstm_y_h_name, operator.outputs[0].full_name,
274273
name=scope.get_unique_operator_name('Reshape'), shape=[1, hidden_size])
275274

276275
# Create the second LSTM output from the first output
@@ -280,13 +279,8 @@ def convert_unidirectional_lstm(scope, operator, container):
280279

281280
# Handle the cell state output of LSTM
282281
if len(operator.outputs) > 2:
283-
lstm_c_name = scope.get_unique_variable_name(lstm_op_name + '_Y_c')
284-
lstm_outputs.append(lstm_c_name)
285282
container.add_node('Reshape', lstm_c_name, operator.outputs[2].full_name,
286283
name=scope.get_unique_operator_name('Reshape'), shape=[1, hidden_size])
287284

288-
# Finally, the main LSTM operator is created
289-
container.add_node('LSTM', lstm_inputs, lstm_outputs, **lstm_attrs)
290-
291285

292286
register_converter('uniDirectionalLSTM', convert_unidirectional_lstm)

onnxmltools/convert/coreml/operator_converters/neural_network/Pad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def convert_padding(scope, operator, container):
3636
attrs['pads'] = pads
3737

3838
if pad_type == 'constant':
39-
attrs['values'] = params.constant.value
39+
attrs['value'] = params.constant.value
4040

4141
container.add_node(op_type, operator.input_full_names, operator.output_full_names, op_version=2, **attrs)
4242

onnxmltools/convert/coreml/operator_converters/neural_network/Pool.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,7 @@ def create_legacy_pad(scope, input_name, output_name, H_in, W_in, k_h, k_w,
7676
# N_end_index, C_end_index, H_end_index, W_end_index]
7777
# Because only H- and W-axes are padded in CoreML, we leave padding amounts of N- and C-axes zeros.
7878
pads = [0, 0, pad_t, pad_l, 0, 0, pad_b, pad_r]
79-
attrs = {'name': scope.get_unique_operator_name('Pad'), 'kernel_shape': [k_h, k_w],
80-
'strides': [k_h, k_w], 'pads': pads, 'value': padded_value}
79+
attrs = {'name': scope.get_unique_operator_name('Pad'), 'pads': pads, 'value': padded_value}
8180
container.add_node('Pad', input_name, output_name, op_version=2, **attrs)
8281

8382

onnxmltools/convert/coreml/operator_converters/neural_network/SimpleRNN.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -188,36 +188,26 @@ def convert_simple_rnn(scope, operator, container):
188188
rnn_attrs['output_sequence'] = params.sequenceOutput
189189
rnn_attrs['hidden_size'] = hidden_size
190190

191+
# We use the collected information to build ONNX's RNN
192+
rnn_y_name = scope.get_unique_variable_name(rnn_op_name + '_Y')
193+
rnn_h_name = scope.get_unique_variable_name(rnn_op_name + '_Y_h')
194+
container.add_node('RNN', rnn_inputs, [rnn_y_name, rnn_h_name], **rnn_attrs)
195+
191196
# Set up outputs' of RNN
192-
rnn_outputs = []
193197
if params.sequenceOutput:
194-
# Create ONNX's RNN output, which needs to be reshaped to fit CoreML standard.
195-
rnn_y_name = scope.get_unique_variable_name(rnn_op_name + '_Y')
196-
rnn_outputs.append(rnn_y_name)
197-
198198
# Connect ONNX's output and CoreML's output via a reshape operator
199199
container.add_node('Reshape', rnn_y_name, operator.outputs[0].full_name,
200200
name=scope.get_unique_operator_name('Reshape'), shape=[-1, hidden_size])
201201

202202
# Handel the second RNN output (aka last hidden state), which is optional.
203203
if len(operator.outputs) == 2:
204-
# Create ONNX's RNN output, which needs to be reshaped to fit CoreML standard.
205-
rnn_h_name = scope.get_unique_variable_name(rnn_op_name + '_Y_h')
206-
rnn_outputs.append(rnn_h_name)
207-
208204
# Connect ONNX's output and CoreML's output via a reshape operator
209205
container.add_node('Reshape', rnn_h_name, operator.outputs[1].full_name,
210206
name=scope.get_unique_operator_name('Reshape'),
211207
shape=[1, hidden_size])
212208
else:
213-
# Here we ignore ONNX RNN's first output by assigning it an isolated name. Isolated names
214-
# are not connected with anything else.
215-
rnn_outputs.append(scope.get_unique_variable_name('isolated'))
216-
217-
# According to CoreML, the two outputs are always identical, so we just need to compute one of
218-
# them and produce the other one using an identiy operator.
219-
rnn_h_name = scope.get_unique_variable_name(rnn_op_name + '_Y_h')
220-
rnn_outputs.append(rnn_h_name)
209+
# According to CoreML, its two outputs are always identical, so we just need to compute one of them and produce
210+
# the other one using an identity operator. Note that the first ONNX RNN output is undefined in this case.
221211

222212
# Reshape last hidden state's ONNX format to its CoreML format
223213
container.add_node('Reshape', rnn_h_name, operator.outputs[0].full_name,
@@ -228,8 +218,5 @@ def convert_simple_rnn(scope, operator, container):
228218
container.add_node('Identity', operator.outputs[0].full_name, operator.outputs[1].full_name,
229219
name=scope.get_unique_operator_name('Identity'))
230220

231-
# Finally, we use the collected information to build ONNX's RNN
232-
container.add_node('RNN', rnn_inputs, rnn_outputs, **rnn_attrs)
233-
234221

235222
register_converter('simpleRecurrent', convert_simple_rnn)

0 commit comments

Comments
 (0)