Skip to content

Commit 0764211

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add-mkldnn-to-paddle-lib
2 parents de3c517 + dbbeccc commit 0764211

File tree

18 files changed

+266
-74
lines changed

18 files changed

+266
-74
lines changed

benchmark/fluid/mnist.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def run_benchmark(model, args):
159159
paddle.dataset.mnist.train(), batch_size=args.batch_size)
160160

161161
accuracy = fluid.metrics.Accuracy()
162+
train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name)
162163
iters, num_samples, start_time = 0, 0, time.time()
163164
for pass_id in range(args.pass_num):
164165
accuracy.reset()
@@ -175,17 +176,20 @@ def run_benchmark(model, args):
175176
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
176177
y_data = y_data.reshape([len(y_data), 1])
177178

178-
outs = exe.run(
179-
fluid.default_main_program(),
179+
outs = train_exe.run(
180180
feed={"pixel": img_data,
181181
"label": y_data},
182-
fetch_list=[avg_cost, batch_acc, batch_size_tensor]
182+
fetch_list=[
183+
avg_cost.name, batch_acc.name, batch_size_tensor.name
184+
]
183185
) # The accuracy is the accumulation of batches, but not the current batch.
184-
accuracy.update(value=outs[1], weight=outs[2])
186+
accuracy.update(
187+
value=np.array(np.mean(outs[1])),
188+
weight=np.mean(np.array(outs[2])))
185189
iters += 1
186190
num_samples += len(y_data)
187-
loss = np.array(outs[0])
188-
acc = np.array(outs[1])
191+
loss = np.mean(np.array(outs[0]))
192+
acc = np.mean(np.array(outs[1]))
189193
train_losses.append(loss)
190194
train_accs.append(acc)
191195
print("Pass: %d, Iter: %d, Loss: %f, Accuracy: %f" %

benchmark/fluid/resnet.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ def test(exe):
241241
exe = fluid.Executor(place)
242242
exe.run(fluid.default_startup_program())
243243
accuracy = fluid.average.WeightedAverage()
244+
train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name)
244245
if args.use_fake_data:
245246
data = train_reader().next()
246247
image = np.array(map(lambda x: x[0].reshape(dshape), data)).astype(
@@ -264,14 +265,17 @@ def test(exe):
264265
data)).astype('float32')
265266
label = np.array(map(lambda x: x[1], data)).astype('int64')
266267
label = label.reshape([-1, 1])
267-
loss, acc, weight = exe.run(
268-
fluid.default_main_program(),
268+
loss, acc, weight = train_exe.run(
269269
feed={'data': image,
270270
'label': label},
271-
fetch_list=[avg_cost, batch_acc, batch_size_tensor])
271+
fetch_list=[
272+
avg_cost.name, batch_acc.name, batch_size_tensor.name
273+
])
272274
iters += 1
273275
num_samples += len(label)
274-
accuracy.add(value=acc, weight=weight)
276+
accuracy.add(value=np.array(np.mean(acc)), weight=np.mean(weight))
277+
loss = np.mean(np.array(loss))
278+
acc = np.mean(np.array(acc))
275279
train_losses.append(loss)
276280
train_accs.append(acc)
277281
print("Pass: %d, Iter: %d, Loss: %f, Accuracy: %f" %

benchmark/fluid/vgg.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def test(exe):
169169

170170
iters, num_samples, start_time = 0, 0, time.time()
171171
accuracy = fluid.average.WeightedAverage()
172+
train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name)
172173
for pass_id in range(args.pass_num):
173174
accuracy.reset()
174175
train_accs = []
@@ -184,14 +185,17 @@ def test(exe):
184185
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
185186
y_data = y_data.reshape([-1, 1])
186187

