@@ -223,15 +223,15 @@ def get_cfgs(input_program):
223
223
224
224
# Find while/while_grad block pair
225
225
for grad_id in while_grad_sub_block_ids :
226
- parent_id = pdesc .block (grad_id ).parent
227
- if parent_id in while_sub_block_ids :
228
- while_block_id_pair .append ((parent_id , grad_id ))
229
- while_sub_block_ids .remove (parent_id )
226
+ forward_id = pdesc .block (grad_id ).get_forward_block_idx ()
227
+ if forward_id in while_sub_block_ids :
228
+ while_block_id_pair .append ((forward_id , grad_id ))
229
+ while_sub_block_ids .remove (forward_id )
230
230
231
231
# Get while/while_grad block ops
232
- for parent_id , grad_id in while_block_id_pair :
232
+ for forward_id , grad_id in while_block_id_pair :
233
233
while_block_ops = []
234
- while_block = pdesc .block (parent_id )
234
+ while_block = pdesc .block (forward_id )
235
235
while_block_op_size = while_block .op_size ()
236
236
for i in range (while_block_op_size ):
237
237
while_block_ops .append (while_block .op (i ))
@@ -242,21 +242,21 @@ def get_cfgs(input_program):
242
242
while_block_ops .append (while_grad_block .op (i ))
243
243
244
244
while_op_output = set ()
245
- while_op_output .update (while_op_dict [parent_id ].output_arg_names ())
245
+ while_op_output .update (while_op_dict [forward_id ].output_arg_names ())
246
246
while_op_output .update (while_op_dict [grad_id ].output_arg_names ())
247
247
248
248
ops_list .append ((while_block_ops , while_block_op_size , while_op_output ))
249
249
250
250
# Process rest while block ops
251
- for parent_id in while_sub_block_ids :
251
+ for forward_id in while_sub_block_ids :
252
252
while_block_ops = []
253
- while_block = pdesc .block (parent_id )
253
+ while_block = pdesc .block (forward_id )
254
254
while_block_op_size = while_block .op_size ()
255
255
for i in range (while_block_op_size ):
256
256
while_block_ops .append (while_block .op (i ))
257
257
258
258
while_op_output = set ()
259
- while_op_output .update (while_op_dict [parent_id ].output_arg_names ())
259
+ while_op_output .update (while_op_dict [forward_id ].output_arg_names ())
260
260
261
261
ops_list .append ((while_block_ops , while_block_op_size , while_op_output ))
262
262
0 commit comments