Skip to content

Commit 0c1a5d8

Browse files
committed
"debug version"
1 parent ca97313 commit 0c1a5d8

File tree

1 file changed

+63
-19
lines changed

1 file changed

+63
-19
lines changed

python/paddle/fluid/transpiler/memory_optimization_transpiler.py

Lines changed: 63 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def _build_graph(self):
7777
for i in range(self.op_size):
7878
self._uses[i].update(self._ops[i].input_arg_names())
7979
self._defs[i].update(self._ops[i].output_arg_names())
80+
self._live_in[i] = self._uses[i]
81+
# print(self._successors)
82+
# print(self._presuccessors)
8083

8184
def _update_graph(self, old_name, new_name, begin_idx=0):
8285
for i in range(begin_idx, self.op_size):
@@ -86,12 +89,18 @@ def _update_graph(self, old_name, new_name, begin_idx=0):
8689
if old_name in self._defs[i]:
8790
self._defs[i].remove(old_name)
8891
self._defs[i].add(new_name)
92+
# for i in range(begin_idx, -1, -1):
8993
if old_name in self._live_in[i]:
9094
self._live_in[i].remove(old_name)
91-
self._live_out[i].add(new_name)
95+
self._live_in[i].add(new_name)
96+
# if old_name == "concat_3.tmp_0@GRAD":
97+
# print("new_name", new_name)
98+
# print("live_in ", i , self._live_in[i])
9299
if old_name in self._live_out[i]:
93100
self._live_out[i].remove(old_name)
94101
self._live_out[i].add(new_name)
102+
# if old_name == "concat_3.tmp_0@GRAD":
103+
# print("live_out ", i , self._live_out[i])
95104

96105
def _reach_fixed_point(self, live_in, live_out):
97106
"""Check if the liveness set has stablized."""
@@ -105,22 +114,40 @@ def _reach_fixed_point(self, live_in, live_out):
105114
return False
106115
return True
107116

117+
# def _dataflow_analyze(self):
118+
# self._build_graph()
119+
# live_in = defaultdict(set)
120+
# live_out = defaultdict(set)
121+
# # Repeatedly apply liveness updates until the algorithm stablize
122+
# # on a complete set live input vars and live output vars.
123+
# counter = 0
124+
# print(self._successors)
125+
# while True:
126+
# counter += 1
127+
# for i in reversed(list(range(self.op_size))):
128+
# live_in[i] = set(self._live_in[i])
129+
# live_out[i] = set(self._live_out[i])
130+
# for s in self._successors[i]:
131+
# self._live_out[i] |= self._live_in[s]
132+
# self._live_in[i] = self._uses[i] | (
133+
# self._live_out[i] - self._defs[i])
134+
# if self._reach_fixed_point(live_in, live_out):
135+
# break
136+
108137
def _dataflow_analyze(self):
109138
self._build_graph()
110139
live_in = defaultdict(set)
111-
live_out = defaultdict(set)
112-
# Repeatedly apply liveness updates until the algorithm stablize
113-
# on a complete set live input vars and live output vars.
114-
while True:
115-
for i in reversed(list(range(self.op_size))):
116-
live_in[i] = set(self._live_in[i])
117-
live_out[i] = set(self._live_out[i])
118-
for s in self._successors[i]:
119-
self._live_out[i] |= self._live_in[s]
120-
self._live_in[i] = self._uses[i] | (
121-
self._live_out[i] - self._defs[i])
122-
if self._reach_fixed_point(live_in, live_out):
123-
break
140+
worklist = list(range(len(self._ops) - 1, -1, -1))
141+
while worklist:
142+
i = worklist.pop(0)
143+
live_in[i] = set(self._live_in[i])
144+
for s in self._successors[i]:
145+
self._live_out[i] |= self._live_in[s]
146+
self._live_in[i] = self._uses[i] | (
147+
self._live_out[i] - self._defs[i])
148+
if live_in[i] != self._live_in[i]:
149+
for d in self._presuccessors[i]:
150+
worklist.append(d)
124151

125152
def _get_diff(self, a, b):
126153
u = a & b
@@ -218,6 +245,17 @@ def compare_shape(x_shape, cache_shape, opt_level):
218245
continue
219246
block_desc = op.block()
220247
is_forward = i < self._forward_num
248+
in_diff, _ = self._get_diff(self._live_in[i], self._live_out[i])
249+
can_optimize = [
250+
x for x in in_diff
251+
if self._check_var_validity(block_desc, x, is_forward)
252+
]
253+
if can_optimize:
254+
for var_name in can_optimize:
255+
self.pool.append((var_name, self._find_var(
256+
block_desc, var_name, is_forward).shape()))
257+
# print(op.type(), i, self.pool)
258+
# print(self._live_in[i])
221259
if self.pool:
222260
defs_can_optimize = [
223261
x for x in self._defs[i]
@@ -249,21 +287,24 @@ def compare_shape(x_shape, cache_shape, opt_level):
249287
if x_dtype != cache_dtype:
250288
continue
251289

290+
self.pool.pop(index)
291+
if x == cache_var:
292+
break
293+
252294
if PRINT_LOG:
253295
print(("Hit Cache !!!! cache pool index "
254296
"is %d, var name is %s, "
255297
"cached var name is %s, "
256298
"var shape is %s ") % (index, x, cache_var,
257299
str(cache_shape)))
258-
self.pool.pop(index)
259-
if x == cache_var:
260-
break
261300
# Rename the var to the cache var already with
262301
# memory allocated in order to reuse the memory.
263302
_rename_arg_(self._ops, x, cache_var, begin_idx=i)
264303
self._program.block(block_desc.id).var(cpt.to_text(
265304
x)).desc = self._find_var(block_desc, cache_var,
266305
is_forward)
306+
if x == "concat_3.tmp_0@GRAD":
307+
print("Update Graph", i)
267308
self._update_graph(x, cache_var, begin_idx=i)
268309
break
269310

@@ -272,10 +313,13 @@ def compare_shape(x_shape, cache_shape, opt_level):
272313
x for x in in_diff
273314
if self._check_var_validity(block_desc, x, is_forward)
274315
]
316+
keys = set([key for key,shape in self.pool])
275317
if can_optimize:
276318
for var_name in can_optimize:
277-
self.pool.append((var_name, self._find_var(
278-
block_desc, var_name, is_forward).shape()))
319+
if var_name not in keys:
320+
self.pool.append((var_name, self._find_var(
321+
block_desc, var_name, is_forward).shape()))
322+
# print(op.type(), i, self.pool)
279323

280324

281325
def _process_sub_block_pair(pdesc, sub_block_pair):

0 commit comments

Comments
 (0)