@@ -89,10 +89,11 @@ def run(self):
89
89
continue
90
90
91
91
self ._cut_off_connection (cond_context )
92
- self ._create_if_node (cond_context )
92
+ if_node = self ._create_if_node (cond_context )
93
93
# remove nodes in If branches explicitly
94
- for n in list (cond_context .true_branch_context .nodes ) + list (cond_context .false_branch_context .nodes ):
95
- self .g .remove_node (n .name )
94
+ if if_node is not None :
95
+ for n in list (cond_context .true_branch_context .nodes ) + list (cond_context .false_branch_context .nodes ):
96
+ self .g .remove_node (n .name )
96
97
logger .debug ("cond pre rewrite done" )
97
98
98
99
return self .g .get_nodes ()
@@ -136,6 +137,19 @@ def _get_output_shape_dtype(self, cond_context):
136
137
137
138
def _create_if_node (self , cond_context ):
138
139
output_shapes , output_dtypes = self ._get_output_shape_dtype (cond_context )
140
+ pred_node = self .g .get_node_by_output (cond_context .pred_input )
141
+ while pred_node .type == "Identity" :
142
+ pred_node = pred_node .inputs [0 ]
143
+ if pred_node .is_const ():
144
+ # Constant folding for if node
145
+ if pred_node .get_tensor_value ():
146
+ branch_outputs = cond_context .true_branch_context .output
147
+ else :
148
+ branch_outputs = cond_context .false_branch_context .output
149
+ for merge , out in zip (cond_context .merges , branch_outputs ):
150
+ self .g .replace_all_inputs (merge .output [0 ], out )
151
+ return None
152
+
139
153
true_graph = utils .construct_graph_from_nodes (
140
154
self .g ,
141
155
list (cond_context .true_branch_context .nodes ),
0 commit comments