Skip to content

Commit 34ac0eb

Browse files
QiJunedzhwinter
authored andcommitted
enhance memory optimization transpiler to support user defined skip_opt_set (#11372)
* fix mac build error * enhance memory optimize transpiler to let users set to some skip opt set of variables
1 parent 745ea4d commit 34ac0eb

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

python/paddle/fluid/transpiler/memory_optimization_transpiler.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,11 @@ def _update_skip_opt_set(self):
157157
if op.type() == "fill_constant" and op.attr("force_cpu") == True:
158158
self._skip_opt.update(op.output_arg_names())
159159

160-
def release_memory(self):
160+
def release_memory(self, skip_opt_set=None):
161161
self._dataflow_analyze()
162162
self._update_skip_opt_set()
163+
if skip_opt_set:
164+
self._skip_opt.update(skip_opt_set)
163165
fwd_id = 0
164166
bwd_id = 0
165167
for i in range(self.op_size):
@@ -183,7 +185,7 @@ def release_memory(self):
183185
else:
184186
bwd_id += 1
185187

186-
def memory_optimize(self, level=0):
188+
def memory_optimize(self, skip_opt_set=None, level=0):
187189
def compare_shape(x_shape, cache_shape, opt_level):
188190
if opt_level == 0:
189191
return x_shape == cache_shape
@@ -200,6 +202,9 @@ def compare_shape(x_shape, cache_shape, opt_level):
200202

201203
self._dataflow_analyze()
202204
self._update_skip_opt_set()
205+
# update skip set to meet users' demand
206+
if skip_opt_set:
207+
self._skip_opt.update(skip_opt_set)
203208
self.pool = []
204209
for i in range(self.op_size):
205210
op = self._ops[i]
@@ -358,7 +363,7 @@ def _get_cfgs(input_program):
358363
return cfgs
359364

360365

361-
def memory_optimize(input_program, print_log=False, level=0):
366+
def memory_optimize(input_program, skip_opt_set=None, print_log=False, level=0):
362367
"""Optimize memory by reusing var memory.
363368
364369
Note: it doesn't not support subblock nested in subblock.
@@ -374,10 +379,10 @@ def memory_optimize(input_program, print_log=False, level=0):
374379
PRINT_LOG = print_log
375380
cfgs = _get_cfgs(input_program)
376381
for cfg in cfgs:
377-
cfg.memory_optimize(level)
382+
cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level)
378383

379384

380-
def release_memory(input_program):
385+
def release_memory(input_program, skip_opt_set=None):
381386
cfgs = _get_cfgs(input_program)
382387
for cfg in cfgs:
383-
cfg.release_memory()
388+
cfg.release_memory(skip_opt_set=skip_opt_set)

0 commit comments

Comments
 (0)