Skip to content

Commit 6148265

Browse files
committed
enhance shape_inference
1 parent 49d7396 commit 6148265

File tree

1 file changed

+60
-31
lines changed

1 file changed

+60
-31
lines changed

tf2onnx/shape_inference.py

Lines changed: 60 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
from __future__ import print_function
1010
from __future__ import unicode_literals
1111
import logging
12+
import numpy as np
1213
from onnx import onnx_pb
1314
from tf2onnx import utils
15+
from tf2onnx.rewriter import rnn_utils
1416

1517
# pylint: disable=logging-not-lazy,missing-docstring,consider-swap-variables
1618

1719

18-
1920
logger = logging.getLogger(__name__)
2021

2122
direct_ops = [
@@ -115,41 +116,16 @@ def infer_shape_for_node(g, node):
115116
return False
116117
return set_shape_from_input(g, shape_node.input[0], node.output[0])
117118

118-
if node.type == "ConcatV2":
119-
axis_node = node.inputs[-1]
120-
if not axis_node.is_const():
121-
return False
122-
123-
axis = axis_node.get_tensor_value()
124-
val = 0
125-
data_inputs = node.input[:-1]
126-
for i in data_inputs:
127-
s = g.get_shape(i)
128-
if s is None:
129-
return False
130-
131-
if s[axis] == -1:
132-
val = -1
133-
break
134-
val += s[axis]
135-
136-
s1 = g.get_shape(node.input[0])
137-
if axis < 0:
138-
axis += len(s1)
139-
new_shape = s1[:axis] + [val]
140-
if axis < len(s1) - 1:
141-
new_shape += s1[axis + 1:]
142-
143-
g.set_shape(node.output[0], new_shape)
144-
logger.debug("set ConcatV2 node [%s] with new shape %s", node.output[0], new_shape)
145-
return True
146-
147119
if node.type == "Gather":
148120
# uses the follwing link to know how to infer shape of output
149121
# https://www.tensorflow.org/api_docs/python/tf/gather
150122
shape_params = g.get_shape(node.input[0])
151123
shape_indices = g.get_shape(node.input[1])
152-
axis = node.input[2].get_tensor_value()
124+
# in lower tf version, gather only has 2 inputs
125+
if len(node.input) == 3:
126+
axis = node.input[2].get_tensor_value()
127+
else:
128+
axis = 0
153129

154130
shape = shape_params[:axis] + shape_indices + shape_params[axis + 1:]
155131
g.set_shape(node.output[0], shape)
@@ -194,6 +170,29 @@ def infer_shape_for_node(g, node):
194170
logger.debug("set [%s] with new shape %s", node.output[0], new_shape)
195171
return True
196172

173+
if node.type == "Unpack":
174+
input_shape = g.get_shape(node.input[0])
175+
if input_shape is None:
176+
return False
177+
178+
axis = node.get_attr("axis").i
179+
axis = axis if axis >= 0 else axis + len(input_shape)
180+
# the link below says that the rank of output is "rank(input) -1",
181+
# from this statement "num" must equal to input_shape[axis], and if not tf will throw a runtime error
182+
# https://www.tensorflow.org/api_docs/python/tf/unstack
183+
new_shape = input_shape[:axis] + input_shape[axis + 1:]
184+
for output in node.output:
185+
g.set_shape(output, new_shape)
186+
logger.debug("set %s node [%s] with new shape %s", node.type, output, new_shape)
187+
return True
188+
189+
if node.type in ["Minimum", "Maximum"]:
190+
# ops that are elementwise and support broadcasting
191+
input_shapes = [g.get_shape(node) for node in node.input]
192+
new_shape = broadcast_shape_inference(*input_shapes)
193+
g.set_shape(node.output[0], new_shape)
194+
return True
195+
197196
return False
198197

199198

@@ -213,6 +212,36 @@ def infer_input_shapes(g, node):
213212

214213

215214
def infer_output_shapes_with_partial_inputs(g, node):
215+
# output shape of concat op: only the dim val of concatenated dim will be changed
216+
# so only partial(at least one) input shapes need to be known to infer output shape of concat node
217+
if rnn_utils.is_concat_op(node):
218+
data_inputs = node.input[:-1]
219+
input_shapes = [g.get_shape(node) for node in data_inputs]
220+
input_shapes = [shape for shape in input_shapes if shape is not None]
221+
if len(input_shapes) == 0:
222+
logger.debug("all input shapes of concat node %s are None, can't infer its output shape", node.name)
223+
return False
224+
225+
new_shape = input_shapes[0]
226+
axis_node = node.inputs[-1]
227+
rank = len(new_shape)
228+
if not axis_node.is_const():
229+
g.set_shape(node.output[0], [-1] * rank)
230+
return True
231+
232+
axis = axis_node.get_tensor_value()
233+
axis = axis if axis >= 0 else axis + rank
234+
new_shape[axis] = -1
235+
if len(input_shapes) == len(data_inputs): # all input shapes are known
236+
concat_dim_vals = list(np.array(input_shapes)[:, axis])
237+
# only when inputs' shape are known, then val of concat dim can be calculated
238+
if concat_dim_vals.count(-1) == 0:
239+
new_shape[axis] = sum(concat_dim_vals)
240+
241+
g.set_shape(node.output[0], new_shape)
242+
logger.debug("set Concat node [%s] with new shape %s", node.output[0], new_shape)
243+
return True
244+
216245
if node.type == "Merge":
217246
s1 = g.get_shape(node.input[0])
218247
s2 = g.get_shape(node.input[1])

0 commit comments

Comments
 (0)