Skip to content

Commit 142fac1

Browse files
committed
add print_log to memory_optimize
1 parent 84aea8a commit 142fac1

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

python/paddle/fluid/memory_optimization_transpiler.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131

3232
sub_block_ops = ["while", "while_grad", "parallel_do", "parallel_do_grad"]
3333

34+
PRINT_LOG = False
35+
3436

3537
class ControlFlowGraph(object):
3638
def __init__(self, Program, ops, forward_num, skip_opt):
@@ -170,7 +172,7 @@ def check_var_validity(block_desc, x, is_forward):
170172
block_desc, cache_var, is_forward).dtype()
171173
# TODO(qijun): actually, we should compare dtype_to_size[x_dtype]
172174
# and dtype_to_size[cache_dtype]
173-
if x_dtype == cache_dtype:
175+
if x_dtype == cache_dtype and PRINT_LOG:
174176
print(("Hit Cache !!!! cache pool index "
175177
"is %d, var name is %s, "
176178
"cached var name is %s, "
@@ -277,7 +279,9 @@ def _get_cfgs(input_program):
277279
return cfgs
278280

279281

280-
def memory_optimize(input_program):
282+
def memory_optimize(input_program, print_log=False):
283+
global PRINT_LOG
284+
PRINT_LOG = print_log
281285
cfgs = _get_cfgs(input_program)
282286
for cfg in cfgs:
283287
cfg.memory_optimize()

python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01)
5050
sgd_optimizer.minimize(avg_cost)
5151

52-
fluid.memory_optimize(fluid.default_main_program())
52+
fluid.memory_optimize(fluid.default_main_program(), print_log=True)
5353

5454
BATCH_SIZE = 200
5555

0 commit comments

Comments
 (0)