Skip to content

Commit ef60a65

Browse files
committed
"add test"
1 parent 0c1a5d8 commit ef60a65

File tree

2 files changed

+65
-75
lines changed

2 files changed

+65
-75
lines changed

python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,31 @@ def test_inplace_ops(self):
6666
print("after optimization")
6767
print(str(result_program))
6868

69+
class TestMemoryTranspiler3(unittest.TestCase):
70+
def setUp(self):
71+
program = Program()
72+
with program_guard(program, startup_program=Program()):
73+
word = fluid.layers.data(name='word', shape=[1], dtype='int64')
74+
emb = [fluid.layers.embedding(word, size=[65536, 256], param_attr='emb')
75+
for _ in range(6)]
76+
77+
left = emb.pop(0)
78+
while len(emb) != 0:
79+
right = emb.pop(0)
80+
left = fluid.layers.concat([left, right])
81+
emb = fluid.layers.mean(left)
82+
fluid.backward.append_backward(emb)
83+
self.program = program
84+
85+
def test_cascade_reuse(self):
86+
block = self.program.block(0)
87+
# variable reuse in programdesc
88+
self.assertTrue("concat_4.tmp_0@GRAD" in block.vars)
89+
self.assertTrue("concat_3.tmp_0@GRAD" not in block.vars)
90+
self.assertTrue("concat_2.tmp_0@GRAD" not in block.vars)
91+
self.assertTrue("concat_1.tmp_0@GRAD" not in block.vars)
92+
self.assertTrue("concat_0.tmp_0@GRAD" not in block.vars)
93+
6994

7095
if __name__ == "__main__":
7196
unittest.main()

python/paddle/fluid/transpiler/memory_optimization_transpiler.py

Lines changed: 40 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
class ControlFlowGraph(object):
4848
def __init__(self, program, ops, forward_num, skip_opt):
4949
self._program = program
50+
self._dup_program = program.clone()
5051
self._ops = ops
5152
self._forward_num = forward_num
5253
self._successors = defaultdict(set)
@@ -56,6 +57,7 @@ def __init__(self, program, ops, forward_num, skip_opt):
5657
self._live_in = defaultdict(set)
5758
self._live_out = defaultdict(set)
5859
self._skip_opt = skip_opt
60+
self.pool = []
5961

6062
def _add_connections(self, connections):
6163
"""Populates _successors and _presuccessors for two neighbor nodes."""
@@ -78,8 +80,6 @@ def _build_graph(self):
7880
self._uses[i].update(self._ops[i].input_arg_names())
7981
self._defs[i].update(self._ops[i].output_arg_names())
8082
self._live_in[i] = self._uses[i]
81-
# print(self._successors)
82-
# print(self._presuccessors)
8383

8484
def _update_graph(self, old_name, new_name, begin_idx=0):
8585
for i in range(begin_idx, self.op_size):
@@ -89,50 +89,13 @@ def _update_graph(self, old_name, new_name, begin_idx=0):
8989
if old_name in self._defs[i]:
9090
self._defs[i].remove(old_name)
9191
self._defs[i].add(new_name)
92-
# for i in range(begin_idx, -1, -1):
9392
if old_name in self._live_in[i]:
9493
self._live_in[i].remove(old_name)
9594
self._live_in[i].add(new_name)
96-
# if old_name == "concat_3.tmp_0@GRAD":
97-
# print("new_name", new_name)
98-
# print("live_in ", i , self._live_in[i])
9995
if old_name in self._live_out[i]:
10096
self._live_out[i].remove(old_name)
10197
self._live_out[i].add(new_name)
102-
# if old_name == "concat_3.tmp_0@GRAD":
103-
# print("live_out ", i , self._live_out[i])
104-
105-
def _reach_fixed_point(self, live_in, live_out):
106-
"""Check if the liveness set has stablized."""
107-
if len(live_in) != len(self._live_in):
108-
return False
109-
if len(live_out) != len(self._live_out):
110-
return False
111-
for i in range(self.op_size):
112-
if (live_in[i] != self._live_in[i] or
113-
live_out[i] != self._live_out[i]):
114-
return False
115-
return True
11698

117-
# def _dataflow_analyze(self):
118-
# self._build_graph()
119-
# live_in = defaultdict(set)
120-
# live_out = defaultdict(set)
121-
# # Repeatedly apply liveness updates until the algorithm stablize
122-
# # on a complete set live input vars and live output vars.
123-
# counter = 0
124-
# print(self._successors)
125-
# while True:
126-
# counter += 1
127-
# for i in reversed(list(range(self.op_size))):
128-
# live_in[i] = set(self._live_in[i])
129-
# live_out[i] = set(self._live_out[i])
130-
# for s in self._successors[i]:
131-
# self._live_out[i] |= self._live_in[s]
132-
# self._live_in[i] = self._uses[i] | (
133-
# self._live_out[i] - self._defs[i])
134-
# if self._reach_fixed_point(live_in, live_out):
135-
# break
13699

137100
def _dataflow_analyze(self):
138101
self._build_graph()
@@ -149,6 +112,20 @@ def _dataflow_analyze(self):
149112
for d in self._presuccessors[i]:
150113
worklist.append(d)
151114

