@@ -157,9 +157,11 @@ def _update_skip_opt_set(self):
157
157
if op .type () == "fill_constant" and op .attr ("force_cpu" ) == True :
158
158
self ._skip_opt .update (op .output_arg_names ())
159
159
160
- def release_memory (self ):
160
+ def release_memory (self , skip_opt_set = None ):
161
161
self ._dataflow_analyze ()
162
162
self ._update_skip_opt_set ()
163
+ if skip_opt_set :
164
+ self ._skip_opt .update (skip_opt_set )
163
165
fwd_id = 0
164
166
bwd_id = 0
165
167
for i in range (self .op_size ):
@@ -183,7 +185,7 @@ def release_memory(self):
183
185
else :
184
186
bwd_id += 1
185
187
186
- def memory_optimize (self , level = 0 ):
188
+ def memory_optimize (self , skip_opt_set = None , level = 0 ):
187
189
def compare_shape (x_shape , cache_shape , opt_level ):
188
190
if opt_level == 0 :
189
191
return x_shape == cache_shape
@@ -200,6 +202,9 @@ def compare_shape(x_shape, cache_shape, opt_level):
200
202
201
203
self ._dataflow_analyze ()
202
204
self ._update_skip_opt_set ()
205
+ # update skip set to meet users' demand
206
+ if skip_opt_set :
207
+ self ._skip_opt .update (skip_opt_set )
203
208
self .pool = []
204
209
for i in range (self .op_size ):
205
210
op = self ._ops [i ]
@@ -358,7 +363,7 @@ def _get_cfgs(input_program):
358
363
return cfgs
359
364
360
365
361
- def memory_optimize (input_program , print_log = False , level = 0 ):
366
+ def memory_optimize (input_program , skip_opt_set = None , print_log = False , level = 0 ):
362
367
"""Optimize memory by reusing var memory.
363
368
364
369
Note: it doesn't not support subblock nested in subblock.
@@ -374,10 +379,10 @@ def memory_optimize(input_program, print_log=False, level=0):
374
379
PRINT_LOG = print_log
375
380
cfgs = _get_cfgs (input_program )
376
381
for cfg in cfgs :
377
- cfg .memory_optimize (level )
382
+ cfg .memory_optimize (skip_opt_set = skip_opt_set , level = level )
378
383
379
384
380
- def release_memory (input_program ):
385
+ def release_memory (input_program , skip_opt_set = None ):
381
386
cfgs = _get_cfgs (input_program )
382
387
for cfg in cfgs :
383
- cfg .release_memory ()
388
+ cfg .release_memory (skip_opt_set = skip_opt_set )
0 commit comments