Skip to content

Commit c89664e

Browse files
authored
Merge pull request #8831 from jacquesqiao/add-printlog-to-memoryoptimize
add print_log to memory_optimize
2 parents 84aea8a + fe2d590 commit c89664e

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

python/paddle/fluid/memory_optimization_transpiler.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131

3232
sub_block_ops = ["while", "while_grad", "parallel_do", "parallel_do_grad"]
3333

34+
PRINT_LOG = False
35+
3436

3537
class ControlFlowGraph(object):
3638
def __init__(self, Program, ops, forward_num, skip_opt):
@@ -171,12 +173,14 @@ def check_var_validity(block_desc, x, is_forward):
171173
# TODO(qijun): actually, we should compare dtype_to_size[x_dtype]
172174
# and dtype_to_size[cache_dtype]
173175
if x_dtype == cache_dtype:
174-
print(("Hit Cache !!!! cache pool index "
175-
"is %d, var name is %s, "
176-
"cached var name is %s, "
177-
"var shape is %s ") %
178-
(index, x, cache_var,
179-
str(cache_shape)))
176+
if PRINT_LOG:
177+
print(
178+
("Hit Cache !!!! cache pool index "
179+
"is %d, var name is %s, "
180+
"cached var name is %s, "
181+
"var shape is %s ") %
182+
(index, x, cache_var,
183+
str(cache_shape)))
180184
self.pool.pop(index)
181185
if x == cache_var:
182186
break
@@ -277,7 +281,9 @@ def _get_cfgs(input_program):
277281
return cfgs
278282

279283

280-
def memory_optimize(input_program):
284+
def memory_optimize(input_program, print_log=False):
285+
global PRINT_LOG
286+
PRINT_LOG = print_log
281287
cfgs = _get_cfgs(input_program)
282288
for cfg in cfgs:
283289
cfg.memory_optimize()

python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01)
5050
sgd_optimizer.minimize(avg_cost)
5151

52-
fluid.memory_optimize(fluid.default_main_program())
52+
fluid.memory_optimize(fluid.default_main_program(), print_log=True)
5353

5454
BATCH_SIZE = 200
5555

0 commit comments

Comments
 (0)