Skip to content

Commit 7907b6a

Browse files
committed
split optimization ops on pserver to independenty blocks
1 parent d89a306 commit 7907b6a

File tree

1 file changed

+25
-15
lines changed

1 file changed

+25
-15
lines changed

python/paddle/fluid/distribute_transpiler.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
2626
RPC_CLIENT_VAR_NAME = "RPC_CLIENT_VAR"
2727

28+
GLOBAL_BLOCK_IDX = 0
29+
2830

2931
class VarBlock:
3032
def __init__(self, varname, offset, size):
@@ -368,8 +370,8 @@ def get_pserver_program(self, endpoint):
368370
else:
369371
recv_inputs.append(single_trainer_var)
370372

371-
# step3
372-
optimize_block = pserver_program.create_block(0)
373+
optimize_block = None
374+
373375
# step 4
374376
# Create a union-find data structure from optimize ops,
375377
# If two ops are connected, we could add these two ops
@@ -415,29 +417,34 @@ def __append_optimize_op__(op, block):
415417
else:
416418
self._append_pserver_non_opt_ops(block, op)
417419

418-
append_block = optimize_block
419420
# append lr decay ops to the child block if exists
421+
lr_decay_block = None
420422
lr_ops = self._get_lr_ops()
421423
if len(lr_ops) > 0:
424+
lr_decay_block = pserver_program.create_block(GLOBAL_BLOCK_IDX)
422425
for _, op in enumerate(lr_ops):
423-
self._append_pserver_non_opt_ops(append_block, op)
424-
425-
append_block = pserver_program.create_block(append_block.idx)
426+
self._append_pserver_non_opt_ops(lr_decay_block, op)
426427

427428
# append op to the current block
428-
per_opt_block = append_block
429+
per_opt_block = None
430+
pre_block_idx = GLOBAL_BLOCK_IDX
431+
if lr_decay_block is not None:
432+
pre_block_idx = lr_decay_block.idx
429433
for idx, opt_op in enumerate(opt_op_on_pserver):
434+
per_opt_block = pserver_program.create_block(pre_block_idx)
435+
if optimize_block is None:
436+
optimize_block = per_opt_block
430437
for _, op in enumerate(self.optimize_ops):
431438
# optimizer is connected to itself
432-
if ufind.is_connected(op, opt_op) and \
433-
op not in global_ops:
439+
if ufind.is_connected(op, opt_op) and op not in global_ops:
434440
__append_optimize_op__(op, per_opt_block)
435-
if idx == len(opt_op_on_pserver) - 1 and global_ops:
436-
per_opt_block = pserver_program.create_block(append_block.idx)
437441

438442
# append global ops
443+
opt_state_block = None
444+
if global_ops:
445+
opt_state_block = pserver_program.create_block(per_opt_block.idx)
439446
for glb_op in global_ops:
440-
__append_optimize_op__(glb_op, per_opt_block)
447+
__append_optimize_op__(glb_op, opt_state_block)
441448

442449
# NOT USED: single block version:
443450
#
@@ -451,10 +458,11 @@ def __append_optimize_op__(op, block):
451458
prefetch_block = None
452459
if self.has_distributed_lookup_table:
453460
pserver_index = self.pserver_endpoints.index(endpoint)
454-
self._create_table_optimize_block(pserver_index, pserver_program,
455-
append_block)
461+
table_opt_block = self._create_table_optimize_block(
462+
pserver_index, pserver_program, opt_state_block or
463+
pserver_program.global_block())
456464
prefetch_block = self._create_prefetch_block(
457-
pserver_index, pserver_program, optimize_block)
465+
pserver_index, pserver_program, table_opt_block)
458466

459467
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will
460468
# not be executed, so it's safe to use optimize_block to hold the place
@@ -724,6 +732,8 @@ def _clone_var(block, var, persistable=True):
724732
outputs=outputs,
725733
attrs=table_opt_op.attrs)
726734

735+
return table_opt_block
736+
727737
# ====================== private transpiler functions =====================
728738
def _create_vars_from_blocklist(self,
729739
program,

0 commit comments

Comments
 (0)