Skip to content

Commit dd7a48b

Browse files
authored
Merge pull request #10123 from jacquesqiao/split-optimize-op-into-signle-blocks
split optimization ops on pserver to independenty blocks
2 parents d89a306 + ba1e68d commit dd7a48b

File tree

2 files changed

+27
-22
lines changed

2 files changed

+27
-22
lines changed

python/paddle/fluid/distribute_transpiler.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -368,21 +368,19 @@ def get_pserver_program(self, endpoint):
368368
else:
369369
recv_inputs.append(single_trainer_var)
370370

371-
# step3
372-
optimize_block = pserver_program.create_block(0)
373-
# step 4
371+
# step 3
374372
# Create a union-find data structure from optimize ops,
375373
# If two ops are connected, we could add these two ops
376374
# into one set.
377375
ufind = self._create_ufind(self.optimize_ops)
378-
# step 4.2
376+
# step 3.2
379377
# Iterate through the ops and append optimize op which
380378
# located on current pserver
381379
opt_op_on_pserver = []
382380
for _, op in enumerate(self.optimize_ops):
383381
if self._is_opt_op(op) and self._is_opt_op_on_pserver(endpoint, op):
384382
opt_op_on_pserver.append(op)
385-
# step 4.3
383+
# step 3.3
386384
# Iterate through the ops, and if an op and the optimize ops
387385
# which located on current pserver are in one set, then
388386
# append it into the sub program.
@@ -415,29 +413,30 @@ def __append_optimize_op__(op, block):
415413
else:
416414
self._append_pserver_non_opt_ops(block, op)
417415

418-
append_block = optimize_block
419416
# append lr decay ops to the child block if exists
420417
lr_ops = self._get_lr_ops()
421418
if len(lr_ops) > 0:
419+
lr_decay_block = pserver_program.create_block(
420+
pserver_program.num_blocks - 1)
422421
for _, op in enumerate(lr_ops):
423-
self._append_pserver_non_opt_ops(append_block, op)
424-
425-
append_block = pserver_program.create_block(append_block.idx)
422+
self._append_pserver_non_opt_ops(lr_decay_block, op)
426423

427424
# append op to the current block
428-
per_opt_block = append_block
425+
pre_block_idx = pserver_program.num_blocks - 1
429426
for idx, opt_op in enumerate(opt_op_on_pserver):
427+
per_opt_block = pserver_program.create_block(pre_block_idx)
430428
for _, op in enumerate(self.optimize_ops):
431429
# optimizer is connected to itself
432-
if ufind.is_connected(op, opt_op) and \
433-
op not in global_ops:
430+
if ufind.is_connected(op, opt_op) and op not in global_ops:
434431
__append_optimize_op__(op, per_opt_block)
435-
if idx == len(opt_op_on_pserver) - 1 and global_ops:
436-
per_opt_block = pserver_program.create_block(append_block.idx)
437432

438433
# append global ops
439-
for glb_op in global_ops:
440-
__append_optimize_op__(glb_op, per_opt_block)
434+
opt_state_block = None
435+
if global_ops:
436+
opt_state_block = pserver_program.create_block(
437+
pserver_program.num_blocks - 1)
438+
for glb_op in global_ops:
439+
__append_optimize_op__(glb_op, opt_state_block)
441440

442441
# NOT USED: single block version:
443442
#
@@ -451,10 +450,10 @@ def __append_optimize_op__(op, block):
451450
prefetch_block = None
452451
if self.has_distributed_lookup_table:
453452
pserver_index = self.pserver_endpoints.index(endpoint)
454-
self._create_table_optimize_block(pserver_index, pserver_program,
455-
append_block)
453+
table_opt_block = self._create_table_optimize_block(
454+
pserver_index, pserver_program, pre_block_idx)
456455
prefetch_block = self._create_prefetch_block(
457-
pserver_index, pserver_program, optimize_block)
456+
pserver_index, pserver_program, table_opt_block)
458457

459458
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will
460459
# not be executed, so it's safe to use optimize_block to hold the place
@@ -470,7 +469,7 @@ def __append_optimize_op__(op, block):
470469
inputs={'X': recv_inputs},
471470
outputs={},
472471
attrs={
473-
"OptimizeBlock": optimize_block,
472+
"OptimizeBlock": pserver_program.block(1),
474473
"endpoint": endpoint,
475474
"Fanin": self.trainer_num,
476475
"PrefetchBlock": prefetch_block
@@ -663,7 +662,7 @@ def _create_prefetch_block(self, pserver_index, pserver_program,
663662
return prefetch_block
664663

665664
def _create_table_optimize_block(self, pserver_index, pserver_program,
666-
append_block):
665+
pre_block_idx):
667666
def _clone_var(block, var, persistable=True):
668667
assert isinstance(var, Variable)
669668
return block.create_var(
@@ -700,7 +699,7 @@ def _clone_var(block, var, persistable=True):
700699
op for op in self.optimize_ops
701700
if op.input("Param")[0] == self.table_name
702701
][0]
703-
table_opt_block = pserver_program.create_block(append_block.idx)
702+
table_opt_block = pserver_program.create_block(pre_block_idx)
704703
# only support sgd now
705704
assert table_opt_op.type == "sgd"
706705

@@ -724,6 +723,8 @@ def _clone_var(block, var, persistable=True):
724723
outputs=outputs,
725724
attrs=table_opt_op.attrs)
726725

726+
return table_opt_block
727+
727728
# ====================== private transpiler functions =====================
728729
def _create_vars_from_blocklist(self,
729730
program,

python/paddle/fluid/framework.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,6 +1107,10 @@ def parse_from_string(binary_str):
11071107
def random_seed(self):
11081108
return self._seed
11091109

1110+
@property
1111+
def num_blocks(self):
1112+
return self.desc.num_blocks()
1113+
11101114
@random_seed.setter
11111115
def random_seed(self, seed):
11121116
if not isinstance(seed, int):

0 commit comments

Comments
 (0)