Skip to content

Commit 2f38264

Browse files
authored
Merge pull request #9905 from panyx0718/mem-opt
Polish memory optimization transpiler
2 parents b48cf17 + d4024a6 commit 2f38264

File tree

1 file changed

+88
-47
lines changed

1 file changed

+88
-47
lines changed

python/paddle/fluid/memory_optimization_transpiler.py

Lines changed: 88 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,20 @@
2929
core.VarDesc.VarType.BOOL: 1
3030
}
3131

32-
sub_block_ops = [
32+
SUB_BLOCK_OPS = [
3333
"while", "while_grad", "parallel_do", "parallel_do_grad",
3434
"conditional_block", "conditional_block_grad"
3535
]
3636

37+
SUB_BLOCK_PAIR = [("while", "while_grad"), ("parallel_do", "parallel_do_grad"),
38+
("conditional_block", "conditional_block_grad")]
39+
3740
PRINT_LOG = False
3841

3942

4043
class ControlFlowGraph(object):
41-
def __init__(self, Program, ops, forward_num, skip_opt):
42-
self._program = Program
44+
def __init__(self, program, ops, forward_num, skip_opt):
45+
self._program = program
4346
self._ops = ops
4447
self._forward_num = forward_num
4548
self._successors = defaultdict(set)
@@ -51,14 +54,19 @@ def __init__(self, Program, ops, forward_num, skip_opt):
5154
self._skip_opt = skip_opt
5255

5356
def _add_connections(self, connections):
57+
"""Populates _successors and _presuccessors for two neighbor nodes."""
5458
for node1, node2 in connections:
5559
self._add(node1, node2)
5660

5761
def _add(self, node1, node2):
5862
self._successors[node1].add(node2)
5963
self._presuccessors[node2].add(node1)
6064

65+
# TODO(panyx0718): We need to have a unified way of building intermediate
66+
# representation.
6167
def _build_graph(self):
68+
"""Build a graph based on op sequence.
69+
"""
6270
self.op_size = len(self._ops)
6371
op_node_connections = [(i, i + 1) for i in range(self.op_size - 1)]
6472
self._add_connections(op_node_connections)
@@ -82,22 +90,23 @@ def _update_graph(self, old_name, new_name, begin_idx=0):
8290
self._live_out[i].add(new_name)
8391

8492
def _reach_fixed_point(self, live_in, live_out):
93+
"""Check if the liveness set has stablized."""
8594
if len(live_in) != len(self._live_in):
8695
return False
8796
if len(live_out) != len(self._live_out):
8897
return False
8998
for i in range(self.op_size):
90-
if live_in[i] != self._live_in[i]:
91-
return False
92-
for i in range(self.op_size):
93-
if live_out[i] != self._live_out[i]:
99+
if (live_in[i] != self._live_in[i] or
100+
live_out[i] != self._live_out[i]):
94101
return False
95102
return True
96103

97104
def _dataflow_analyze(self):
98105
self._build_graph()
99106
live_in = defaultdict(set)
100107
live_out = defaultdict(set)
108+
# Repeatedly apply liveness updates until the algorithm stablize
109+
# on a complete set live input vars and live output vars.
101110
while True:
102111
for i in range(self.op_size, 0, -1):
103112
live_in[i] = set(self._live_in[i])
@@ -141,6 +150,8 @@ def _check_var_validity(self, block_desc, x, is_forward):
141150
return False
142151
return True
143152

153+
# TODO(panyx0718): This needs to be less hacky. It seems memory optimization
154+
# doesn't consider vars copied between cpu and gpu.
144155
def _update_skip_opt_set(self):
145156
for i in range(self.op_size):
146157
op = self._ops[i]
@@ -154,7 +165,7 @@ def release_memory(self):
154165
bwd_id = 0
155166
for i in range(self.op_size):
156167
op = self._ops[i]
157-
if op.type() in sub_block_ops:
168+
if op.type() in SUB_BLOCK_OPS:
158169
continue
159170
block_desc = op.block()
160171
is_forward = i < self._forward_num
@@ -177,24 +188,25 @@ def memory_optimize(self, level=0):
177188
def compare_shape(x_shape, cache_shape, opt_level):
178189
if opt_level == 0:
179190
return x_shape == cache_shape
180-
if opt_level == 1:
191+
elif opt_level == 1:
181192
if (x_shape[0] == -1) ^ (cache_shape[0] == -1):
182193
return False
183194
x_size = abs(reduce(lambda x, y: x * y, x_shape))
184195
cache_size = abs(reduce(lambda x, y: x * y, cache_shape))
185196
if x_size <= cache_size:
186197
return True
198+
else:
199+
raise ValueError("only support opt_level 0 or 1.")
187200
return False
188201

189202
self._dataflow_analyze()
190203
self._update_skip_opt_set()
191204
self.pool = []
192205
for i in range(self.op_size):
193206
op = self._ops[i]
194-
if op.type() in sub_block_ops:
207+
if op.type() in SUB_BLOCK_OPS:
195208
continue
196209
block_desc = op.block()
197-
self.current_block_desc = block_desc
198210
is_forward = i < self._forward_num
199211
if self.pool:
200212
defs_can_optimize = filter(
@@ -211,37 +223,40 @@ def compare_shape(x_shape, cache_shape, opt_level):
211223
for index, cache_pair in enumerate(self.pool):
212224
cache_var = cache_pair[0]
213225
cache_shape = cache_pair[1]
214-
if compare_shape(x_shape, cache_shape, level):
215-
if self._has_var(block_desc, cache_var, is_forward):
216-
x_dtype = self._find_var(block_desc, x,
217-
is_forward).dtype()
218-
cache_dtype = self._find_var(
219-
block_desc, cache_var, is_forward).dtype()
220-
# TODO(qijun): actually, we should compare dtype_to_size[x_dtype]
221-
# and dtype_to_size[cache_dtype]
222-
if x_dtype == cache_dtype:
223-
if PRINT_LOG:
224-
print(
225-
("Hit Cache !!!! cache pool index "
226-
"is %d, var name is %s, "
227-
"cached var name is %s, "
228-
"var shape is %s ") %
229-
(index, x, cache_var,
230-
str(cache_shape)))
231-
self.pool.pop(index)
232-
if x == cache_var:
233-
break
234-
_rename_arg_(
235-
self._ops, x, cache_var, begin_idx=i)
236-
self._program.block(block_desc.id).var(
237-
str(x)).desc = self._find_var(
238-
block_desc, cache_var, is_forward)
239-
self._update_graph(
240-
x, cache_var, begin_idx=i)
241-
break
242-
243-
in_diff, out_diff = self._get_diff(self._live_in[i],
244-
self._live_out[i])
226+
if not compare_shape(x_shape, cache_shape, level):
227+
continue
228+
229+
if not self._has_var(block_desc, cache_var, is_forward):
230+
continue
231+
232+
x_dtype = self._find_var(block_desc, x,
233+
is_forward).dtype()
234+
cache_dtype = self._find_var(block_desc, cache_var,
235+
is_forward).dtype()
236+
# TODO(qijun): actually, we should compare
237+
# dtype_to_size[x_dtype] and dtype_to_size[cache_dtype]
238+
if x_dtype != cache_dtype:
239+
continue
240+
241+
if PRINT_LOG:
242+
print(("Hit Cache !!!! cache pool index "
243+
"is %d, var name is %s, "
244+
"cached var name is %s, "
245+
"var shape is %s ") % (index, x, cache_var,
246+
str(cache_shape)))
247+
self.pool.pop(index)
248+
if x == cache_var:
249+
break
250+
# Rename the var to the cache var already with
251+
# memory allocated in order to reuse the memory.
252+
_rename_arg_(self._ops, x, cache_var, begin_idx=i)
253+
self._program.block(block_desc.id).var(str(
254+
x)).desc = self._find_var(block_desc, cache_var,
255+
is_forward)
256+
self._update_graph(x, cache_var, begin_idx=i)
257+
break
258+
259+
in_diff, _ = self._get_diff(self._live_in[i], self._live_out[i])
245260
can_optimize = filter(
246261
lambda x: self._check_var_validity(block_desc, x, is_forward),
247262
in_diff)
@@ -252,6 +267,19 @@ def compare_shape(x_shape, cache_shape, opt_level):
252267

253268

254269
def _process_sub_block_pair(pdesc, sub_block_pair):
270+
"""Creates a list of tuple each of which tracks info of a subblock.
271+
272+
Note: this function doesn't handle nested subblocks yet.
273+
TODO(panyx0718): assert if case nested subblocks happen.
274+
275+
:param pdesc: ProgramDesc.
276+
:param sub_block_pair: A list op pairs. Each op pair is the forward
277+
op and backward op. The ops in the list are special that they contain
278+
a subblock of ops.
279+
:return: A list of tuples, each tuple is (all ops in a subblock pair
280+
including forward and backward, number of forward ops,
281+
all output args names of the ops in the subblock pairs).
282+
"""
255283
ops_list = []
256284
block_desc = pdesc.block(0)
257285
op_size = block_desc.op_size()
@@ -308,6 +336,11 @@ def _process_sub_block_pair(pdesc, sub_block_pair):
308336

309337

310338
def _get_cfgs(input_program):
339+
"""Process each block and create ControlFlowGraph for each of them.
340+
341+
:param input_program: Program object.
342+
:return: A list of ControlFlowGraph, each corresponds to a block.
343+
"""
311344
ops_list = []
312345
pdesc = input_program.get_desc()
313346
block_desc = pdesc.block(0)
@@ -316,11 +349,8 @@ def _get_cfgs(input_program):
316349
ops_list.append(
317350
([block_desc.op(i) for i in range(op_size)], op_size, set()))
318351

319-
sub_block_pair = [("while", "while_grad"), ("parallel_do",
320-
"parallel_do_grad"),
321-
("conditional_block", "conditional_block_grad")]
322-
323-
ops_list.extend(_process_sub_block_pair(pdesc, sub_block_pair))
352+
# Only process one level of nested subblock.
353+
ops_list.extend(_process_sub_block_pair(pdesc, SUB_BLOCK_PAIR))
324354

325355
cfgs = [
326356
ControlFlowGraph(input_program, ops, forward_num, skip_opt)
@@ -330,6 +360,17 @@ def _get_cfgs(input_program):
330360

331361

332362
def memory_optimize(input_program, print_log=False, level=0):
363+
"""Optimize memory by reusing var memory.
364+
365+
Note: it doesn't not support subblock nested in subblock.
366+
367+
:param input_program: Input Program
368+
:param print_log: whether to print debug log.
369+
:param level: If level=0, reuse if the shape is completely equal, o
370+
:return:
371+
"""
372+
if level != 0 and level != 1:
373+
raise ValueError("only support opt_level 0 or 1.")
333374
global PRINT_LOG
334375
PRINT_LOG = print_log
335376
cfgs = _get_cfgs(input_program)

0 commit comments

Comments
 (0)