Skip to content

Commit 89d1edb

Browse files
committed
Fix memory optimization with dist train (#13535)
* show detail error log on ci * test * fix memopt and dist * update apispec * will fix different batch issue test=develop
1 parent 8d16de7 commit 89d1edb

File tree

4 files changed

+112
-24
lines changed

4 files changed

+112
-24
lines changed

paddle/fluid/API.spec

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ paddle.fluid.DistributeTranspiler.get_pserver_programs ArgSpec(args=['self', 'en
2121
paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None))
2222
paddle.fluid.DistributeTranspiler.get_trainer_program ArgSpec(args=['self', 'wait_port'], varargs=None, keywords=None, defaults=(True,))
2323
paddle.fluid.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode', 'startup_program', 'current_endpoint'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True, None, '127.0.0.1:6174'))
24-
paddle.fluid.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level'], varargs=None, keywords=None, defaults=(None, False, 0))
24+
paddle.fluid.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level', 'skip_grads'], varargs=None, keywords=None, defaults=(None, False, 0, False))
2525
paddle.fluid.release_memory ArgSpec(args=['input_program', 'skip_opt_set'], varargs=None, keywords=None, defaults=(None,))
2626
paddle.fluid.DistributeTranspilerConfig.__init__
2727
paddle.fluid.ParallelExecutor.__init__ ArgSpec(args=['self', 'use_cuda', 'loss_name', 'main_program', 'share_vars_from', 'exec_strategy', 'build_strategy', 'num_trainers', 'trainer_id', 'scope'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 1, 0, None))
@@ -304,7 +304,7 @@ paddle.fluid.transpiler.DistributeTranspiler.get_pserver_programs ArgSpec(args=[
304304
paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None))
305305
paddle.fluid.transpiler.DistributeTranspiler.get_trainer_program ArgSpec(args=['self', 'wait_port'], varargs=None, keywords=None, defaults=(True,))
306306
paddle.fluid.transpiler.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode', 'startup_program', 'current_endpoint'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True, None, '127.0.0.1:6174'))
307-
paddle.fluid.transpiler.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level'], varargs=None, keywords=None, defaults=(None, False, 0))
307+
paddle.fluid.transpiler.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level', 'skip_grads'], varargs=None, keywords=None, defaults=(None, False, 0, False))
308308
paddle.fluid.transpiler.release_memory ArgSpec(args=['input_program', 'skip_opt_set'], varargs=None, keywords=None, defaults=(None,))
309309
paddle.fluid.transpiler.HashName.__init__ ArgSpec(args=['self', 'pserver_endpoints'], varargs=None, keywords=None, defaults=None)
310310
paddle.fluid.transpiler.HashName.dispatch ArgSpec(args=['self', 'varlist'], varargs=None, keywords=None, defaults=None)

python/paddle/fluid/tests/unittests/test_dist_base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ def run_pserver(self, args):
4747
import paddle
4848
import paddle.fluid as fluid
4949
self.get_model(batch_size=2)
50-
if args.mem_opt:
51-
fluid.memory_optimize(fluid.default_main_program())
50+
# NOTE: pserver should not call memory optimize
5251
t = self.get_transpiler(args.trainer_id,
5352
fluid.default_main_program(), args.endpoints,
5453
args.trainers, args.sync_mode)
@@ -68,7 +67,7 @@ def run_trainer(self, use_cuda, args):
6867
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
6968
self.get_model(batch_size=2)
7069
if args.mem_opt:
71-
fluid.memory_optimize(fluid.default_main_program())
70+
fluid.memory_optimize(fluid.default_main_program(), skip_grads=True)
7271
if args.is_dist:
7372
t = self.get_transpiler(args.trainer_id,
7473
fluid.default_main_program(),

python/paddle/fluid/tests/unittests/test_dist_se_resnext.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,13 @@ def test_dist_train(self):
2525
self.check_with_place("dist_se_resnext.py", delta=1e-7)
2626

2727

28-
# TODO(typhoonzero): fix this test
29-
# class TestDistseResnXt2x2WithMemopt(TestDistBase):
30-
# def _setup_config(self):
31-
# self._sync_mode = True
32-
# self._mem_opt = True
33-
34-
# def test_dist_train(self):
35-
# self.check_with_place("dist_se_resnext.py", delta=1e-7)
28+
class TestDistseResnXt2x2WithMemopt(TestDistBase):
29+
def _setup_config(self):
30+
self._sync_mode = True
31+
self._mem_opt = True
32+
33+
def test_dist_train(self):
34+
self.check_with_place("dist_se_resnext.py", delta=100)
3635

3736

3837
class TestDistSeResneXt2x2Async(TestDistBase):

python/paddle/fluid/transpiler/memory_optimization_transpiler.py

Lines changed: 101 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414

1515
from __future__ import print_function
1616

17-
from collections import defaultdict, OrderedDict, Callable
17+
from collections import defaultdict, MutableSet
1818
from .. import core
1919
from ... import compat as cpt
20-
from ..framework import Program, default_main_program, Parameter, Variable
20+
from ..framework import Program, default_main_program, Parameter, Variable, core
2121
from ..backward import _rename_arg_
2222
from functools import reduce
2323
from six.moves import range
@@ -44,17 +44,82 @@
4444
PRINT_LOG = False
4545

4646

47+
class OrderedSet(MutableSet):
48+
def __init__(self, iterable=None):
49+
self.end = end = []
50+
end += [None, end, end] # sentinel node for doubly linked list
51+
self.map = {} # key --> [key, prev, next]
52+
if iterable is not None:
53+
self |= iterable
54+
55+
def __len__(self):
56+
return len(self.map)
57+
58+
def __contains__(self, key):
59+
return key in self.map
60+
61+
def add(self, key):
62+
if key not in self.map:
63+
end = self.end
64+
curr = end[1]
65+
curr[2] = end[1] = self.map[key] = [key, curr, end]
66+
67+
def update(self, other):
68+
for e in other:
69+
self.add(e)
70+
71+
def discard(self, key):
72+
if key in self.map:
73+
key, prev, next = self.map.pop(key)
74+
prev[2] = next
75+
next[1] = prev
76+
77+
def remove(self, key):
78+
self.discard(key)
79+
80+
def __iter__(self):
81+
end = self.end
82+
curr = end[2]
83+
while curr is not end:
84+
yield curr[0]
85+
curr = curr[2]
86+
87+
def __reversed__(self):
88+
end = self.end
89+
curr = end[1]
90+
while curr is not end:
91+
yield curr[0]
92+
curr = curr[1]
93+
94+
def pop(self, last=True):
95+
if not self:
96+
raise KeyError('set is empty')
97+
key = self.end[1][0] if last else self.end[2][0]
98+
self.discard(key)
99+
return key
100+
101+
def __repr__(self):
102+
if not self:
103+
return '%s()' % (self.__class__.__name__, )
104+
return '%s(%r)' % (self.__class__.__name__, list(self))
105+
106+
def __eq__(self, other):
107+
if isinstance(other, OrderedSet):
108+
return len(self) == len(other) and list(self) == list(other)
109+
return set(self) == set(other)
110+
111+
47112
class ControlFlowGraph(object):
48113
def __init__(self, program, ops, forward_num, skip_opt):
49114
self._program = program
50115
self._ops = ops
51116
self._forward_num = forward_num
52-
self._successors = defaultdict(set)
53-
self._presuccessors = defaultdict(set)
54-
self._uses = defaultdict(set)
55-
self._defs = defaultdict(set)
56-
self._live_in = defaultdict(set)
57-
self._live_out = defaultdict(set)
117+
self._successors = defaultdict(OrderedSet)
118+
self._presuccessors = defaultdict(OrderedSet)
119+
self._uses = defaultdict(OrderedSet)
120+
self._defs = defaultdict(OrderedSet)
121+
self._live_in = defaultdict(OrderedSet)
122+
self._live_out = defaultdict(OrderedSet)
58123
self._skip_opt = skip_opt
59124
self.pool = []
60125

@@ -116,7 +181,7 @@ def _fill_pool(self, i, is_forward):
116181
# NOTE: must sort the in_diff set for cases that get different cache var.
117182
# FIXME(typhoonzero): maybe use a "sorted set" is better than this.
118183
can_optimize = [
119-
x for x in sorted(list(in_diff))
184+
x for x in in_diff
120185
if self._check_var_validity(block_desc, x, is_forward)
121186
]
122187
if can_optimize:
@@ -224,7 +289,7 @@ def compare_shape(x_shape, cache_shape, opt_level):
224289
if self.pool:
225290
# NOTE: must sort the in_diff set for cases that get different cache var.
226291
defs_can_optimize = [
227-
x for x in sorted(list(self._defs[i]))
292+
x for x in self._defs[i]
228293
if self._check_var_validity(block_desc, x, is_forward)
229294
]
230295
out_pair = [
@@ -381,7 +446,19 @@ def _get_cfgs(input_program):
381446
return cfgs
382447

383448

384-
def memory_optimize(input_program, skip_opt_set=None, print_log=False, level=0):
449+
def _is_opt_role_op(op):
450+
op_maker = core.op_proto_and_checker_maker
451+
optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize
452+
if op_maker.kOpRoleAttrName() in op.attr_names and \
453+
int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(optimize_role):
454+
return True
455+
456+
457+
def memory_optimize(input_program,
458+
skip_opt_set=None,
459+
print_log=False,
460+
level=0,
461+
skip_grads=False):
385462
"""Optimize memory by reusing var memory.
386463
387464
Note: it doesn't not support subblock nested in subblock.
@@ -398,6 +475,19 @@ def memory_optimize(input_program, skip_opt_set=None, print_log=False, level=0):
398475
raise ValueError("only support opt_level 0 or 1.")
399476
global PRINT_LOG
400477
PRINT_LOG = print_log
478+
if skip_grads:
479+
grad_set = set()
480+
OP_ROLE_VAR = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
481+
for op in input_program.global_block().ops:
482+
if _is_opt_role_op(op):
483+
if op.attr(OP_ROLE_VAR):
484+
grad_name = op.attr(OP_ROLE_VAR)[1]
485+
grad_set.add(grad_name)
486+
if not skip_opt_set:
487+
skip_opt_set = grad_set
488+
else:
489+
skip_opt_set.update(grad_set)
490+
401491
cfgs = _get_cfgs(input_program)
402492
for cfg in cfgs:
403493
cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level)

0 commit comments

Comments
 (0)