@@ -96,7 +96,6 @@ def _update_graph(self, old_name, new_name, begin_idx=0):
96
96
self ._live_out [i ].remove (old_name )
97
97
self ._live_out [i ].add (new_name )
98
98
99
-
100
99
def _dataflow_analyze (self ):
101
100
self ._build_graph ()
102
101
live_in = defaultdict (set )
@@ -121,8 +120,8 @@ def _fill_pool(self, i, is_forward):
121
120
]
122
121
if can_optimize :
123
122
for var_name in can_optimize :
124
- cache = (var_name , self ._find_var (
125
- block_desc , var_name , is_forward ).shape ())
123
+ cache = (var_name , self ._find_var (block_desc , var_name ,
124
+ is_forward ).shape ())
126
125
if cache not in self .pool :
127
126
self .pool .append (cache )
128
127
@@ -232,17 +231,22 @@ def compare_shape(x_shape, cache_shape, opt_level):
232
231
]
233
232
for x , x_shape in out_pair :
234
233
if (x , x_shape ) in self .pool :
235
- raise ValueError ("x in pool" )
234
+ raise ValueError ("x in pool, %s, %s" % ( x , x_shape ) )
236
235
# If x is both in uses and defs, it can not be optimized!
237
236
if x in self ._uses [i ]:
238
237
continue
239
238
for index , cache_pair in enumerate (self .pool ):
240
239
cache_var = cache_pair [0 ]
241
240
cache_shape = cache_pair [1 ]
242
241
if not self ._has_var (block_desc , cache_var , is_forward ):
243
- raise ValueError ("cache" , cpt .to_text (cache_var ), " Not exists!" )
242
+ raise ValueError ("cache" ,
243
+ cpt .to_text (cache_var ),
244
+ " Not exists!" )
244
245
if x == cache_var :
245
- raise ValueError ("x : " , cpt .to_text (x ), " cache : " , cpt .to_text (cache_var ), " is same var!" )
246
+ raise ValueError ("x : " ,
247
+ cpt .to_text (x ), " cache : " ,
248
+ cpt .to_text (cache_var ),
249
+ " is same var!" )
246
250
247
251
x_dtype = self ._find_var (block_desc , x ,
248
252
is_forward ).dtype ()
@@ -266,14 +270,14 @@ def compare_shape(x_shape, cache_shape, opt_level):
266
270
# Rename the var to the cache var already with
267
271
# memory allocated in order to reuse the memory.
268
272
_rename_arg_ (self ._ops , x , cache_var , begin_idx = i )
269
- self ._program .block (block_desc .id )._remove_var (cpt .to_text (
270
- x ))
273
+ self ._program .block (block_desc .id ).var (cpt .to_text (
274
+ x )).desc = self ._find_var (block_desc , cache_var ,
275
+ is_forward )
271
276
self ._update_graph (x , cache_var , begin_idx = i )
272
277
break
273
278
self ._fill_pool (i , is_forward )
274
279
275
280
276
-
277
281
def _process_sub_block_pair (pdesc , sub_block_pair ):
278
282
"""Creates a list of tuple each of which tracks info of a subblock.
279
283
@@ -379,7 +383,7 @@ def memory_optimize(input_program, skip_opt_set=None, print_log=False, level=0):
379
383
380
384
Note: it doesn't not support subblock nested in subblock.
381
385
382
- :param input_program: Input Program
386
+ :param input_program(str) : Input Program
383
387
:param print_log: whether to print debug log.
384
388
:param level: If level=0, reuse if the shape is completely equal, o
385
389
:return:
0 commit comments