@@ -47,7 +47,8 @@ def _apply_basic_numerical_operation(scope, op_type, input_names, output_name, c
4747 else :
4848 op_version = 6
4949 else :
50- # Since ONNX-1.2 (opset 7), broadcasting behavior is Numpy-like, so we don't need to specify any attributes
50+ # Since ONNX-1.2 (opset 7), broadcasting behavior is NumPy-like,
51+ # so we don't need to specify any attributes
5152 op_version = 7
5253
5354 container .add_node (op_type , input_names , output_name , op_version = op_version , name = name , ** attrs )
@@ -82,7 +83,7 @@ def apply_batch_norm(scope, input_names, output_names, container, operator_name=
8283 epsilon = None , is_test = None , momentum = None , spatial = None ):
8384 name = _create_name_or_use_existing_one (scope , 'BatchNormalization' , operator_name )
8485
85- attrs = {'name' : name , 'epsilon' : epsilon , 'momentum' : momentum , 'spatial' : spatial }
86+ attrs = {'name' : name , 'epsilon' : epsilon , 'momentum' : momentum }
8687
8788 if container .target_opset < 6 :
8889 attrs ['consumed_inputs' ] = [0 ] * len (input_names )
@@ -91,12 +92,17 @@ def apply_batch_norm(scope, input_names, output_names, container, operator_name=
9192 if len (input_names ) > 4 :
9293 attrs ['consumed_inputs' ][4 ] = 2
9394 attrs ['is_test' ] = is_test
95+ attrs ['spatial' ] = spatial
9496 op_version = 1
9597 elif container .target_opset < 7 :
9698 attrs ['is_test' ] = is_test
99+ attrs ['spatial' ] = spatial
97100 op_version = 6
98- else :
101+ elif container .target_opset < 9 :
102+ attrs ['spatial' ] = spatial
99103 op_version = 7
104+ else :
105+ op_version = 9
100106
101107 container .add_node ('BatchNormalization' , input_names , output_names , op_version = op_version , ** attrs )
102108
@@ -111,16 +117,27 @@ def apply_cast(scope, input_name, output_name, container, operator_name=None, to
111117 d = onnx_proto .TensorProto .DataType .DESCRIPTOR
112118 allowed_type_name_and_type_enum_pairs = {v .number : k for k , v in d .values_by_name .items ()}
113119 if to not in allowed_type_name_and_type_enum_pairs :
114- raise ValueError ('Attribute to must be one of %s' % allowed_type_name_and_type_enum_pairs .keys ())
120+ raise ValueError ('Attribute "to" must be one of %s' % allowed_type_name_and_type_enum_pairs .keys ())
115121
116- if container .target_opset < 7 :
117- # Convert enum to string, for example, TensorProto.INT64 to 'INT64'
118- attrs ['to' ] = allowed_type_name_and_type_enum_pairs [to ]
119- op_version = 1
122+ if container .target_opset < 9 :
123+ if to in [onnx_proto .TensorProto .STRING , onnx_proto .TensorProto .COMPLEX64 , onnx_proto .TensorProto .COMPLEX128 ]:
124+ raise ValueError ('Attribute "to" cannot correspond to a String or Complex TensorProto type.' )
125+
126+ if container .target_opset < 6 :
127+ # Convert enum to string, for example, TensorProto.INT64 to 'INT64'
128+ attrs ['to' ] = allowed_type_name_and_type_enum_pairs [to ]
129+ op_version = 1
130+ else :
131+ # Enum, for example, TensorProto.INT64
132+ attrs ['to' ] = to
133+ op_version = 6
120134 else :
121- # Enum, for example, TensorProto.INT64
135+ # Enum value, for example, TensorProto.INT64
136+ # String casting is supported in opset 9
137+ if to in [onnx_proto .TensorProto .COMPLEX64 , onnx_proto .TensorProto .COMPLEX128 ]:
138+ raise ValueError ('Attribute "to" cannot correspond to a Complex TensorProto type.' )
122139 attrs ['to' ] = to
123- op_version = 7
140+ op_version = 9
124141
125142 container .add_node ('Cast' , input_name , output_name , op_version = op_version , ** attrs )
126143
@@ -345,10 +362,14 @@ def apply_upsample(scope, input_name, output_name, container, operator_name=None
345362 attrs ['width_scale' ] = float (scales [3 ])
346363 attrs ['mode' ] = mode .upper ()
347364 op_version = 1
348- else :
365+ elif container . target_opset < 9 :
349366 attrs ['scales' ] = list (map (float , scales ))
350367 attrs ['mode' ] = mode .lower ()
351368 op_version = 7
369+ else :
370+ attrs ['scales' ] = list (map (float , scales ))
371+ attrs ['mode' ] = mode .lower ()
372+ op_version = 9
352373
353374 container .add_node ('Upsample' , input_name , output_name , op_version = op_version , ** attrs )
354375
@@ -372,11 +393,16 @@ def apply_prelu(scope, input_name, output_name, container, operator_name=None, s
372393 if container .target_opset < 6 :
373394 container .add_node ('PRelu' , [input_name , slope_tensor_name ], output_name , op_version = 1 , name = name ,
374395 consumed_inputs = [0 , 0 ])
375- elif container .target_opset < 7 :
376- container .add_node ('PRelu' , [input_name , slope_tensor_name ], output_name , op_version = 6 , name = name )
377396 else :
378- container .add_node ('PRelu' , [input_name , slope_tensor_name ], output_name , op_version = 7 , name = name )
397+ if container .target_opset < 7 :
398+ op_version = 6
399+ elif container .target_opset < 9 :
400+ op_version = 7
401+ else :
402+ # opset 9 supports unidirectional broadcasting
403+ op_version = 9
379404
405+ container .add_node ('PRelu' , [input_name , slope_tensor_name ], output_name , op_version = op_version , name = name )
380406
381407def apply_elu (scope , input_name , output_name , container , operator_name = None , alpha = 1.0 ):
382408 _apply_unary_operation (scope , 'Elu' , input_name , output_name , container , operator_name , alpha = alpha )
0 commit comments