16
16
logging .basicConfig (level = logging .INFO )
17
17
log = logging .getLogger ("tf2onnx.rewriter.cond_rewriter_base" )
18
18
19
+
19
20
# pylint: disable=missing-docstring,unused-argument,broad-except
20
21
21
22
class BranchType (Enum ):
@@ -29,6 +30,7 @@ class BranchType(Enum):
29
30
30
31
class CondBranchContext :
31
32
"""Context for each branch graph"""
33
+
32
34
def __init__ (self ):
33
35
self .output = []
34
36
self .nodes = set ()
@@ -37,12 +39,12 @@ def __init__(self):
37
39
class CondContext :
38
40
def __init__ (self , cond_scope , pred_input , true_branch_context ,
39
41
false_branch_context , switchs , merges ):
40
- self .cond_scope = cond_scope # name scope for this tf.cond
41
- self .pred_input = pred_input # condition input
42
+ self .cond_scope = cond_scope # name scope for this tf.cond
43
+ self .pred_input = pred_input # condition input
42
44
self .true_branch_context = true_branch_context
43
45
self .false_branch_context = false_branch_context
44
46
self .switchs = set (switchs )
45
- self .merges = merges # list of merges in order
47
+ self .merges = merges # list of merges in order
46
48
47
49
48
50
class CondRewriter :
@@ -114,7 +116,7 @@ def _get_output_shape_dtype(self, cond_context):
114
116
true_dtype = self .g .get_dtype (true_output )
115
117
false_shape = self .g .get_shape (false_output )
116
118
false_dtype = self .g .get_dtype (false_output )
117
- if true_shape != false_shape :
119
+ if not utils . are_shapes_compatible ( true_shape , false_shape ) :
118
120
raise RuntimeError (
119
121
"the shape of outputs {} and {} mismatch: {}, {}" .format (
120
122
true_output ,
@@ -132,7 +134,7 @@ def _get_output_shape_dtype(self, cond_context):
132
134
false_dtype
133
135
)
134
136
)
135
- output_shapes .append (true_shape )
137
+ output_shapes .append (utils . merge_shapes ( true_shape , false_shape ) )
136
138
output_dtypes .append (true_dtype )
137
139
return output_shapes , output_dtypes
138
140
@@ -243,11 +245,13 @@ def _trace_back_from_one_merge(self, merge_node):
243
245
merge_input_1 = merge_node .input [0 ]
244
246
merge_input_2 = merge_node .input [1 ]
245
247
switchs = set ()
248
+
246
249
def stop_at_switch (node ):
247
250
if self ._is_switch (node ):
248
251
switchs .add (node )
249
252
return False
250
253
return True
254
+
251
255
branch_nodes_1 = self .g .extract_sub_graph_nodes (
252
256
[merge_input_1 ],
253
257
stop_at_switch
0 commit comments