Skip to content

Commit b2435a3

Browse files
authored
Merge pull request #12374 from chenwhql/py_calc_memory
Add memory usage estimate API
2 parents 66be532 + 8627ef3 commit b2435a3

File tree

4 files changed

+175
-1
lines changed

4 files changed

+175
-1
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ paddle.fluid.contrib.BeamSearchDecoder.decode ArgSpec(args=['self'], varargs=Non
336336
paddle.fluid.contrib.BeamSearchDecoder.early_stop ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
337337
paddle.fluid.contrib.BeamSearchDecoder.read_array ArgSpec(args=['self', 'init', 'is_ids', 'is_scores'], varargs=None, keywords=None, defaults=(False, False))
338338
paddle.fluid.contrib.BeamSearchDecoder.update_array ArgSpec(args=['self', 'array', 'value'], varargs=None, keywords=None, defaults=None)
339+
paddle.fluid.contrib.memory_usage ArgSpec(args=['program', 'batch_size'], varargs=None, keywords=None, defaults=None)
339340
paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,))
340341
paddle.fluid.transpiler.DistributeTranspiler.create_splited_vars ArgSpec(args=['self', 'source_var', 'block', 'tag'], varargs=None, keywords=None, defaults=None)
341342
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)

python/paddle/fluid/contrib/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,7 @@
1414

1515
import decoder
1616
from decoder import *
17+
import memory_usage_calc
18+
from memory_usage_calc import *
1719

18-
__all__ = decoder.__all__
20+
__all__ = decoder.__all__ + memory_usage_calc.__all__
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
This module privides a memory usage calculate function for user.
16+
The purpose of this API is to allow users to estimate memory usage of
17+
a program under a special batch size, then user can set appropriate
18+
batch size to fully utilize a GPU.
19+
20+
This API is still under active development and may change drastically.
21+
"""
22+
23+
from .. import core
24+
from ..framework import Program, Variable
25+
26+
__all__ = ['memory_usage']
27+
28+
dtype_to_size = {
29+
core.VarDesc.VarType.FP16: 2,
30+
core.VarDesc.VarType.FP32: 4,
31+
core.VarDesc.VarType.FP64: 8,
32+
core.VarDesc.VarType.INT16: 2,
33+
core.VarDesc.VarType.INT32: 4,
34+
core.VarDesc.VarType.INT64: 8,
35+
core.VarDesc.VarType.BOOL: 1,
36+
core.VarDesc.VarType.UINT8: 1,
37+
}
38+
39+
DEBUG = False
40+
41+
42+
def memory_usage(program, batch_size):
43+
"""
44+
Get the estimate memory usage of program with input batch size.
45+
46+
Args:
47+
program(Program): The current Program.
48+
batch_size(int): The current input data batch_size.
49+
50+
Returns:
51+
min_total_memory(float): the estimate memory usage lower bound.
52+
max_total_memory(float): the estimate memory usage upper bound.
53+
unit_str(string): the unit of estimate usage result.
54+
55+
Examples:
56+
57+
>>> import paddle.fluid as fluid
58+
>>> lower_usage, upper_usage, unit = fluid.contrib.memory_usage(
59+
fluid.default_main_program(), batch_size=10)
60+
>>> print "memory usage is about %.3f - %.3f %s" % \
61+
(lower_usage, upper_usage, unit)
62+
63+
"""
64+
65+
# Parameters check
66+
if not isinstance(program, Program):
67+
raise TypeError(
68+
"Calculating Memory Usage requires Program as its Parameter."
69+
"But you passed in %s" % (type(prgram)))
70+
if batch_size <= 0:
71+
raise ValueError("The batch size need to be positive.")
72+
73+
# Get the var_name list of first block and calculate
74+
total_memory = 0.0
75+
for var in program.global_block().vars.itervalues():
76+
data_count = 1
77+
for x in var.shape:
78+
if x == -1:
79+
data_count *= batch_size
80+
else:
81+
data_count *= x
82+
var_memory = data_count * dtype_to_size[var.dtype]
83+
if DEBUG:
84+
print "%s memory usage: %d" % (var.name, var_memory)
85+
total_memory += var_memory
86+
if DEBUG:
87+
print "total memory usage: %.2f" % (total_memory)
88+
89+
# Convert appropriate unit
90+
unit_str = "B"
91+
if total_memory > 1024:
92+
total_memory /= 1024
93+
unit_str = "KB"
94+
if total_memory > 1024:
95+
total_memory /= 1024
96+
unit_str = "MB"
97+
98+
# Append extra memory consumption (5% - 10%)
99+
min_total_memory = total_memory * 1.05
100+
max_total_memory = total_memory * 1.1
101+
102+
return min_total_memory, max_total_memory, unit_str
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
import paddle
17+
import paddle.fluid as fluid
18+
import contextlib
19+
import unittest
20+
21+
22+
def train_simulator(test_batch_size=10):
23+
if test_batch_size <= 0:
24+
raise ValueError("batch_size should be a positive integeral value, "
25+
"but got batch_size={}".format(test_batch_size))
26+
27+
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
28+
y_predict = fluid.layers.fc(input=x, size=1, act=None)
29+
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
30+
31+
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
32+
avg_cost = fluid.layers.mean(cost)
33+
34+
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
35+
sgd_optimizer.minimize(avg_cost)
36+
37+
# Calculate memory usage in current network config
38+
lower_usage, upper_usage, unit = fluid.contrib.memory_usage(
39+
fluid.default_main_program(), batch_size=test_batch_size)
40+
41+
print("memory usage is about %.3f - %.3f %s" %
42+
(lower_usage, upper_usage, unit))
43+
44+
45+
class TestMemoryUsage(unittest.TestCase):
46+
def test_with_unit_B(self):
47+
with self.program_scope_guard():
48+
train_simulator()
49+
50+
def test_with_unit_KB(self):
51+
with self.program_scope_guard():
52+
train_simulator(test_batch_size=1000)
53+
54+
def test_with_unit_MB(self):
55+
with self.program_scope_guard():
56+
train_simulator(test_batch_size=100000)
57+
58+
@contextlib.contextmanager
59+
def program_scope_guard(self):
60+
prog = fluid.Program()
61+
startup_prog = fluid.Program()
62+
scope = fluid.core.Scope()
63+
with fluid.scope_guard(scope):
64+
with fluid.program_guard(prog, startup_prog):
65+
yield
66+
67+
68+
if __name__ == '__main__':
69+
unittest.main()

0 commit comments

Comments
 (0)