Skip to content

Commit 39f6274

Browse files
committed
follow comment, optimize code
1 parent e05f4df commit 39f6274

File tree

2 files changed

+14
-19
lines changed

2 files changed

+14
-19
lines changed

python/paddle/fluid/distribute_transpiler.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -371,22 +371,18 @@ def get_pserver_program(self, endpoint):
371371
recv_inputs.append(single_trainer_var)
372372

373373
# step 3
374-
# each optimization op will has a optimize block
375-
optimize_block = None
376-
377-
# step 4
378374
# Create a union-find data structure from optimize ops,
379375
# If two ops are connected, we could add these two ops
380376
# into one set.
381377
ufind = self._create_ufind(self.optimize_ops)
382-
# step 4.2
378+
# step 3.2
383379
# Iterate through the ops and append optimize op which
384380
# located on current pserver
385381
opt_op_on_pserver = []
386382
for _, op in enumerate(self.optimize_ops):
387383
if self._is_opt_op(op) and self._is_opt_op_on_pserver(endpoint, op):
388384
opt_op_on_pserver.append(op)
389-
# step 4.3
385+
# step 3.3
390386
# Iterate through the ops, and if an op and the optimize ops
391387
# which located on current pserver are in one set, then
392388
# append it into the sub program.
@@ -420,23 +416,17 @@ def __append_optimize_op__(op, block):
420416
self._append_pserver_non_opt_ops(block, op)
421417

422418
# append lr decay ops to the child block if exists
423-
lr_decay_block = None
424419
lr_ops = self._get_lr_ops()
425420
if len(lr_ops) > 0:
426-
lr_decay_block = pserver_program.create_block(GLOBAL_BLOCK_IDX)
421+
lr_decay_block = pserver_program.create_block(
422+
pserver_program.num_blocks - 1)
427423
for _, op in enumerate(lr_ops):
428424
self._append_pserver_non_opt_ops(lr_decay_block, op)
429425

430426
# append op to the current block
431-
per_opt_block = None
432-
pre_block_idx = GLOBAL_BLOCK_IDX
433-
if lr_decay_block is not None:
434-
pre_block_idx = lr_decay_block.idx
427+
pre_block_idx = pserver_program.num_blocks - 1
435428
for idx, opt_op in enumerate(opt_op_on_pserver):
436429
per_opt_block = pserver_program.create_block(pre_block_idx)
437-
if optimize_block is None:
438-
# first optimize block
439-
optimize_block = per_opt_block
440430
for _, op in enumerate(self.optimize_ops):
441431
# optimizer is connected to itself
442432
if ufind.is_connected(op, opt_op) and op not in global_ops:
@@ -445,9 +435,10 @@ def __append_optimize_op__(op, block):
445435
# append global ops
446436
opt_state_block = None
447437
if global_ops:
448-
opt_state_block = pserver_program.create_block(per_opt_block.idx)
449-
for glb_op in global_ops:
450-
__append_optimize_op__(glb_op, opt_state_block)
438+
opt_state_block = pserver_program.create_block(
439+
pserver_program.num_blocks - 1)
440+
for glb_op in global_ops:
441+
__append_optimize_op__(glb_op, opt_state_block)
451442

452443
# NOT USED: single block version:
453444
#
@@ -481,7 +472,7 @@ def __append_optimize_op__(op, block):
481472
inputs={'X': recv_inputs},
482473
outputs={},
483474
attrs={
484-
"OptimizeBlock": optimize_block,
475+
"OptimizeBlock": pserver_program.block(1),
485476
"endpoint": endpoint,
486477
"Fanin": self.trainer_num,
487478
"PrefetchBlock": prefetch_block

python/paddle/fluid/framework.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,6 +1107,10 @@ def parse_from_string(binary_str):
11071107
def random_seed(self):
11081108
return self._seed
11091109

1110+
@property
1111+
def num_blocks(self):
1112+
return self.desc.num_blocks()
1113+
11101114
@random_seed.setter
11111115
def random_seed(self, seed):
11121116
if not isinstance(seed, int):

0 commit comments

Comments
 (0)