Skip to content

Commit abe32a4

Browse files
authored
"fix memory optimize bug in lr decay" (#12299)
1 parent c2fe067 commit abe32a4

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

python/paddle/fluid/transpiler/memory_optimization_transpiler.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,8 @@ def _process_sub_block_pair(pdesc, sub_block_pair):
324324
sub_op_output = set()
325325
sub_op_output.update(sub_op_dict[fwd_id].output_arg_names())
326326
sub_op_output.update(sub_op_dict[grad_id].output_arg_names())
327+
sub_op_output.update(sub_op_dict[fwd_id].input_arg_names())
328+
sub_op_output.update(sub_op_dict[grad_id].input_arg_names())
327329
ops_list.append((sub_block_ops, block_op_size, sub_op_output))
328330

329331
# Process rest fwd_op block ops
@@ -335,6 +337,7 @@ def _process_sub_block_pair(pdesc, sub_block_pair):
335337
sub_block_ops.append(sub_block.op(i))
336338
sub_op_output = set()
337339
sub_op_output.update(sub_op_dict[fwd_id].output_arg_names())
340+
sub_op_output.update(sub_op_dict[fwd_id].input_arg_names())
338341
ops_list.append((sub_block_ops, sub_block_op_size, sub_op_output))
339342
return ops_list
340343

@@ -349,13 +352,17 @@ def _get_cfgs(input_program):
349352
pdesc = input_program.get_desc()
350353
block_desc = pdesc.block(0)
351354
op_size = block_desc.op_size()
352-
# Get global block ops
353-
ops_list.append(
354-
([block_desc.op(i) for i in range(op_size)], op_size, set()))
355355

356356
# Only process one level of nested subblock.
357357
ops_list.extend(_process_sub_block_pair(pdesc, SUB_BLOCK_PAIR))
358358

359+
skip_opt_set = set()
360+
for _, _, skip_opt in ops_list:
361+
skip_opt_set.update(skip_opt)
362+
363+
# Get global block ops
364+
ops_list.insert(
365+
0, ([block_desc.op(i) for i in range(op_size)], op_size, skip_opt_set))
359366
cfgs = [
360367
ControlFlowGraph(input_program, ops, forward_num, skip_opt)
361368
for ops, forward_num, skip_opt in ops_list

0 commit comments

Comments
 (0)