@@ -800,6 +800,67 @@ def version_7(cls, ctx, node, **kwargs):
800
800
801
801
@tf_op (["Pad" , "PadV2" , "MirrorPad" ], onnx_op = "Pad" )
802
802
class Pad :
803
+
804
+ @classmethod
805
+ def convert_symmetric_pads (cls , ctx , node ):
806
+ """Currently there isn't a symmetric padding mode in ONNX so we add a dummy row then use the reflect mode
807
+ and remove the dummy row with compress. Ex: 1234 -> 012340 -> 2101234043 -> 21123443. Only do this to
808
+ dims with non-zero pads (if pads are constant)"""
809
+ rank = ctx .get_rank (node .input [0 ])
810
+ utils .make_sure (rank is not None , "Cannot convert pad with symmetric mode and unknown rank" )
811
+ utils .make_sure (ctx .opset >= 9 , "opset 9 required for symmetric padding mode" )
812
+ node .set_attr ("mode" , "reflect" )
813
+ const_pads = None
814
+ consumers = ctx .find_output_consumers (node .output [0 ])
815
+ output_shape = ctx .get_shape (node .output [0 ])
816
+ if ctx .opset < 11 :
817
+ const_pads = node .get_attr_value ("pads" )
818
+ elif node .inputs [1 ].is_const ():
819
+ const_pads = node .inputs [1 ].get_tensor_value ()
820
+ non_zero_axes = list (range (rank ))
821
+ if const_pads is not None :
822
+ non_zero_axes = []
823
+ for i in range (rank ):
824
+ if const_pads [i ] != 0 or const_pads [i + rank ] != 0 :
825
+ non_zero_axes .append (i )
826
+
827
+ inc_pads = [0 ] * (rank * 2 )
828
+ for a in non_zero_axes :
829
+ inc_pads [a ] = 1
830
+ inc_pads [a + rank ] = 1
831
+
832
+ if ctx .opset < 11 :
833
+ padded_inp = ctx .make_node ("Pad" , [node .input [0 ]], attr = {'mode' : 'constant' , 'pads' : inc_pads }).output [0 ]
834
+ else :
835
+ pad1_pads_const = ctx .make_const (utils .make_name ("pad1_pads" ), np .array (inc_pads , np .int64 )).output [0 ]
836
+ padded_inp = ctx .make_node ("Pad" , [node .input [0 ], pad1_pads_const ], attr = {'mode' : 'constant' }).output [0 ]
837
+ ctx .replace_input (node , node .input [0 ], padded_inp , 0 )
838
+ ctx .update_node_shape_dtype (node , override = True )
839
+
840
+ output = node .output [0 ]
841
+ shape = ctx .make_node ("Shape" , [output ]).output [0 ]
842
+ dims = ctx .make_node ("Split" , [shape ], output_count = rank ).output
843
+ two_false = ctx .make_const (utils .make_name ("two_false" ), np .array ([False , False ], np .bool )).output [0 ]
844
+ inv_second = ctx .make_const (utils .make_name ("inv_second" ), np .array ([1 , - 1 ], np .int64 )).output [0 ]
845
+ dec_second = ctx .make_const (utils .make_name ("dec_second" ), np .array ([0 , 1 ], np .int64 )).output [0 ]
846
+ for a in non_zero_axes :
847
+ one_tensor = helper .make_tensor ("value" , onnx_pb .TensorProto .BOOL , dims = [1 ], vals = [1 ])
848
+ ones_of_shape = ctx .make_node ("ConstantOfShape" , [dims [a ]], attr = {'value' : one_tensor }).output [0 ]
849
+ if const_pads is not None :
850
+ to_remove_val = [const_pads [a ], - 1 - const_pads [a + rank ]]
851
+ to_remove = ctx .make_const (utils .make_name ("to_remove" ), np .array (to_remove_val , np .int64 )).output [0 ]
852
+ else :
853
+ pads_idx = ctx .make_const (utils .make_name ("pads_idx" ), np .array ([a , a + rank ], np .int64 )).output [0 ]
854
+ pads_vals = ctx .make_node ("Gather" , [node .input [1 ], pads_idx ]).output [0 ]
855
+ pads_inv_second = ctx .make_node ("Mul" , [pads_vals , inv_second ]).output [0 ]
856
+ to_remove = ctx .make_node ("Sub" , [pads_inv_second , dec_second ]).output [0 ]
857
+ scatter_op = "ScatterElements" if ctx .opset >= 11 else "Scatter"
858
+ dims_to_keep = ctx .make_node (scatter_op , [ones_of_shape , to_remove , two_false ]).output [0 ]
859
+ compress = ctx .make_node ("Compress" , [output , dims_to_keep ], attr = {'axis' : a })
860
+ output = compress .output [0 ]
861
+ ctx .replace_all_inputs (node .output [0 ], output , consumers )
862
+ ctx .set_shape (output , output_shape )
863
+
803
864
@classmethod
804
865
def version_1 (cls , ctx , node , ** kwargs ):
805
866
node .type = "Pad"
@@ -812,7 +873,7 @@ def version_1(cls, ctx, node, **kwargs):
812
873
if mode :
813
874
mode = mode .s .decode ("utf-8" ).lower ()
814
875
node .set_attr ("mode" , mode )
815
- if mode not in [None , "constant" , "reflect" ]:
876
+ if mode not in [None , "symmetric" , " constant" , "reflect" ]:
816
877
raise ValueError (mode + " pad mode is not supported" )
817
878
818
879
if mode in [None , "constant" ] and len (node .input ) == 3 :
@@ -836,21 +897,29 @@ def version_1(cls, ctx, node, **kwargs):
836
897
ctx .set_dtype (cast_back_node .output [0 ], origin_dtype )
837
898
ctx .copy_shape (node .name , cast_back_node .output [0 ])
838
899
900
+ if mode == "symmetric" :
901
+ cls .convert_symmetric_pads (ctx , node )
902
+
839
903
@classmethod
840
904
def version_11 (cls , ctx , node , ** kwargs ):
841
905
mode = node .get_attr ("mode" )
842
906
if mode :
843
907
mode = mode .s .decode ("utf-8" ).lower ()
844
908
node .set_attr ("mode" , mode )
845
- if mode not in [None , "constant" , "reflect" ]:
909
+ if mode not in [None , "symmetric" , " constant" , "reflect" ]:
846
910
raise ValueError (mode + " pad mode is not supported" )
847
911
848
- # pads must be int64.
849
- if ctx .get_dtype (node .input [1 ]) != onnx_pb .TensorProto .INT64 :
850
- ctx .insert_new_node_on_input (node , "Cast" , node .input [1 ], to = onnx_pb .TensorProto .INT64 )
851
- ctx .insert_new_node_on_input (node , "Transpose" , node .input [1 ])
852
- shape_const = ctx .make_const (utils .make_name (node .name ), np .array ([- 1 ]).astype (np .int64 ))
853
- ctx .insert_new_node_on_input (node , "Reshape" , [node .input [1 ], shape_const .name ])
912
+ if not node .inputs [1 ].is_const ():
913
+ # pads must be int64.
914
+ if ctx .get_dtype (node .input [1 ]) != onnx_pb .TensorProto .INT64 :
915
+ ctx .insert_new_node_on_input (node , "Cast" , node .input [1 ], to = onnx_pb .TensorProto .INT64 )
916
+ ctx .insert_new_node_on_input (node , "Transpose" , node .input [1 ])
917
+ shape_const = ctx .make_const (utils .make_name (node .name ), np .array ([- 1 ]).astype (np .int64 ))
918
+ ctx .insert_new_node_on_input (node , "Reshape" , [node .input [1 ], shape_const .name ])
919
+ else :
920
+ paddings = node .inputs [1 ].get_tensor_value (as_list = False ).astype (np .int64 ).transpose ().flatten ()
921
+ pad_const = ctx .make_const (utils .make_name ("pad_const" ), paddings )
922
+ ctx .replace_input (node , node .input [1 ], pad_const .output [0 ], 1 )
854
923
855
924
origin_dtype = ctx .get_dtype (node .output [0 ])
856
925
if origin_dtype not in [TensorProto .FLOAT , TensorProto .DOUBLE ,
@@ -865,6 +934,9 @@ def version_11(cls, ctx, node, **kwargs):
865
934
ctx .set_dtype (cast_back_node .output [0 ], origin_dtype )
866
935
ctx .copy_shape (node .name , cast_back_node .output [0 ])
867
936
937
+ if mode == "symmetric" :
938
+ cls .convert_symmetric_pads (ctx , node )
939
+
868
940
869
941
@tf_op (["FusedBatchNorm" , "FusedBatchNormV2" , "FusedBatchNormV3" ])
870
942
class BatchNorm :
0 commit comments