Skip to content

Commit eaea82f

Browse files
committed
feat: Add memory usage estimate function
1 parent 6169d72 commit eaea82f

File tree

2 files changed

+102
-1
lines changed

2 files changed

+102
-1
lines changed

python/paddle/fluid/contrib/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,7 @@
1414

1515
import decoder
1616
from decoder import *
17+
import memory_usage_calc
18+
from memory_usage_calc import *
1719

18-
__all__ = decoder.__all__
20+
__all__ = decoder.__all__ + memory_usage_calc.__all__
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
This module privides a memory usage calculate function for user.
16+
The purpose of this API is to allow users to estimate memory usage of
17+
a program under a special batch size, then user can set appropriate
18+
batch size to fully utilize a GPU.
19+
20+
This API is still under active development and may change drastically.
21+
"""
22+
23+
from .. import core
24+
from ..framework import Program, Variable
25+
26+
__all__ = ['MemoryInfo']
27+
28+
DEBUG = False
29+
30+
dtype_to_size = {
31+
core.VarDesc.VarType.FP16: 2,
32+
core.VarDesc.VarType.FP32: 4,
33+
core.VarDesc.VarType.FP64: 8,
34+
core.VarDesc.VarType.INT16: 2,
35+
core.VarDesc.VarType.INT32: 4,
36+
core.VarDesc.VarType.INT64: 8,
37+
core.VarDesc.VarType.BOOL: 1,
38+
core.VarDesc.VarType.UINT8: 1,
39+
}
40+
41+
42+
class MemoryInfo(object):
43+
def __init__(self, program):
44+
if not isinstance(program, Program):
45+
raise TypeError(
46+
"Calculating Memory Usage requires Program as its Parameter."
47+
"But you passed in %s" % (type(prgram)))
48+
self._program = program
49+
50+
def _has_var(self, block, var_name):
51+
return block.has_var(str(var_name))
52+
53+
def _find_var(self, block, var_name):
54+
return block.var(str(var_name))
55+
56+
def get_memory_usage(self, batch_size, with_details=False):
57+
58+
# get the first block of program
59+
first_block = self._program.global_block()
60+
61+
# get the var_name list of first block
62+
# TODO(chenweihang): not find the API get block's var list directly
63+
total_memory = 0.0
64+
for var in self._program.list_vars():
65+
if DEBUG:
66+
print "All Block's Var: %s" % (var.name)
67+
# TODO(chenweihang): why not used program.list_vars()
68+
# calculate all variable's memory directly?
69+
if self._has_var(first_block, var.name):
70+
if DEBUG:
71+
print "First Block's Var: %s" % (var.name)
72+
print "Var's shape: ", var.shape
73+
print "Var's dtype: ", var.dtype
74+
data_count = 1
75+
for x in var.shape:
76+
if x == -1:
77+
data_count *= batch_size
78+
else:
79+
data_count *= x
80+
var_memory = data_count * dtype_to_size[var.dtype]
81+
if DEBUG:
82+
print "Var's memory: %d" % (var_memory)
83+
total_memory += var_memory
84+
85+
# Convert unit and make result string
86+
result_str = "- With current batch size, memory usage is about "
87+
unit_str = " B."
88+
if total_memory > 1024:
89+
total_memory /= 1024
90+
unit_str = " KB."
91+
if total_memory > 1024:
92+
total_memory /= 1024
93+
unit_str = " MB."
94+
95+
# Append extra memory consumption (5% - 10%)
96+
result_str += str(round(total_memory * 1.05, 3)) + " - " \
97+
+ str(round(total_memory * 1.10, 3)) + unit_str
98+
99+
return result_str

0 commit comments

Comments
 (0)