Skip to content

Commit 7ffd50b

Browse files
committed
Merge branch 'feature/add_fwd_block_id' of github.com:reyoung/Paddle into feature/add_fwd_block_id
2 parents aaf818f + bf92706 commit 7ffd50b

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

python/paddle/v2/fluid/memory_optimization_transpiler.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -223,15 +223,15 @@ def get_cfgs(input_program):
223223

224224
# Find while/while_grad block pair
225225
for grad_id in while_grad_sub_block_ids:
226-
parent_id = pdesc.block(grad_id).parent
227-
if parent_id in while_sub_block_ids:
228-
while_block_id_pair.append((parent_id, grad_id))
229-
while_sub_block_ids.remove(parent_id)
226+
forward_id = pdesc.block(grad_id).get_forward_block_idx()
227+
if forward_id in while_sub_block_ids:
228+
while_block_id_pair.append((forward_id, grad_id))
229+
while_sub_block_ids.remove(forward_id)
230230

231231
# Get while/while_grad block ops
232-
for parent_id, grad_id in while_block_id_pair:
232+
for forward_id, grad_id in while_block_id_pair:
233233
while_block_ops = []
234-
while_block = pdesc.block(parent_id)
234+
while_block = pdesc.block(forward_id)
235235
while_block_op_size = while_block.op_size()
236236
for i in range(while_block_op_size):
237237
while_block_ops.append(while_block.op(i))
@@ -242,21 +242,21 @@ def get_cfgs(input_program):
242242
while_block_ops.append(while_grad_block.op(i))
243243

244244
while_op_output = set()
245-
while_op_output.update(while_op_dict[parent_id].output_arg_names())
245+
while_op_output.update(while_op_dict[forward_id].output_arg_names())
246246
while_op_output.update(while_op_dict[grad_id].output_arg_names())
247247

248248
ops_list.append((while_block_ops, while_block_op_size, while_op_output))
249249

250250
# Process rest while block ops
251-
for parent_id in while_sub_block_ids:
251+
for forward_id in while_sub_block_ids:
252252
while_block_ops = []
253-
while_block = pdesc.block(parent_id)
253+
while_block = pdesc.block(forward_id)
254254
while_block_op_size = while_block.op_size()
255255
for i in range(while_block_op_size):
256256
while_block_ops.append(while_block.op(i))
257257

258258
while_op_output = set()
259-
while_op_output.update(while_op_dict[parent_id].output_arg_names())
259+
while_op_output.update(while_op_dict[forward_id].output_arg_names())
260260

261261
ops_list.append((while_block_ops, while_block_op_size, while_op_output))
262262

0 commit comments

Comments
 (0)