Skip to content

Commit 76e3ec6

Browse files
committed
fix cloned op
1 parent 8cb494f commit 76e3ec6

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def __op_have_grad_input__(op):
396396
return varname
397397
return ""
398398

399-
def __clone_lr_op_sub_block__(op, program, new_block):
399+
def __clone_lr_op_sub_block__(op, program, lr_block):
400400
if not op.has_attr('sub_block'):
401401
return
402402

@@ -405,17 +405,17 @@ def __clone_lr_op_sub_block__(op, program, new_block):
405405
assert isinstance(origin_block, Block)
406406
# we put the new sub block to new block to follow the block
407407
# hierarchy of the original blocks
408-
new_sub_block = program.create_block(new_block.idx)
408+
new_sub_block = program.create_block(lr_block.idx)
409409

410410
# clone vars
411411
for var in origin_block.vars:
412412
new_sub_block.clone_variable(var)
413413

414414
# clone ops
415-
for op in origin_block.ops:
416-
self._clone_lr_op(program, new_sub_block, op)
415+
for origin_op in origin_block.ops:
416+
cloned_op = self._clone_lr_op(program, new_sub_block, origin_op)
417417
# clone sub_block of op
418-
__clone_lr_op_sub_block__(op, program, new_sub_block)
418+
__clone_lr_op_sub_block__(cloned_op, program, new_sub_block)
419419

420420
# reset the block of op
421421
op.set_attr('sub_block', new_sub_block)
@@ -429,9 +429,10 @@ def __clone_lr_op_sub_block__(op, program, new_block):
429429
pserver_program.num_blocks - 1)
430430
optimize_blocks.append(lr_decay_block)
431431
for _, op in enumerate(lr_ops):
432-
self._append_pserver_non_opt_ops(lr_decay_block, op)
432+
cloned_op = self._append_pserver_non_opt_ops(lr_decay_block, op)
433433
# append sub blocks to pserver_program in lr_decay_op
434-
__clone_lr_op_sub_block__(op, pserver_program, lr_decay_block)
434+
__clone_lr_op_sub_block__(cloned_op, pserver_program,
435+
lr_decay_block)
435436

436437
# append op to the current block
437438
grad_to_block_id = []
@@ -1214,7 +1215,7 @@ def _clone_lr_op(self, program, block, op):
12141215
if var not in program.global_block().vars:
12151216
block.clone_variable(var)
12161217

1217-
block.append_op(
1218+
return block.append_op(
12181219
type=op.type, inputs=inputs, outputs=outputs, attrs=op.attrs)
12191220

12201221
def _append_pserver_non_opt_ops(self, optimize_block, opt_op):
@@ -1252,7 +1253,7 @@ def _append_pserver_non_opt_ops(self, optimize_block, opt_op):
12521253
elif not program.global_block().vars.has_key(var.name):
12531254
program.global_block().clone_variable(var)
12541255

1255-
optimize_block.append_op(
1256+
return optimize_block.append_op(
12561257
type=opt_op.type,
12571258
inputs=inputs,
12581259
outputs=outputs,

0 commit comments

Comments
 (0)