9
9
from __future__ import print_function
10
10
from __future__ import unicode_literals
11
11
import logging
12
+ import numpy as np
12
13
from onnx import onnx_pb
13
14
from tf2onnx import utils
15
+ from tf2onnx .rewriter import rnn_utils
14
16
15
17
# pylint: disable=logging-not-lazy,missing-docstring,consider-swap-variables
16
18
17
19
18
-
19
20
logger = logging .getLogger (__name__ )
20
21
21
22
direct_ops = [
@@ -115,41 +116,16 @@ def infer_shape_for_node(g, node):
115
116
return False
116
117
return set_shape_from_input (g , shape_node .input [0 ], node .output [0 ])
117
118
118
- if node .type == "ConcatV2" :
119
- axis_node = node .inputs [- 1 ]
120
- if not axis_node .is_const ():
121
- return False
122
-
123
- axis = axis_node .get_tensor_value ()
124
- val = 0
125
- data_inputs = node .input [:- 1 ]
126
- for i in data_inputs :
127
- s = g .get_shape (i )
128
- if s is None :
129
- return False
130
-
131
- if s [axis ] == - 1 :
132
- val = - 1
133
- break
134
- val += s [axis ]
135
-
136
- s1 = g .get_shape (node .input [0 ])
137
- if axis < 0 :
138
- axis += len (s1 )
139
- new_shape = s1 [:axis ] + [val ]
140
- if axis < len (s1 ) - 1 :
141
- new_shape += s1 [axis + 1 :]
142
-
143
- g .set_shape (node .output [0 ], new_shape )
144
- logger .debug ("set ConcatV2 node [%s] with new shape %s" , node .output [0 ], new_shape )
145
- return True
146
-
147
119
if node .type == "Gather" :
148
120
# uses the follwing link to know how to infer shape of output
149
121
# https://www.tensorflow.org/api_docs/python/tf/gather
150
122
shape_params = g .get_shape (node .input [0 ])
151
123
shape_indices = g .get_shape (node .input [1 ])
152
- axis = node .input [2 ].get_tensor_value ()
124
+ # in lower tf version, gather only has 2 inputs
125
+ if len (node .input ) == 3 :
126
+ axis = node .input [2 ].get_tensor_value ()
127
+ else :
128
+ axis = 0
153
129
154
130
shape = shape_params [:axis ] + shape_indices + shape_params [axis + 1 :]
155
131
g .set_shape (node .output [0 ], shape )
@@ -194,6 +170,29 @@ def infer_shape_for_node(g, node):
194
170
logger .debug ("set [%s] with new shape %s" , node .output [0 ], new_shape )
195
171
return True
196
172
173
+ if node .type == "Unpack" :
174
+ input_shape = g .get_shape (node .input [0 ])
175
+ if input_shape is None :
176
+ return False
177
+
178
+ axis = node .get_attr ("axis" ).i
179
+ axis = axis if axis >= 0 else axis + len (input_shape )
180
+ # the link below says that the rank of output is "rank(input) -1",
181
+ # from this statement "num" must equal to input_shape[axis], and if not tf will throw a runtime error
182
+ # https://www.tensorflow.org/api_docs/python/tf/unstack
183
+ new_shape = input_shape [:axis ] + input_shape [axis + 1 :]
184
+ for output in node .output :
185
+ g .set_shape (output , new_shape )
186
+ logger .debug ("set %s node [%s] with new shape %s" , node .type , output , new_shape )
187
+ return True
188
+
189
+ if node .type in ["Minimum" , "Maximum" ]:
190
+ # ops that are elementwise and support broadcasting
191
+ input_shapes = [g .get_shape (node ) for node in node .input ]
192
+ new_shape = broadcast_shape_inference (* input_shapes )
193
+ g .set_shape (node .output [0 ], new_shape )
194
+ return True
195
+
197
196
return False
198
197
199
198
@@ -213,6 +212,36 @@ def infer_input_shapes(g, node):
213
212
214
213
215
214
def infer_output_shapes_with_partial_inputs (g , node ):
215
+ # output shape of concat op: only the dim val of concatenated dim will be changed
216
+ # so only partial(at least one) input shapes need to be known to infer output shape of concat node
217
+ if rnn_utils .is_concat_op (node ):
218
+ data_inputs = node .input [:- 1 ]
219
+ input_shapes = [g .get_shape (node ) for node in data_inputs ]
220
+ input_shapes = [shape for shape in input_shapes if shape is not None ]
221
+ if len (input_shapes ) == 0 :
222
+ logger .debug ("all input shapes of concat node %s are None, can't infer its output shape" , node .name )
223
+ return False
224
+
225
+ new_shape = input_shapes [0 ]
226
+ axis_node = node .inputs [- 1 ]
227
+ rank = len (new_shape )
228
+ if not axis_node .is_const ():
229
+ g .set_shape (node .output [0 ], [- 1 ] * rank )
230
+ return True
231
+
232
+ axis = axis_node .get_tensor_value ()
233
+ axis = axis if axis >= 0 else axis + rank
234
+ new_shape [axis ] = - 1
235
+ if len (input_shapes ) == len (data_inputs ): # all input shapes are known
236
+ concat_dim_vals = list (np .array (input_shapes )[:, axis ])
237
+ # only when inputs' shape are known, then val of concat dim can be calculated
238
+ if concat_dim_vals .count (- 1 ) == 0 :
239
+ new_shape [axis ] = sum (concat_dim_vals )
240
+
241
+ g .set_shape (node .output [0 ], new_shape )
242
+ logger .debug ("set Concat node [%s] with new shape %s" , node .output [0 ], new_shape )
243
+ return True
244
+
216
245
if node .type == "Merge" :
217
246
s1 = g .get_shape (node .input [0 ])
218
247
s2 = g .get_shape (node .input [1 ])
0 commit comments