@@ -324,6 +324,8 @@ def _process_sub_block_pair(pdesc, sub_block_pair):
324
324
sub_op_output = set ()
325
325
sub_op_output .update (sub_op_dict [fwd_id ].output_arg_names ())
326
326
sub_op_output .update (sub_op_dict [grad_id ].output_arg_names ())
327
+ sub_op_output .update (sub_op_dict [fwd_id ].input_arg_names ())
328
+ sub_op_output .update (sub_op_dict [grad_id ].input_arg_names ())
327
329
ops_list .append ((sub_block_ops , block_op_size , sub_op_output ))
328
330
329
331
# Process rest fwd_op block ops
@@ -335,6 +337,7 @@ def _process_sub_block_pair(pdesc, sub_block_pair):
335
337
sub_block_ops .append (sub_block .op (i ))
336
338
sub_op_output = set ()
337
339
sub_op_output .update (sub_op_dict [fwd_id ].output_arg_names ())
340
+ sub_op_output .update (sub_op_dict [fwd_id ].input_arg_names ())
338
341
ops_list .append ((sub_block_ops , sub_block_op_size , sub_op_output ))
339
342
return ops_list
340
343
@@ -349,13 +352,17 @@ def _get_cfgs(input_program):
349
352
pdesc = input_program .get_desc ()
350
353
block_desc = pdesc .block (0 )
351
354
op_size = block_desc .op_size ()
352
- # Get global block ops
353
- ops_list .append (
354
- ([block_desc .op (i ) for i in range (op_size )], op_size , set ()))
355
355
356
356
# Only process one level of nested subblock.
357
357
ops_list .extend (_process_sub_block_pair (pdesc , SUB_BLOCK_PAIR ))
358
358
359
+ skip_opt_set = set ()
360
+ for _ , _ , skip_opt in ops_list :
361
+ skip_opt_set .update (skip_opt )
362
+
363
+ # Get global block ops
364
+ ops_list .insert (
365
+ 0 , ([block_desc .op (i ) for i in range (op_size )], op_size , skip_opt_set ))
359
366
cfgs = [
360
367
ControlFlowGraph (input_program , ops , forward_num , skip_opt )
361
368
for ops , forward_num , skip_opt in ops_list
0 commit comments