@@ -56,6 +56,7 @@ def __init__(self, program, ops, forward_num, skip_opt):
56
56
self ._live_in = defaultdict (set )
57
57
self ._live_out = defaultdict (set )
58
58
self ._skip_opt = skip_opt
59
+ self .pool = []
59
60
60
61
def _add_connections (self , connections ):
61
62
"""Populates _successors and _presuccessors for two neighbor nodes."""
@@ -77,6 +78,7 @@ def _build_graph(self):
77
78
for i in range (self .op_size ):
78
79
self ._uses [i ].update (self ._ops [i ].input_arg_names ())
79
80
self ._defs [i ].update (self ._ops [i ].output_arg_names ())
81
+ self ._live_in [i ] = self ._uses [i ]
80
82
81
83
def _update_graph (self , old_name , new_name , begin_idx = 0 ):
82
84
for i in range (begin_idx , self .op_size ):
@@ -88,39 +90,39 @@ def _update_graph(self, old_name, new_name, begin_idx=0):
88
90
self ._defs [i ].add (new_name )
89
91
if old_name in self ._live_in [i ]:
90
92
self ._live_in [i ].remove (old_name )
91
- self ._live_out [i ].add (new_name )
93
+ self ._live_in [i ].add (new_name )
92
94
if old_name in self ._live_out [i ]:
93
95
self ._live_out [i ].remove (old_name )
94
96
self ._live_out [i ].add (new_name )
95
97
96
- def _reach_fixed_point (self , live_in , live_out ):
97
- """Check if the liveness set has stablized."""
98
- if len (live_in ) != len (self ._live_in ):
99
- return False
100
- if len (live_out ) != len (self ._live_out ):
101
- return False
102
- for i in range (self .op_size ):
103
- if (live_in [i ] != self ._live_in [i ] or
104
- live_out [i ] != self ._live_out [i ]):
105
- return False
106
- return True
107
-
108
98
def _dataflow_analyze (self ):
109
99
self ._build_graph ()
110
100
live_in = defaultdict (set )
111
- live_out = defaultdict (set )
112
- # Repeatedly apply liveness updates until the algorithm stablize
113
- # on a complete set live input vars and live output vars.
114
- while True :
115
- for i in reversed (list (range (self .op_size ))):
116
- live_in [i ] = set (self ._live_in [i ])
117
- live_out [i ] = set (self ._live_out [i ])
118
- for s in self ._successors [i ]:
119
- self ._live_out [i ] |= self ._live_in [s ]
120
- self ._live_in [i ] = self ._uses [i ] | (
121
- self ._live_out [i ] - self ._defs [i ])
122
- if self ._reach_fixed_point (live_in , live_out ):
123
- break
101
+ worklist = list (range (len (self ._ops ) - 1 , - 1 , - 1 ))
102
+ while worklist :
103
+ i = worklist .pop (0 )
104
+ live_in [i ] = set (self ._live_in [i ])
105
+ for s in self ._successors [i ]:
106
+ self ._live_out [i ] |= self ._live_in [s ]
107
+ self ._live_in [i ] = self ._uses [i ] | (
108
+ self ._live_out [i ] - self ._defs [i ])
109
+ if live_in [i ] != self ._live_in [i ]:
110
+ for d in self ._presuccessors [i ]:
111
+ worklist .append (d )
112
+
113
+ def _fill_pool (self , i , is_forward ):
114
+ block_desc = self ._ops [i ].block ()
115
+ in_diff , _ = self ._get_diff (self ._live_in [i ], self ._live_out [i ])
116
+ can_optimize = [
117
+ x for x in in_diff
118
+ if self ._check_var_validity (block_desc , x , is_forward )
119
+ ]
120
+ if can_optimize :
121
+ for var_name in can_optimize :
122
+ cache = (var_name , self ._find_var (block_desc , var_name ,
123
+ is_forward ).shape ())
124
+ if cache not in self .pool :
125
+ self .pool .append (cache )
124
126
125
127
def _get_diff (self , a , b ):
126
128
u = a & b
@@ -211,7 +213,6 @@ def compare_shape(x_shape, cache_shape, opt_level):
211
213
# update skip set to meet users' demand
212
214
if skip_opt_set :
213
215
self ._skip_opt .update (skip_opt_set )
214
- self .pool = []
215
216
for i in range (self .op_size ):
216
217
op = self ._ops [i ]
217
218
if op .type () in SUB_BLOCK_OPS :
@@ -234,16 +235,24 @@ def compare_shape(x_shape, cache_shape, opt_level):
234
235
for index , cache_pair in enumerate (self .pool ):
235
236
cache_var = cache_pair [0 ]
236
237
cache_shape = cache_pair [1 ]
237
- if not compare_shape (x_shape , cache_shape , level ):
238
- continue
239
-
240
238
if not self ._has_var (block_desc , cache_var , is_forward ):
239
+ if PRINT_LOG :
240
+ print ("cache %s not exists!" %
241
+ (cpt .to_text (cache_var )))
241
242
continue
243
+ if x == cache_var :
244
+ if PRINT_LOG :
245
+ print ("x : " , cpt .to_text (x ), " cache : " ,
246
+ cpt .to_text (cache_var ), " is same var!" )
247
+ break
242
248
243
249
x_dtype = self ._find_var (block_desc , x ,
244
250
is_forward ).dtype ()
245
251
cache_dtype = self ._find_var (block_desc , cache_var ,
246
252
is_forward ).dtype ()
253
+
254
+ if not compare_shape (x_shape , cache_shape , level ):
255
+ continue
247
256
# TODO(qijun): actually, we should compare
248
257
# dtype_to_size[x_dtype] and dtype_to_size[cache_dtype]
249
258
if x_dtype != cache_dtype :
@@ -256,8 +265,6 @@ def compare_shape(x_shape, cache_shape, opt_level):
256
265
"var shape is %s " ) % (index , x , cache_var ,
257
266
str (cache_shape )))
258
267
self .pool .pop (index )
259
- if x == cache_var :
260
- break
261
268
# Rename the var to the cache var already with
262
269
# memory allocated in order to reuse the memory.
263
270
_rename_arg_ (self ._ops , x , cache_var , begin_idx = i )
@@ -266,16 +273,7 @@ def compare_shape(x_shape, cache_shape, opt_level):
266
273
is_forward )
267
274
self ._update_graph (x , cache_var , begin_idx = i )
268
275
break
269
-
270
- in_diff , _ = self ._get_diff (self ._live_in [i ], self ._live_out [i ])
271
- can_optimize = [
272
- x for x in in_diff
273
- if self ._check_var_validity (block_desc , x , is_forward )
274
- ]
275
- if can_optimize :
276
- for var_name in can_optimize :
277
- self .pool .append ((var_name , self ._find_var (
278
- block_desc , var_name , is_forward ).shape ()))
276
+ self ._fill_pool (i , is_forward )
279
277
280
278
281
279
def _process_sub_block_pair (pdesc , sub_block_pair ):
@@ -383,10 +381,13 @@ def memory_optimize(input_program, skip_opt_set=None, print_log=False, level=0):
383
381
384
382
Note: it doesn't not support subblock nested in subblock.
385
383
386
- :param input_program: Input Program
387
- :param print_log: whether to print debug log.
388
- :param level: If level=0, reuse if the shape is completely equal, o
389
- :return:
384
+ Args:
385
+ input_program(str): Input Program
386
+ skip_opt_set(set): vars wil be skipped in memory optimze
387
+ print_log(bool): whether to print debug log.
388
+ level(int): If level=0, reuse if the shape is completely equal, o
389
+ Returns:
390
+ None
390
391
"""
391
392
if level != 0 and level != 1 :
392
393
raise ValueError ("only support opt_level 0 or 1." )
@@ -407,6 +408,9 @@ def release_memory(input_program, skip_opt_set=None):
407
408
408
409
Args:
409
410
input_program(Program): The program will be inserted :code:`delete_op`.
411
+ skip_opt_set(set): vars wil be skipped in memory optimze
412
+ Returns:
413
+ None
410
414
"""
411
415
cfgs = _get_cfgs (input_program )
412
416
for cfg in cfgs :
0 commit comments