6
6
pat = SpaceToBatchND->DepthwiseConv2dNative->BatchToSpaceND
7
7
"""
8
8
9
+ import numpy as np
9
10
from tf2onnx .graph_matcher import OpTypePattern , GraphMatcher
10
11
11
12
# pylint: disable=invalid-name,unused-argument,missing-docstring, unused-variable
@@ -18,30 +19,30 @@ def rewrite_conv_dilations(g, ops):
18
19
OpTypePattern ("SpaceToBatchND" , name = "space_to_batch" , inputs = [
19
20
OpTypePattern ("*" ),
20
21
OpTypePattern ("Const|ConstV2" ),
21
- OpTypePattern ("Const|ConstV2 " ),
22
+ OpTypePattern ("* " ),
22
23
]),
23
24
OpTypePattern ("*" ),
24
25
]),
25
26
OpTypePattern ("Const|ConstV2" ),
26
- OpTypePattern ("Const|ConstV2 " ),
27
+ OpTypePattern ("* " ),
27
28
])
28
29
pattern2 = \
29
30
OpTypePattern ("BatchToSpaceND" , name = "batch_to_space" , inputs = [
30
31
OpTypePattern ("Squeeze" , name = "squeeze" , inputs = [
31
- OpTypePattern ("DepthwiseConv2dNative| Conv2D|Conv3D " , name = "conv" , inputs = [
32
+ OpTypePattern ("Conv2D" , name = "conv" , inputs = [
32
33
OpTypePattern ("ExpandDims" , name = "expand" , inputs = [
33
34
OpTypePattern ("SpaceToBatchND" , name = "space_to_batch" , inputs = [
34
35
OpTypePattern ("*" ),
35
36
OpTypePattern ("Const|ConstV2" ),
36
- OpTypePattern ("Const|ConstV2 " ),
37
+ OpTypePattern ("* " ),
37
38
]),
38
39
OpTypePattern ("Const|ConstV2" ),
39
40
]),
40
41
OpTypePattern ("*" ),
41
42
]),
42
43
]),
43
44
OpTypePattern ("Const|ConstV2" ),
44
- OpTypePattern ("Const|ConstV2 " ),
45
+ OpTypePattern ("* " ),
45
46
])
46
47
47
48
for pattern in [pattern1 , pattern2 ]:
@@ -61,9 +62,7 @@ def rewrite_conv_dilations(g, ops):
61
62
continue
62
63
63
64
block_shape1 = space_to_batch .inputs [1 ].get_tensor_value (as_list = True )
64
- paddings = space_to_batch .inputs [2 ].get_tensor_value (as_list = True )
65
65
block_shape2 = batch_to_space .inputs [1 ].get_tensor_value (as_list = True )
66
- crops = batch_to_space .inputs [2 ].get_tensor_value (as_list = True )
67
66
68
67
if block_shape1 != block_shape2 :
69
68
continue
@@ -79,15 +78,27 @@ def rewrite_conv_dilations(g, ops):
79
78
if conv .get_attr_value ("padding" ) != b"VALID" :
80
79
continue
81
80
82
-
83
- base_start_pad = [p [0 ] for p in paddings ]
84
- if any (c [0 ] != 0 for c in crops ):
85
- continue
86
- base_end_pad = [p [1 ] - c [1 ] for p , c in zip (paddings , crops )]
87
- if not all (0 <= p [1 ] - bp < bs for p , bp , bs in zip (paddings , base_end_pad , block_shape1 )):
88
- continue
81
+ if space_to_batch .inputs [2 ].is_const () and batch_to_space .inputs [2 ].is_const ():
82
+ paddings = space_to_batch .inputs [2 ].get_tensor_value (as_list = True )
83
+ crops = batch_to_space .inputs [2 ].get_tensor_value (as_list = True )
84
+ base_start_pad = [p [0 ] for p in paddings ]
85
+ if any (c [0 ] != 0 for c in crops ):
86
+ continue
87
+ base_end_pad = [p [1 ] - c [1 ] for p , c in zip (paddings , crops )]
88
+ if not all (0 <= p [1 ] - bp < bs for p , bp , bs in zip (paddings , base_end_pad , block_shape1 )):
89
+ continue
90
+ pad_mode = "EXPLICIT"
91
+ if is_conv_1d :
92
+ base_start_pad = [0 ] + base_start_pad
93
+ base_end_pad = [0 ] + base_end_pad
94
+ base_pad_flat = [0 , 0 ] + [x for s , e in zip (base_start_pad , base_end_pad ) for x in [s , e ]] + [0 , 0 ]
95
+ else :
96
+ pad_mode = determine_pad_mode (space_to_batch .inputs [2 ])
97
+ if pad_mode is None :
98
+ continue
89
99
90
100
if is_conv_1d :
101
+ block_shape1 = [1 ] + block_shape1
91
102
inp = space_to_batch .input [0 ]
92
103
g .replace_inputs (expand , [inp , expand .input [1 ]])
93
104
g .copy_shape (batch_to_space .output [0 ], squeeze .output [0 ])
@@ -96,20 +107,39 @@ def rewrite_conv_dilations(g, ops):
96
107
g .set_shape (squeeze .input [0 ], squeeze_out_shape [:1 ] + [1 ] + squeeze_out_shape [1 :])
97
108
expand_inp_shape = g .get_shape (expand .input [0 ])
98
109
g .set_shape (expand .output [0 ], expand_inp_shape [:1 ] + [1 ] + expand_inp_shape [1 :])
99
-
100
- base_start_pad = [0 ] + base_start_pad
101
- base_end_pad = [0 ] + base_end_pad
102
- block_shape1 = [1 ] + block_shape1
103
110
else :
104
111
inp = space_to_batch .input [0 ]
105
112
kernel = conv .input [1 ]
106
113
g .replace_inputs (conv , [inp , kernel ])
107
114
g .copy_shape (batch_to_space .output [0 ], conv .output [0 ])
108
115
g .replace_all_inputs (batch_to_space .output [0 ], conv .output [0 ])
109
116
110
- base_pad_flat = [0 , 0 ] + [x for s , e in zip (base_start_pad , base_end_pad ) for x in [s , e ]] + [0 , 0 ]
111
117
conv .set_attr ("dilations" , [1 ] + block_shape1 + [1 ])
112
- conv .set_attr ("explicit_paddings" , base_pad_flat )
113
- conv .set_attr ("padding" , "EXPLICIT" )
118
+ conv .set_attr ("padding" , pad_mode )
119
+ if pad_mode == "EXPLICIT" :
120
+ conv .set_attr ("explicit_paddings" , base_pad_flat )
114
121
115
122
return g .get_nodes ()
123
+
124
+ def determine_pad_mode (paddings_inp_node ):
125
+ tensor_ops = set (["Concat" , "ConcatV2" , "ConcatV3" , "StridedSlice" , "Pack" , "ExpandDims" , "Identity" ])
126
+ while paddings_inp_node .type in tensor_ops :
127
+ non_const = [inp for inp in paddings_inp_node .inputs if not inp .is_const ()]
128
+ if len (non_const ) == 0 :
129
+ return None
130
+ paddings_inp_node = non_const [0 ]
131
+ if paddings_inp_node .type == "FloorMod" :
132
+ return "VALID"
133
+ if paddings_inp_node .type in ["Add" , "AddV2" ]:
134
+ if paddings_inp_node .inputs [0 ].type == "FloorMod" :
135
+ pad = paddings_inp_node .inputs [1 ]
136
+ elif paddings_inp_node .inputs [1 ].type == "FloorMod" :
137
+ pad = paddings_inp_node .inputs [0 ]
138
+ else :
139
+ return None
140
+ if pad .is_const ():
141
+ if np .any (pad .get_tensor_value (as_list = False )):
142
+ #return "SAME" ORT doesn't implement dilations for SAME autopadding yet
143
+ return None
144
+ return "VALID"
145
+ return None
0 commit comments