Skip to content

Commit ea548a7

Browse files
committed
refactor: simplify class to function
1 parent 999d097 commit ea548a7

File tree

1 file changed

+54
-51
lines changed

1 file changed

+54
-51
lines changed

python/paddle/fluid/contrib/memory_usage_calc.py

Lines changed: 54 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@
2323
from .. import core
2424
from ..framework import Program, Variable
2525

26-
__all__ = ['MemoryInfo']
27-
28-
DEBUG = False
26+
__all__ = ['memory_usage']
2927

3028
dtype_to_size = {
3129
core.VarDesc.VarType.FP16: 2,
@@ -38,62 +36,67 @@
3836
core.VarDesc.VarType.UINT8: 1,
3937
}
4038

39+
DEBUG = False
4140

42-
class MemoryInfo(object):
43-
def __init__(self, program):
44-
if not isinstance(program, Program):
45-
raise TypeError(
46-
"Calculating Memory Usage requires Program as its Parameter."
47-
"But you passed in %s" % (type(prgram)))
48-
self._program = program
4941

50-
def _has_var(self, block, var_name):
51-
return block.has_var(str(var_name))
42+
def memory_usage(program, batch_size):
43+
"""
44+
Get the estimate memory usage of program with input batch size.
5245
53-
def _find_var(self, block, var_name):
54-
return block.var(str(var_name))
46+
Args:
47+
program(Program): The current Program.
48+
batch_size(int): The current input data batch_size.
49+
50+
Returns:
51+
min_total_memory(float): the estimate memory usage lower bound.
52+
max_total_memory(float): the estimate memory usage upper bound.
53+
unit_str(string): the unit of estimate usage result.
54+
55+
Examples:
56+
57+
>>> import paddle.fluid as fluid
58+
>>> lower_usage, upper_usage, unit = fluid.contrib.memory_usage(
59+
fluid.default_main_program(), batch_size=10)
60+
>>> print "memory usage is about %.3f - %.3f %s" % \
61+
(lower_usage, upper_usage, unit)
5562
56-
def get_memory_usage(self, batch_size, with_details=False):
63+
"""
5764

58-
# get the first block of program
59-
first_block = self._program.global_block()
65+
# Parameters check
66+
if not isinstance(program, Program):
67+
raise TypeError(
68+
"Calculating Memory Usage requires Program as its Parameter."
69+
"But you passed in %s" % (type(prgram)))
70+
if batch_size <= 0:
71+
raise ValueError("The batch size need to be positive.")
6072

61-
# get the var_name list of first block
62-
# TODO(chenweihang): not find the API get block's var list directly
63-
total_memory = 0.0
64-
for var in self._program.list_vars():
65-
if DEBUG:
66-
print "All Block's Var: %s" % (var.name)
67-
# TODO(chenweihang): why not used program.list_vars()
68-
# calculate all variable's memory directly?
69-
if self._has_var(first_block, var.name):
70-
if DEBUG:
71-
print "First Block's Var: %s" % (var.name)
72-
print "Var's shape: ", var.shape
73-
print "Var's dtype: ", var.dtype
74-
data_count = 1
75-
for x in var.shape:
76-
if x == -1:
77-
data_count *= batch_size
78-
else:
79-
data_count *= x
80-
var_memory = data_count * dtype_to_size[var.dtype]
81-
if DEBUG:
82-
print "Var's memory: %d" % (var_memory)
83-
total_memory += var_memory
73+
# Get the var_name list of first block and calculate
74+
total_memory = 0.0
75+
for var in program.global_block().vars.itervalues():
76+
data_count = 1
77+
for x in var.shape:
78+
if x == -1:
79+
data_count *= batch_size
80+
else:
81+
data_count *= x
82+
var_memory = data_count * dtype_to_size[var.dtype]
83+
if DEBUG:
84+
print "%s memory usage: %d" % (var.name, var_memory)
85+
total_memory += var_memory
86+
if DEBUG:
87+
print "total memory usage: %.2f" % (total_memory)
8488

85-
# Convert unit and make result string
86-
result_str = "- With current batch size, memory usage is about "
87-
unit_str = " B."
89+
# Convert appropriate unit
90+
unit_str = "B"
91+
if total_memory > 1024:
92+
total_memory /= 1024
93+
unit_str = "KB"
8894
if total_memory > 1024:
8995
total_memory /= 1024
90-
unit_str = " KB."
91-
if total_memory > 1024:
92-
total_memory /= 1024
93-
unit_str = " MB."
96+
unit_str = "MB"
9497

95-
# Append extra memory consumption (5% - 10%)
96-
result_str += str(round(total_memory * 1.05, 3)) + " - " \
97-
+ str(round(total_memory * 1.10, 3)) + unit_str
98+
# Append extra memory consumption (5% - 10%)
99+
min_total_memory = total_memory * 1.05
100+
max_total_memory = total_memory * 1.1
98101

99-
return result_str
102+
return min_total_memory, max_total_memory, unit_str

0 commit comments

Comments
 (0)