Skip to content

Commit fb08e16

Browse files
committed
refine memory usage calc
1 parent a22309a commit fb08e16

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

python/paddle/fluid/contrib/memory_usage_calc.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,23 +70,32 @@ def memory_usage(program, batch_size):
7070
if not isinstance(program, Program):
7171
raise TypeError(
7272
"Calculating Memory Usage requires Program as its Parameter."
73-
"But you passed in %s" % (type(prgram)))
73+
"But you passed in %s" % (type(program)))
7474
if batch_size <= 0:
7575
raise ValueError("The batch size need to be positive.")
7676

7777
# Get the var_name list of first block and calculate
7878
total_memory = 0.0
79-
for var in six.itervalues(program.global_block().vars):
80-
data_count = 1
81-
for x in var.shape:
82-
if x == -1:
83-
data_count *= batch_size
84-
else:
85-
data_count *= x
86-
var_memory = data_count * dtype_to_size[var.dtype]
87-
if DEBUG:
88-
print("%s memory usage: %d" % (var.name, var_memory))
89-
total_memory += var_memory
79+
processed_var_names = set()
80+
for op in program.global_block().ops:
81+
for var_name in op.output_arg_names:
82+
if var_name in processed_var_names:
83+
continue
84+
processed_var_names.add(var_name)
85+
var = program.global_block().vars[var_name]
86+
if var.desc.type() != core.VarDesc.VarType.LOD_TENSOR:
87+
continue
88+
89+
data_count = 1
90+
for x in var.shape:
91+
if x < 0:
92+
data_count *= batch_size * (-x)
93+
else:
94+
data_count *= x
95+
var_memory = data_count * dtype_to_size[var.dtype]
96+
if DEBUG:
97+
print("%s memory usage: %d" % (var.name, var_memory))
98+
total_memory += var_memory
9099
if DEBUG:
91100
print("total memory usage: %.2f" % (total_memory))
92101

0 commit comments

Comments
 (0)