@@ -70,23 +70,37 @@ def memory_usage(program, batch_size):
70
70
if not isinstance (program , Program ):
71
71
raise TypeError (
72
72
"Calculating Memory Usage requires Program as its Parameter."
73
- "But you passed in %s" % (type (prgram )))
73
+ "But you passed in %s" % (type (program )))
74
74
if batch_size <= 0 :
75
75
raise ValueError ("The batch size need to be positive." )
76
76
77
77
# Get the var_name list of first block and calculate
78
78
total_memory = 0.0
79
- for var in six .itervalues (program .global_block ().vars ):
80
- data_count = 1
81
- for x in var .shape :
82
- if x == - 1 :
83
- data_count *= batch_size
84
- else :
85
- data_count *= x
86
- var_memory = data_count * dtype_to_size [var .dtype ]
87
- if DEBUG :
88
- print ("%s memory usage: %d" % (var .name , var_memory ))
89
- total_memory += var_memory
79
+ processed_var_names = set ()
80
+ for op in program .global_block ().ops :
81
+ for var_name in op .output_arg_names :
82
+ if var_name in processed_var_names :
83
+ continue
84
+ processed_var_names .add (var_name )
85
+ var = program .global_block ().vars [var_name ]
86
+ if var .desc .type () != core .VarDesc .VarType .LOD_TENSOR :
87
+ continue
88
+
89
+ data_count = 1
90
+ neg_dim_count = 0
91
+ for x in var .shape :
92
+ if x < 0 :
93
+ if neg_dim_count >= 1 :
94
+ raise ValueError ("Var %s has more than one negtive dim."
95
+ % (var_name ))
96
+ neg_dim_count += 1
97
+ data_count *= batch_size * (- x )
98
+ else :
99
+ data_count *= x
100
+ var_memory = data_count * dtype_to_size [var .dtype ]
101
+ if DEBUG :
102
+ print ("%s memory usage: %d" % (var .name , var_memory ))
103
+ total_memory += var_memory
90
104
if DEBUG :
91
105
print ("total memory usage: %.2f" % (total_memory ))
92
106
0 commit comments