187-
loss, acc, weight = exe.run(
188-
fluid.default_main_program(),
188+
loss, acc, weight = train_exe.run(
189189
feed={"pixel": img_data,
190190
"label": y_data},
191-
fetch_list=[avg_cost, batch_acc, batch_size_tensor])
192-
accuracy.add(value=acc, weight=weight)
191+
fetch_list=[
192+
avg_cost.name, batch_acc.name, batch_size_tensor.name
193+
])
194+
accuracy.add(value=np.array(np.mean(acc)), weight=np.mean(weight))
193195
iters += 1
194196
num_samples += len(y_data)
197+
loss = np.mean(np.array(loss))
198+
acc = np.mean(np.array(acc))
195199
print(
196200
"Pass = %d, Iter = %d, Loss = %f, Accuracy = %f" %
197201
(pass_id, iters, loss, acc)

cmake/inference_lib.cmake

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,4 +156,10 @@ copy(string_lib
156156
DSTS ${dst_dir}/${module} ${dst_dir}/${module}/tinyformat
157157
)
158158

159+
set(module "pybind")
160+
copy(pybind_lib
161+
SRCS ${CMAKE_CURRENT_BINARY_DIR}/paddle/fluid/${module}/pybind.h
162+
DSTS ${dst_dir}/${module}
163+
)
164+
159165
add_custom_target(inference_lib_dist DEPENDS ${inference_lib_dist_dep})

doc/fluid/design/concepts/functions_operators_layers.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ template <typename T>
4040
class FCOp : public OperatorBase {
4141
public:
4242
void Run(...) {
43-
add(mul(Input<T>("X"), Input<T>("W")), Input<T>("b");
43+
add(mul(Input<T>("X"), Input<T>("W")), Input<T>("b"));
4444
}
4545
};
4646
REGISTER_OP(FCOp, "fc");

paddle/fluid/framework/details/op_handle_base.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ class OpHandleBase {
7070

7171
const std::vector<VarHandleBase *> &Inputs() const { return inputs_; }
7272

73+
size_t NoDupInputSize() const {
74+
std::unordered_set<VarHandleBase *> res;
75+
for (auto *var : inputs_) {
76+
res.emplace(var);
77+
}
78+
return res.size();
79+
}
80+
7381
const std::vector<VarHandleBase *> &Outputs() const { return outputs_; }
7482

7583
protected:

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
174174
void ThreadedSSAGraphExecutor::InsertPendingOp(
175175
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
176176
OpHandleBase *op_instance) const {
177-
pending_ops->insert({op_instance, op_instance->Inputs().size()});
177+
pending_ops->insert({op_instance, op_instance->NoDupInputSize()});
178178
}
179179

180180
void ThreadedSSAGraphExecutor::InsertPendingVar(

paddle/fluid/inference/tensorrt/convert/op_converter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class OpConverter {
4949
// convert fluid block to tensorrt network
5050
void ConvertBlock(const framework::proto::BlockDesc& block,
5151
TensorRTEngine* engine) {
52-
for (size_t i = 0; i < block.ops_size(); i++) {
52+
for (int i = 0; i < block.ops_size(); i++) {
5353
const auto& op = block.ops(i);
5454
OpConverter::Run(op, engine);
5555
}

paddle/fluid/operators/smooth_l1_loss_op.cc

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel {
105105
using framework::OperatorWithKernel::OperatorWithKernel;
106106

107107
void InferShape(framework::InferShapeContext* ctx) const override {
108-
auto in_dims = ctx->GetInputDim("X");
108+
auto in_dims = ctx->GetInputDim("Diff");
109109
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
110110

111111
PADDLE_ENFORCE_GE(out_dims.size(), 2,
@@ -127,12 +127,33 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel {
127127
}
128128
};
129129

130+
class SmoothL1LossGradMaker : public framework::SingleGradOpDescMaker {
131+
public:
132+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
133+
134+
protected:
135+
std::unique_ptr<framework::OpDesc> Apply() const override {
136+
auto* op = new framework::OpDesc();
137+
op->SetType("smooth_l1_loss_grad");
138+
op->SetInput("InsideWeight", Input("InsideWeight"));
139+
op->SetInput("OutsideWeight", Input("OutsideWeight"));
140+
op->SetInput("Diff", Output("Diff"));
141+
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
142+
143+
op->SetAttrMap(Attrs());
144+
145+
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
146+
op->SetOutput(framework::GradVarName("Y"), InputGrad("Y"));
147+
return std::unique_ptr<framework::OpDesc>(op);
148+
}
149+
};
150+
130151
} // namespace operators
131152
} // namespace paddle
132153

133154
namespace ops = paddle::operators;
134155
REGISTER_OPERATOR(smooth_l1_loss, ops::SmoothL1LossOp, ops::SmoothL1LossOpMaker,
135-
paddle::framework::DefaultGradOpDescMaker<true>);
156+
ops::SmoothL1LossGradMaker);
136157
REGISTER_OPERATOR(smooth_l1_loss_grad, ops::SmoothL1LossGradOp);
137158
REGISTER_OP_CPU_KERNEL(
138159
smooth_l1_loss,

paddle/scripts/paddle_build.sh

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,15 @@
2020
#=================================================
2121

2222
function print_usage() {
23-
RED='\033[0;31m'
24-
BLUE='\033[0;34m'
25-
BOLD='\033[1m'
26-
NONE='\033[0m'
27-
2823
echo -e "\n${RED}Usage${NONE}:
29-
${BOLD}$0${NONE} [OPTION]"
24+
${BOLD}${SCRIPT_NAME}${NONE} [OPTION]"
3025

3126
echo -e "\n${RED}Options${NONE}:
3227
${BLUE}build${NONE}: run build for x86 platform
3328
${BLUE}build_android${NONE}: run build for android platform
3429
${BLUE}build_ios${NONE}: run build for ios platform
3530
${BLUE}test${NONE}: run all unit tests
31+
${BLUE}single_test${NONE}: run a single unit test
3632
${BLUE}bind_test${NONE}: parallel tests bind to different GPU
3733
${BLUE}doc${NONE}: generate paddle documents
3834
${BLUE}html${NONE}: convert C++ source code into HTML
@@ -45,7 +41,15 @@ function print_usage() {
4541
}
4642

4743
function init() {
44+
RED='\033[0;31m'
45+
BLUE='\033[0;34m'
46+
BOLD='\033[1m'
47+
NONE='\033[0m'
48+
4849
PADDLE_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}")/../../" && pwd )"
50+
if [ -z "${SCRIPT_NAME}" ]; then
51+
SCRIPT_NAME=$0
52+
fi
4953
}
5054

5155
function cmake_gen() {
@@ -309,6 +313,25 @@ EOF
309313
fi
310314
}
311315

316+
function single_test() {
317+
TEST_NAME=$1
318+
if [ -z "${TEST_NAME}" ]; then
319+
echo -e "${RED}Usage:${NONE}"
320+
echo -e "${BOLD}${SCRIPT_NAME}${NONE} ${BLUE}single_test${NONE} [test_name]"
321+
exit 1
322+
fi
323+
mkdir -p ${PADDLE_ROOT}/build
324+
cd ${PADDLE_ROOT}/build
325+
if [ ${WITH_TESTING:-ON} == "ON" ] ; then
326+
cat <<EOF
327+
========================================
328+
Running ${TEST_NAME} ...
329+
========================================
330+
EOF
331+
ctest --output-on-failure -R ${TEST_NAME}
332+
fi
333+
}
334+
312335
function bind_test() {
313336
# the number of process to run tests
314337
NUM_PROC=6
@@ -480,6 +503,7 @@ function main() {
480503
build)
481504
cmake_gen ${PYTHON_ABI:-""}
482505
build
506+
gen_dockerfile
483507
;;
484508
build_android)
485509
build_android
@@ -490,6 +514,9 @@ function main() {
490514
test)
491515
run_test
492516
;;
517+
single_test)
518+
single_test $2
519+
;;
493520
bind_test)
494521
bind_test
495522
;;

0 commit comments

Comments
 (0)