Skip to content

Commit 633a49b

Browse files
committed
relax cond_rewriter shape check
1 parent a8328b5 commit 633a49b

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

tf2onnx/rewriter/cond_rewriter.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
logging.basicConfig(level=logging.INFO)
1717
log = logging.getLogger("tf2onnx.rewriter.cond_rewriter_base")
1818

19+
1920
# pylint: disable=missing-docstring,unused-argument,broad-except
2021

2122
class BranchType(Enum):
@@ -29,6 +30,7 @@ class BranchType(Enum):
2930

3031
class CondBranchContext:
3132
"""Context for each branch graph"""
33+
3234
def __init__(self):
3335
self.output = []
3436
self.nodes = set()
@@ -37,12 +39,12 @@ def __init__(self):
3739
class CondContext:
3840
def __init__(self, cond_scope, pred_input, true_branch_context,
3941
false_branch_context, switchs, merges):
40-
self.cond_scope = cond_scope # name scope for this tf.cond
41-
self.pred_input = pred_input # condition input
42+
self.cond_scope = cond_scope # name scope for this tf.cond
43+
self.pred_input = pred_input # condition input
4244
self.true_branch_context = true_branch_context
4345
self.false_branch_context = false_branch_context
4446
self.switchs = set(switchs)
45-
self.merges = merges # list of merges in order
47+
self.merges = merges # list of merges in order
4648

4749

4850
class CondRewriter:
@@ -114,7 +116,7 @@ def _get_output_shape_dtype(self, cond_context):
114116
true_dtype = self.g.get_dtype(true_output)
115117
false_shape = self.g.get_shape(false_output)
116118
false_dtype = self.g.get_dtype(false_output)
117-
if true_shape != false_shape:
119+
if not utils.are_shapes_compatible(true_shape, false_shape):
118120
raise RuntimeError(
119121
"the shape of outputs {} and {} mismatch: {}, {}".format(
120122
true_output,
@@ -132,7 +134,7 @@ def _get_output_shape_dtype(self, cond_context):
132134
false_dtype
133135
)
134136
)
135-
output_shapes.append(true_shape)
137+
output_shapes.append(utils.merge_shapes(true_shape, false_shape))
136138
output_dtypes.append(true_dtype)
137139
return output_shapes, output_dtypes
138140

@@ -243,11 +245,13 @@ def _trace_back_from_one_merge(self, merge_node):
243245
merge_input_1 = merge_node.input[0]
244246
merge_input_2 = merge_node.input[1]
245247
switchs = set()
248+
246249
def stop_at_switch(node):
247250
if self._is_switch(node):
248251
switchs.add(node)
249252
return False
250253
return True
254+
251255
branch_nodes_1 = self.g.extract_sub_graph_nodes(
252256
[merge_input_1],
253257
stop_at_switch

0 commit comments

Comments
 (0)