Skip to content

Commit a6da470

Browse files
authored
add memory optimization transpiler demo (#7443)
* add memory optimization transpiler demo * add memory benchmark compile option * add gflags instead of macro * refine code
1 parent 610ad49 commit a6da470

File tree

7 files changed

+213
-3
lines changed

7 files changed

+213
-3
lines changed

paddle/framework/executor.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ limitations under the License. */
2323
#include "paddle/framework/op_registry.h"
2424
#include "paddle/platform/place.h"
2525

26+
DECLARE_bool(do_memory_benchmark);
2627
DEFINE_bool(check_nan_inf, false,
2728
"Checking whether operator produce NAN/INF or not. It will be "
2829
"extremely slow so please use this flag wisely.");
@@ -117,6 +118,10 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
117118
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
118119
VLOG(3) << op->DebugStringEx(local_scope);
119120
op->Run(*local_scope, place_);
121+
if (FLAGS_do_memory_benchmark) {
122+
VLOG(2) << "Memory used after operator " + op->Type() + " running: "
123+
<< memory::memory_usage(place_);
124+
}
120125
if (FLAGS_check_nan_inf) {
121126
for (auto& vname : op->OutputVars(true)) {
122127
auto* var = local_scope->FindVar(vname);
@@ -130,6 +135,12 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
130135
if (create_vars && create_local_scope) {
131136
scope->DeleteScope(local_scope);
132137
}
138+
if (FLAGS_do_memory_benchmark) {
139+
VLOG(2) << "-------------------------------------------------------";
140+
VLOG(2) << "Memory used after deleting local scope: "
141+
<< memory::memory_usage(place_);
142+
VLOG(2) << "-------------------------------------------------------";
143+
}
133144
}
134145

135146
} // namespace framework

paddle/framework/scope.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ limitations under the License. */
2020
#include "paddle/framework/threadpool.h"
2121
#include "paddle/string/printf.h"
2222

23+
DEFINE_bool(do_memory_benchmark, false,
24+
"Doing memory benchmark. It will make deleting scope synchronized, "
25+
"and add some memory usage logs");
26+
2327
namespace paddle {
2428
namespace framework {
2529

@@ -88,8 +92,12 @@ void Scope::DeleteScope(Scope* scope) {
8892
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
8993
PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope);
9094
this->kids_.erase(it);
91-
// Make delete async.
92-
Async([scope] { delete scope; });
95+
// When making memory benchmark on Fluid, we have to delete scope sync.
96+
if (FLAGS_do_memory_benchmark) {
97+
delete scope;
98+
} else {
99+
Async([scope] { delete scope; });
100+
}
93101
}
94102

95103
void Scope::Rename(const std::string& origin_name,

python/paddle/v2/fluid/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ def __bootstrap__():
8686

8787
os.environ['OMP_NUM_THREADS'] = str(num_threads)
8888

89-
read_env_flags = ['use_pinned_memory', 'check_nan_inf']
89+
read_env_flags = [
90+
'use_pinned_memory', 'check_nan_inf', 'do_memory_benchmark'
91+
]
9092
if core.is_compile_gpu():
9193
read_env_flags += ['fraction_of_gpu_memory_to_use', 'op_sync']
9294
core.init_gflags([sys.argv[0]] +

python/paddle/v2/fluid/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ endforeach()
66

77
add_subdirectory(book)
88
add_subdirectory(book_distribute)
9+
add_subdirectory(book_memory_optimization)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
2+
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
3+
4+
list(REMOVE_ITEM TEST_OPS test_memopt_image_classification_train)
5+
py_test(test_memopt_image_classification_train_resnet SRCS test_memopt_image_classification_train.py ARGS resnet)
6+
py_test(test_memopt_image_classification_train_vgg SRCS test_memopt_image_classification_train.py ARGS vgg)
7+
8+
# default test
9+
foreach(src ${TEST_OPS})
10+
py_test(${src} SRCS ${src}.py)
11+
endforeach()
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import numpy as np
2+
import paddle.v2 as paddle
3+
import paddle.v2.fluid as fluid
4+
5+
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
6+
7+
y_predict = fluid.layers.fc(input=x, size=1, act=None)
8+
9+
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
10+
11+
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
12+
avg_cost = fluid.layers.mean(x=cost)
13+
14+
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1)
15+
sgd_optimizer.minimize(avg_cost)
16+
17+
# memopt_program = fluid.default_main_program()
18+
memopt_program = fluid.memory_optimize(fluid.default_main_program())
19+
20+
BATCH_SIZE = 200
21+
22+
train_reader = paddle.batch(
23+
paddle.reader.shuffle(
24+
paddle.dataset.uci_housing.train(), buf_size=500),
25+
batch_size=BATCH_SIZE)
26+
27+
place = fluid.CPUPlace()
28+
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
29+
exe = fluid.Executor(place)
30+
31+
exe.run(fluid.default_startup_program())
32+
33+
PASS_NUM = 100
34+
for pass_id in range(PASS_NUM):
35+
fluid.io.save_persistables(exe, "./fit_a_line.model/")
36+
fluid.io.load_persistables(exe, "./fit_a_line.model/")
37+
for data in train_reader():
38+
avg_loss_value, = exe.run(memopt_program,
39+
feed=feeder.feed(data),
40+
fetch_list=[avg_cost])
41+
42+
if avg_loss_value[0] < 10.0:
43+
exit(0) # if avg cost less than 10.0, we think our code is good.
44+
exit(1)
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from __future__ import print_function
2+
3+
import sys
4+
5+
import paddle.v2 as paddle
6+
import paddle.v2.fluid as fluid
7+
8+
9+
def resnet_cifar10(input, depth=32):
10+
def conv_bn_layer(input, ch_out, filter_size, stride, padding, act='relu'):
11+
tmp = fluid.layers.conv2d(
12+
input=input,
13+
filter_size=filter_size,
14+
num_filters=ch_out,
15+
stride=stride,
16+
padding=padding,
17+
act=None,
18+
bias_attr=False)
19+
return fluid.layers.batch_norm(input=tmp, act=act)
20+
21+
def shortcut(input, ch_in, ch_out, stride):
22+
if ch_in != ch_out:
23+
return conv_bn_layer(input, ch_out, 1, stride, 0, None)
24+
else:
25+
return input
26+
27+
def basicblock(input, ch_in, ch_out, stride):
28+
tmp = conv_bn_layer(input, ch_out, 3, stride, 1)
29+
tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None)
30+
short = shortcut(input, ch_in, ch_out, stride)
31+
return fluid.layers.elementwise_add(x=tmp, y=short, act='relu')
32+
33+
def layer_warp(block_func, input, ch_in, ch_out, count, stride):
34+
tmp = block_func(input, ch_in, ch_out, stride)
35+
for i in range(1, count):
36+
tmp = block_func(tmp, ch_out, ch_out, 1)
37+
return tmp
38+
39+
assert (depth - 2) % 6 == 0
40+
n = (depth - 2) / 6
41+
conv1 = conv_bn_layer(
42+
input=input, ch_out=16, filter_size=3, stride=1, padding=1)
43+
res1 = layer_warp(basicblock, conv1, 16, 16, n, 1)
44+
res2 = layer_warp(basicblock, res1, 16, 32, n, 2)
45+
res3 = layer_warp(basicblock, res2, 32, 64, n, 2)
46+
pool = fluid.layers.pool2d(
47+
input=res3, pool_size=8, pool_type='avg', pool_stride=1)
48+
return pool
49+
50+
51+
def vgg16_bn_drop(input):
52+
def conv_block(input, num_filter, groups, dropouts):
53+
return fluid.nets.img_conv_group(
54+
input=input,
55+
pool_size=2,
56+
pool_stride=2,
57+
conv_num_filter=[num_filter] * groups,
58+
conv_filter_size=3,
59+
conv_act='relu',
60+
conv_with_batchnorm=True,
61+
conv_batchnorm_drop_rate=dropouts,
62+
pool_type='max')
63+
64+
conv1 = conv_block(input, 64, 2, [0.3, 0])
65+
conv2 = conv_block(conv1, 128, 2, [0.4, 0])
66+
conv3 = conv_block(conv2, 256, 3, [0.4, 0.4, 0])
67+
conv4 = conv_block(conv3, 512, 3, [0.4, 0.4, 0])
68+
conv5 = conv_block(conv4, 512, 3, [0.4, 0.4, 0])
69+
70+
drop = fluid.layers.dropout(x=conv5, dropout_prob=0.5)
71+
fc1 = fluid.layers.fc(input=drop, size=512, act=None)
72+
bn = fluid.layers.batch_norm(input=fc1, act='relu')
73+
drop2 = fluid.layers.dropout(x=bn, dropout_prob=0.5)
74+
fc2 = fluid.layers.fc(input=drop2, size=512, act=None)
75+
return fc2
76+
77+
78+
classdim = 10
79+
data_shape = [3, 32, 32]
80+
81+
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
82+
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
83+
84+
net_type = "vgg"
85+
if len(sys.argv) >= 2:
86+
net_type = sys.argv[1]
87+
88+
if net_type == "vgg":
89+
print("train vgg net")
90+
net = vgg16_bn_drop(images)
91+
elif net_type == "resnet":
92+
print("train resnet")
93+
net = resnet_cifar10(images, 32)
94+
else:
95+
raise ValueError("%s network is not supported" % net_type)
96+
97+
predict = fluid.layers.fc(input=net, size=classdim, act='softmax')
98+
cost = fluid.layers.cross_entropy(input=predict, label=label)
99+
avg_cost = fluid.layers.mean(x=cost)
100+
101+
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
102+
opts = optimizer.minimize(avg_cost)
103+
104+
accuracy = fluid.evaluator.Accuracy(input=predict, label=label)
105+
106+
# memopt_program = fluid.default_main_program()
107+
memopt_program = fluid.memory_optimize(fluid.default_main_program())
108+
109+
BATCH_SIZE = 128
110+
PASS_NUM = 1
111+
112+
train_reader = paddle.batch(
113+
paddle.reader.shuffle(
114+
paddle.dataset.cifar.train10(), buf_size=128 * 10),
115+
batch_size=BATCH_SIZE)
116+
117+
place = fluid.CPUPlace()
118+
exe = fluid.Executor(place)
119+
feeder = fluid.DataFeeder(place=place, feed_list=[images, label])
120+
exe.run(fluid.default_startup_program())
121+
122+
for pass_id in range(PASS_NUM):
123+
accuracy.reset(exe)
124+
for data in train_reader():
125+
loss, acc = exe.run(memopt_program,
126+
feed=feeder.feed(data),
127+
fetch_list=[avg_cost] + accuracy.metrics)
128+
pass_acc = accuracy.eval(exe)
129+
print("loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str(
130+
pass_acc))
131+
# this model is slow, so if we can train two mini batch, we think it works properly.
132+
exit(0)
133+
exit(1)

0 commit comments

Comments
 (0)