Skip to content

Commit c49a22d

Browse files
vinitraWenbing Li
authored andcommitted
Opset 9 Updates: BatchNorm, Cast, PRelu, Upsample (#241)
* updates for batch normalization and cast * common opset 9 updates in apply_operation for PRelu, upsample * refactoring apply_prelu * remove spatial attribute from batch norm * cast op: string support in opset 9, restrict attributes for opset 6, 1 * upsample: explanatory comment, whitespace for readibility * fixing comment on upsample, adding comment to prelu
1 parent eef63ee commit c49a22d

File tree

1 file changed

+40
-14
lines changed

1 file changed

+40
-14
lines changed

onnxmltools/convert/common/_apply_operation.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ def _apply_basic_numerical_operation(scope, op_type, input_names, output_name, c
4747
else:
4848
op_version = 6
4949
else:
50-
# Since ONNX-1.2 (opset 7), broadcasting behavior is Numpy-like, so we don't need to specify any attributes
50+
# Since ONNX-1.2 (opset 7), broadcasting behavior is NumPy-like,
51+
# so we don't need to specify any attributes
5152
op_version = 7
5253

5354
container.add_node(op_type, input_names, output_name, op_version=op_version, name=name, **attrs)
@@ -82,7 +83,7 @@ def apply_batch_norm(scope, input_names, output_names, container, operator_name=
8283
epsilon=None, is_test=None, momentum=None, spatial=None):
8384
name = _create_name_or_use_existing_one(scope, 'BatchNormalization', operator_name)
8485

85-
attrs = {'name': name, 'epsilon': epsilon, 'momentum': momentum, 'spatial': spatial}
86+
attrs = {'name': name, 'epsilon': epsilon, 'momentum': momentum}
8687

8788
if container.target_opset < 6:
8889
attrs['consumed_inputs'] = [0] * len(input_names)
@@ -91,12 +92,17 @@ def apply_batch_norm(scope, input_names, output_names, container, operator_name=
9192
if len(input_names) > 4:
9293
attrs['consumed_inputs'][4] = 2
9394
attrs['is_test'] = is_test
95+
attrs['spatial'] = spatial
9496
op_version = 1
9597
elif container.target_opset < 7:
9698
attrs['is_test'] = is_test
99+
attrs['spatial'] = spatial
97100
op_version = 6
98-
else:
101+
elif container.target_opset < 9:
102+
attrs['spatial'] = spatial
99103
op_version = 7
104+
else:
105+
op_version = 9
100106

101107
container.add_node('BatchNormalization', input_names, output_names, op_version=op_version, **attrs)
102108

@@ -111,16 +117,27 @@ def apply_cast(scope, input_name, output_name, container, operator_name=None, to
111117
d = onnx_proto.TensorProto.DataType.DESCRIPTOR
112118
allowed_type_name_and_type_enum_pairs = {v.number: k for k, v in d.values_by_name.items()}
113119
if to not in allowed_type_name_and_type_enum_pairs:
114-
raise ValueError('Attribute to must be one of %s' % allowed_type_name_and_type_enum_pairs.keys())
120+
raise ValueError('Attribute "to" must be one of %s' % allowed_type_name_and_type_enum_pairs.keys())
115121

116-
if container.target_opset < 7:
117-
# Convert enum to string, for example, TensorProto.INT64 to 'INT64'
118-
attrs['to'] = allowed_type_name_and_type_enum_pairs[to]
119-
op_version = 1
122+
if container.target_opset < 9:
123+
if to in [onnx_proto.TensorProto.STRING, onnx_proto.TensorProto.COMPLEX64, onnx_proto.TensorProto.COMPLEX128]:
124+
raise ValueError('Attribute "to" cannot correspond to a String or Complex TensorProto type.')
125+
126+
if container.target_opset < 6:
127+
# Convert enum to string, for example, TensorProto.INT64 to 'INT64'
128+
attrs['to'] = allowed_type_name_and_type_enum_pairs[to]
129+
op_version = 1
130+
else:
131+
# Enum, for example, TensorProto.INT64
132+
attrs['to'] = to
133+
op_version = 6
120134
else:
121-
# Enum, for example, TensorProto.INT64
135+
# Enum value, for example, TensorProto.INT64
136+
# String casting is supported in opset 9
137+
if to in [onnx_proto.TensorProto.COMPLEX64, onnx_proto.TensorProto.COMPLEX128]:
138+
raise ValueError('Attribute "to" cannot correspond to a Complex TensorProto type.')
122139
attrs['to'] = to
123-
op_version = 7
140+
op_version = 9
124141

125142
container.add_node('Cast', input_name, output_name, op_version=op_version, **attrs)
126143

@@ -345,10 +362,14 @@ def apply_upsample(scope, input_name, output_name, container, operator_name=None
345362
attrs['width_scale'] = float(scales[3])
346363
attrs['mode'] = mode.upper()
347364
op_version = 1
348-
else:
365+
elif container.target_opset < 9:
349366
attrs['scales'] = list(map(float, scales))
350367
attrs['mode'] = mode.lower()
351368
op_version = 7
369+
else:
370+
attrs['scales'] = list(map(float, scales))
371+
attrs['mode'] = mode.lower()
372+
op_version = 9
352373

353374
container.add_node('Upsample', input_name, output_name, op_version=op_version, **attrs)
354375

@@ -372,11 +393,16 @@ def apply_prelu(scope, input_name, output_name, container, operator_name=None, s
372393
if container.target_opset < 6:
373394
container.add_node('PRelu', [input_name, slope_tensor_name], output_name, op_version=1, name=name,
374395
consumed_inputs=[0, 0])
375-
elif container.target_opset < 7:
376-
container.add_node('PRelu', [input_name, slope_tensor_name], output_name, op_version=6, name=name)
377396
else:
378-
container.add_node('PRelu', [input_name, slope_tensor_name], output_name, op_version=7, name=name)
397+
if container.target_opset < 7:
398+
op_version = 6
399+
elif container.target_opset < 9:
400+
op_version = 7
401+
else:
402+
# opset 9 supports unidirectional broadcasting
403+
op_version = 9
379404

405+
container.add_node('PRelu', [input_name, slope_tensor_name], output_name, op_version=op_version, name=name)
380406

381407
def apply_elu(scope, input_name, output_name, container, operator_name=None, alpha=1.0):
382408
_apply_unary_operation(scope, 'Elu', input_name, output_name, container, operator_name, alpha=alpha)

0 commit comments

Comments
 (0)