115+
def _fill_pool(self, i, is_forward):
116+
block_desc = self._ops[i].block()
117+
in_diff, _ = self._get_diff(self._live_in[i], self._live_out[i])
118+
can_optimize = [
119+
x for x in in_diff
120+
if self._check_var_validity(block_desc, x, is_forward)
121+
]
122+
if can_optimize:
123+
for var_name in can_optimize:
124+
cache = (var_name, self._find_var(
125+
block_desc, var_name, is_forward).shape())
126+
if cache not in self.pool:
127+
self.pool.append(cache)
128+
152129
def _get_diff(self, a, b):
153130
u = a & b
154131
return a - u, b - u
@@ -238,24 +215,15 @@ def compare_shape(x_shape, cache_shape, opt_level):
238215
# update skip set to meet users' demand
239216
if skip_opt_set:
240217
self._skip_opt.update(skip_opt_set)
241-
self.pool = []
218+
# self.pool = []
242219
for i in range(self.op_size):
243220
op = self._ops[i]
244221
if op.type() in SUB_BLOCK_OPS:
245222
continue
246223
block_desc = op.block()
247224
is_forward = i < self._forward_num
248-
in_diff, _ = self._get_diff(self._live_in[i], self._live_out[i])
249-
can_optimize = [
250-
x for x in in_diff
251-
if self._check_var_validity(block_desc, x, is_forward)
252-
]
253-
if can_optimize:
254-
for var_name in can_optimize:
255-
self.pool.append((var_name, self._find_var(
256-
block_desc, var_name, is_forward).shape()))
225+
self._fill_pool(i, is_forward)
257226
# print(op.type(), i, self.pool)
258-
# print(self._live_in[i])
259227
if self.pool:
260228
defs_can_optimize = [
261229
x for x in self._defs[i]
@@ -266,60 +234,57 @@ def compare_shape(x_shape, cache_shape, opt_level):
266234
for x in defs_can_optimize
267235
]
268236
for x, x_shape in out_pair:
237+
if (x, x_shape) in self.pool:
238+
raise ValueError("x in pool")
269239
# If x is both in uses and defs, it can not be optimized!
270240
if x in self._uses[i]:
241+
# print(self.pool, op.type(), cpt.to_text(x))
242+
# raise ValueError("x in use!", cpt.to_text(x))
271243
continue
272244
for index, cache_pair in enumerate(self.pool):
273245
cache_var = cache_pair[0]
274246
cache_shape = cache_pair[1]
275-
if not compare_shape(x_shape, cache_shape, level):
276-
continue
277-
278247
if not self._has_var(block_desc, cache_var, is_forward):
279-
continue
248+
raise ValueError("cache", cpt.to_text(cache_var), " Not exists!")
249+
if x == cache_var:
250+
raise ValueError("x : ", cpt.to_text(x), " cache : ", cpt.to_text(cache_var), " is same var!")
280251

281252
x_dtype = self._find_var(block_desc, x,
282253
is_forward).dtype()
283254
cache_dtype = self._find_var(block_desc, cache_var,
284255
is_forward).dtype()
256+
257+
if not compare_shape(x_shape, cache_shape, level):
258+
continue
285259
# TODO(qijun): actually, we should compare
286260
# dtype_to_size[x_dtype] and dtype_to_size[cache_dtype]
287261
if x_dtype != cache_dtype:
288262
continue
289263

290-
self.pool.pop(index)
291-
if x == cache_var:
292-
break
293-
294264
if PRINT_LOG:
295265
print(("Hit Cache !!!! cache pool index "
296266
"is %d, var name is %s, "
297267
"cached var name is %s, "
298268
"var shape is %s ") % (index, x, cache_var,
299269
str(cache_shape)))
270+
self.pool.pop(index)
300271
# Rename the var to the cache var already with
301272
# memory allocated in order to reuse the memory.
302273
_rename_arg_(self._ops, x, cache_var, begin_idx=i)
303-
self._program.block(block_desc.id).var(cpt.to_text(
304-
x)).desc = self._find_var(block_desc, cache_var,
305-
is_forward)
306-
if x == "concat_3.tmp_0@GRAD":
307-
print("Update Graph", i)
274+
self._program.block(block_desc.id)._remove_var(cpt.to_text(
275+
x))
276+
# if str(self._program) != str(self._dup_program):
277+
# with open("./program_middle", "w") as f:
278+
# f.write(str(self._program))
279+
# f.flush()
280+
# exit(0)
281+
# self._program.block(block_desc.id).var(cpt.to_text(
282+
# x)).desc = self._find_var(block_desc, cache_var,
283+
# is_forward)
308284
self._update_graph(x, cache_var, begin_idx=i)
309285
break
286+
# self._fill_pool(i, is_forward)
310287

311-
in_diff, _ = self._get_diff(self._live_in[i], self._live_out[i])
312-
can_optimize = [
313-
x for x in in_diff
314-
if self._check_var_validity(block_desc, x, is_forward)
315-
]
316-
keys = set([key for key,shape in self.pool])
317-
if can_optimize:
318-
for var_name in can_optimize:
319-
if var_name not in keys:
320-
self.pool.append((var_name, self._find_var(
321-
block_desc, var_name, is_forward).shape()))
322-
# print(op.type(), i, self.pool)
323288

324289

325290
def _process_sub_block_pair(pdesc, sub_block_pair):

0 commit comments

Comments
 (0)