47
47
class ControlFlowGraph (object ):
48
48
def __init__ (self , program , ops , forward_num , skip_opt ):
49
49
self ._program = program
50
- self ._dup_program = program .clone ()
51
50
self ._ops = ops
52
51
self ._forward_num = forward_num
53
52
self ._successors = defaultdict (set )
@@ -230,23 +229,22 @@ def compare_shape(x_shape, cache_shape, opt_level):
230
229
for x in defs_can_optimize
231
230
]
232
231
for x , x_shape in out_pair :
233
- if (x , x_shape ) in self .pool :
234
- raise ValueError ("x in pool, %s, %s" % (x , x_shape ))
235
232
# If x is both in uses and defs, it can not be optimized!
236
233
if x in self ._uses [i ]:
237
234
continue
238
235
for index , cache_pair in enumerate (self .pool ):
239
236
cache_var = cache_pair [0 ]
240
237
cache_shape = cache_pair [1 ]
241
238
if not self ._has_var (block_desc , cache_var , is_forward ):
242
- raise ValueError ("cache" ,
243
- cpt .to_text (cache_var ),
244
- " Not exists!" )
239
+ if PRINT_LOG :
240
+ print ("cache %s not exists!" %
241
+ (cpt .to_text (cache_var )))
242
+ continue
245
243
if x == cache_var :
246
- raise ValueError ( "x : " ,
247
- cpt .to_text (x ), " cache : " ,
248
- cpt .to_text (cache_var ),
249
- " is same var!" )
244
+ if PRINT_LOG :
245
+ print ( "x : " , cpt .to_text (x ), " cache : " ,
246
+ cpt .to_text (cache_var ), " is same var!" )
247
+ break
250
248
251
249
x_dtype = self ._find_var (block_desc , x ,
252
250
is_forward ).dtype ()
@@ -383,10 +381,13 @@ def memory_optimize(input_program, skip_opt_set=None, print_log=False, level=0):
383
381
384
382
Note: it doesn't not support subblock nested in subblock.
385
383
386
- :param input_program(str): Input Program
387
- :param print_log: whether to print debug log.
388
- :param level: If level=0, reuse if the shape is completely equal, o
389
- :return:
384
+ Args:
385
+ input_program(str): Input Program
386
+ skip_opt_set(set): vars wil be skipped in memory optimze
387
+ print_log(bool): whether to print debug log.
388
+ level(int): If level=0, reuse if the shape is completely equal, o
389
+ Returns:
390
+ None
390
391
"""
391
392
if level != 0 and level != 1 :
392
393
raise ValueError ("only support opt_level 0 or 1." )
@@ -407,6 +408,9 @@ def release_memory(input_program, skip_opt_set=None):
407
408
408
409
Args:
409
410
input_program(Program): The program will be inserted :code:`delete_op`.
411
+ skip_opt_set(set): vars wil be skipped in memory optimze
412
+ Returns:
413
+ None
410
414
"""
411
415
cfgs = _get_cfgs (input_program )
412
416
for cfg in cfgs :
0 commit comments