Skip to content

Commit a29cb4b

Browse files
velconiagongweibao
authored andcommitted
Fix decay bug (#11520)
* Add sub_blocks of lr_decay_op to pserver_prog after distribute_transpiler * Remove unused logs and logics * 1. Add ops to new block (considering the nested block condition) 2. Follow the original hierarchy of blocks 3. Change the function's name and remove debug lines
1 parent e8f5757 commit a29cb4b

File tree

3 files changed

+63
-8
lines changed

3 files changed

+63
-8
lines changed

paddle/fluid/framework/executor.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,13 +295,14 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
295295

296296
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
297297
const ProgramDesc& program, int block_id) {
298-
auto* ctx = new ExecutorPrepareContext(program, block_id);
298+
std::unique_ptr<ExecutorPrepareContext> ctx(
299+
new ExecutorPrepareContext(program, block_id));
299300
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
300301
auto& block = program.Block(block_id);
301302
for (auto& op_desc : block.AllOps()) {
302303
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
303304
}
304-
return std::unique_ptr<ExecutorPrepareContext>(ctx);
305+
return ctx;
305306
}
306307

307308
std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(

python/paddle/fluid/framework.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,13 @@ def attr_type(self, name):
644644

645645
def set_attr(self, name, val):
646646
self.attrs[name] = val
647-
self.desc.set_attr(name, val)
647+
if isinstance(val, Block):
648+
self.desc.set_block_attr(name, val.desc)
649+
elif isinstance(val, core.BlockDesc) or \
650+
isinstance(val, core.ProgramDesc):
651+
self.desc.set_serialized_attr(name, val.serialize_to_string())
652+
else:
653+
self.desc.set_attr(name, val)
648654

649655
@property
650656
def attr_names(self):

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
2525
2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
2626
3. modify trainer program add split_op to each grad variable.
27-
4. append send_op to send splited variables to server and
27+
4. append send_op to send splited variables to server and
2828
5. add recv_op to fetch params(splited blocks or origin param) from server.
2929
6. append concat_op to merge splited blocks to update local weights.
3030
@@ -44,7 +44,7 @@
4444
from ps_dispatcher import RoundRobin, HashName, PSDispatcher
4545
from .. import core, framework
4646
from ..framework import Program, default_main_program, \
47-
default_startup_program, \
47+
default_startup_program, Block, \
4848
Variable, Parameter, grad_var_name
4949
from details import *
5050

@@ -471,21 +471,47 @@ def __append_optimize_op__(op, block, grad_to_block_id, merged_var):
471471
self._append_pserver_ops(block, op, endpoint, grad_to_block_id,
472472
self.origin_program, merged_var)
473473
else:
474-
self._append_pserver_non_opt_ops(block, op, endpoint)
474+
self._append_pserver_non_opt_ops(block, op)
475475

476476
def __op_have_grad_input__(op):
477477
for varname in op.input_arg_names:
478478
if varname.find("@GRAD") >= 0:
479479
return varname
480480
return ""
481481

482+
def __clone_lr_op_sub_block__(op, program, new_block):
483+
if not op.has_attr('sub_block'):
484+
return
485+
486+
origin_block_desc = op.attr('sub_block')
487+
origin_block = self.origin_program.block(origin_block_desc.id)
488+
assert isinstance(origin_block, Block)
489+
# we put the new sub block to new block to follow the block
490+
# hierarchy of the original blocks
491+
new_sub_block = program.create_block(new_block.idx)
492+
493+
# clone vars
494+
for var in origin_block.vars:
495+
new_sub_block.clone_variable(var)
496+
497+
# clone ops
498+
for op in origin_block.ops:
499+
self._clone_lr_op(program, new_sub_block, op)
500+
# clone sub_block of op
501+
__clone_lr_op_sub_block__(op, program, new_sub_block)
502+
503+
# reset the block of op
504+
op.set_attr('sub_block', new_sub_block)
505+
482506
# append lr decay ops to the child block if exists
483507
lr_ops = self._get_lr_ops()
484508
if len(lr_ops) > 0:
485509
lr_decay_block = pserver_program.create_block(
486510
pserver_program.num_blocks - 1)
487511
for _, op in enumerate(lr_ops):
488-
self._append_pserver_non_opt_ops(lr_decay_block, op, endpoint)
512+
self._append_pserver_non_opt_ops(lr_decay_block, op)
513+
# append sub blocks to pserver_program in lr_decay_op
514+
__clone_lr_op_sub_block__(op, pserver_program, lr_decay_block)
489515

490516
# append op to the current block
491517
grad_to_block_id = []
@@ -1116,7 +1142,29 @@ def _is_splited_grad_var(self, var, var_dict):
11161142
break
11171143
return grad_block
11181144

1119-
def _append_pserver_non_opt_ops(self, optimize_block, opt_op, endpoint):
1145+
def _clone_lr_op(self, program, block, op):
1146+
inputs = self._get_input_map_from_op(
1147+
self.origin_program.global_block().vars, op)
1148+
for key, varlist in inputs.iteritems():
1149+
if not isinstance(varlist, list):
1150+
varlist = [varlist]
1151+
for var in varlist:
1152+
if var not in program.global_block().vars:
1153+
block.clone_variable(var)
1154+
1155+
outputs = self._get_output_map_from_op(
1156+
self.origin_program.global_block().vars, op)
1157+
for key, varlist in outputs.iteritems():
1158+
if not isinstance(varlist, list):
1159+
varlist = [varlist]
1160+
for var in varlist:
1161+
if var not in program.global_block().vars:
1162+
block.clone_variable(var)
1163+
1164+
block.append_op(
1165+
type=op.type, inputs=inputs, outputs=outputs, attrs=op.attrs)
1166+
1167+
def _append_pserver_non_opt_ops(self, optimize_block, opt_op):
11201168
program = optimize_block.program
11211169
# Append the ops for parameters that do not need to be optimized/updated
11221170
inputs = self._get_input_map_from_op(

0 commit comments

Comments
 (0)