Skip to content

Commit ad2e420

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/matmul_support_float16_double
2 parents 2719729 + ba57348 commit ad2e420

37 files changed

+834
-388
lines changed

.copyright.hook

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import subprocess
99
import platform
1010

1111
COPYRIGHT = '''
12-
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
12+
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
1313

1414
Licensed under the Apache License, Version 2.0 (the "License");
1515
you may not use this file except in compliance with the License.

contrib/inference/README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Embed Paddle Inference in Your Application
2+
3+
Paddle inference offers the APIs in `C` and `C++` languages.
4+
5+
One can easily deploy a model trained by Paddle following the steps as below:
6+
7+
1. Optimize the native model;
8+
2. Write some codes for deployment.
9+
10+
11+
Let's explain the steps in detail.
12+
13+
## Optimize the native Fluid Model
14+
15+
The native model that get from the training phase needs to be optimized for that.
16+
17+
- Clean the noise such as the cost operators that do not need inference;
18+
- Prune unnecessary computation fork that has nothing to do with the output;
19+
- Remove extraneous variables;
20+
- Memory reuse for native Fluid executor;
21+
- Translate the model storage format to some third-party engine's, so that the inference API can utilize the engine for acceleration;
22+
23+
We have an official tool to do the optimization, call `paddle_inference_optimize --help` for more information.
24+
25+
## Write some codes
26+
27+
Read `paddle_inference_api.h` for more information.
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <string>
18+
#include <vector>
19+
20+
namespace paddle {
21+
22+
class Predictor {
23+
public:
24+
struct Attr;
25+
Predictor() = default;
26+
27+
// Build the network before inference.
28+
bool Init(const Attr& attr);
29+
30+
// Predict an record.
31+
// Arguments:
32+
// inputs: the name of the input variables.
33+
// outputs: the name of the output varaibles.
34+
// input_shapes: the shape of the input variables.
35+
// output_shapes: the shape of the output variables.
36+
// input_data: the data of the input variables.
37+
// output_data: the data of the output variables.
38+
bool Run(const std::vector<std::string>& inputs,
39+
const std::vector<std::string>& outputs,
40+
const std::vector<std::vector<int>>& input_shapes,
41+
const std::vector<std::vector<int>>& output_shapes,
42+
const std::vector<std::vector<float>>& input_data,
43+
std::vector<std::vector<float>>* output_data);
44+
45+
// Clone a predictor that share the model weights.
46+
Predictor* Clone();
47+
48+
// Destroy the Predictor.
49+
~Predictor();
50+
51+
struct Attr {
52+
enum class EngineKind;
53+
54+
std::string model_dir; // path to the model directory.
55+
bool enable_engine{false}; // Enable to execute (part of) the model on
56+
// third-party engines.
57+
EngineKind engine_kind{Attr::EngineKind::kNone};
58+
59+
enum class EngineKind {
60+
kNone = -1, // Use the native Fluid facility.
61+
kAnakin, // Use Anakin for inference.
62+
kTensorRT, // Use TensorRT for inference.
63+
kAutoMixedAnakin, // Automatically mix Fluid with Anakin.
64+
kAutoMixedTensorRT, // Automatically mix Fluid with TensorRT.
65+
};
66+
};
67+
};
68+
69+
} // namespace paddle

doc/fluid/design/motivation/api.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ print "The sematic-vector of testA: ", paddle.infer(fA, parameters, testA)
7777

7878
### Example 2. Sharing Parameters between "Models"
7979

80-
We use [GAN](https://github.com/PaddlePaddle/book/tree/develop/gan) in
81-
this example. In the following example program, `d0` and `d1`
80+
We use GAN in this example. In the following example program, `d0` and `d1`
8281
correspond to the two networks in the following figure:
8382

8483
<img src="https://github.com/wangyang59/book/raw/00036f4b0da5225041a6824587c1a01cf20159b1/gan/image/gan_ig.png" width=400 />

doc/fluid/design/multi_devices/operator_kernel_type.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ Different layout leads to different implementation of the operator kernel. There
7575
7676
- The inference of Layout is at run-time, not at compile-time.
7777
78-
- Every operator has to implement different kernels for different layouts. Let's take MKLDNN as an example. If we want to implement an MKLDNN convolution operator, we have to implement all the kernels for different layouts, which are listed [here](http://01org.github.io/mkl-dnn/structmkldnn_1_1memory.html). And we will have a special macro to register kernels for MKLDNN operators.
78+
- Every operator has to implement different kernels for different layouts. Let's take MKLDNN as an example. If we want to implement an MKLDNN convolution operator, we have to implement all the kernels for different layouts, which are listed [here](http://intel.github.io/mkl-dnn/structmkldnn_1_1memory.html). And we will have a special macro to register kernels for MKLDNN operators.
7979
8080
`Layout` is also defined as a enum variable:
8181
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Distributed Training with NCCL2 and RDMA
2+
3+
When doing distributed multi-GPU training, network bandwith often becomes the
4+
bottle neck. We introduce a way to use NCCL2 to do such training job to
5+
achieve best performace.
6+
7+
## Prepare Hardwares with RDMA and Multiple GPUs
8+
9+
I'm using two Linux servers each of them is installed with 8 GPUs and
10+
one 100Gb RDMA card.
11+
Base environment is:
12+
13+
* OS: CentOS 7.4
14+
* RDMA device: "Mellanox Technologies MT27700 Family [ConnectX-4]"
15+
* Kernel version: `4.4.88-1.el7.elrepo.x86_64`
16+
* Docker version: `1.12.6`
17+
* Docker storage driver: `overlay2`
18+
* IP addresses: 192.168.16.30,192.168.16.34
19+
20+
In general, the steps including:
21+
22+
1. Install GPU drivers
23+
1. Install RDMA drivers
24+
1. Install "InfiniBand Support"
25+
1. Use docker to run tests and make sure GPUs and RDMA can work inside
26+
the container.
27+
28+
I'll ommit section "Install GPU drivers" because we can find it easily
29+
somewhere else.
30+
31+
### Install RDMA drivers
32+
33+
For my case, I've got two machines with device
34+
"Mellanox Technologies MT27700 Family [ConnectX-4]" installed. The OS was
35+
"CentOS 7.4" and I updated the kernel to version 4.4 so that docker can
36+
work with latest overlay2 filesystem.
37+
38+
***NOTE: before you start, make sure you have a way to get a console
39+
of the server other than ssh because we may need to re-configure the
40+
network device.***
41+
42+
1. Go to http://www.mellanox.com/page/products_dyn?product_family=26,
43+
download `MLNX_OFED` software in the bottom of the page, and upload it
44+
onto the server.
45+
1. Run `./mlnxofedinstall --add-kernel-support` in the software package.
46+
1. Run `/etc/init.d/openibd restart` to make everything work, note that
47+
this operation may cause the network goes down if you are using this
48+
RDMA device as default network device and use ssh to login the server.
49+
1. Re-configure the network interface, for example:
50+
`ifconfig eth2 192.168.16.30/20 up`, then add routes if needed:
51+
`ip route add default via 192.168.16.1 dev eth2`.
52+
1. Do the same thing on the other node.
53+
1. Use `ping` to test if the two nodes have typical ICMP connection.
54+
1. Use either `udaddy` or `ib_write_bw` to test the network connection is
55+
ready and have the desired bandwith.
56+
57+
### Prepare Docker Image to Run RDMA Programs
58+
59+
1. Build a docker image using cuda base image like: `nvidia/cuda:8.0-cudnn5-devel-ubuntu16.04` and install paddlepaddle whl
60+
package in it.
61+
1. Start a docker container and mount GPU driver libs into it (you can
62+
skip this step if you are using nvidia-docker).
63+
1. Mount RDMA dirvers and libs into the docker image (see below section),
64+
also `udaddy` and `ib_write_bw` if needed.
65+
1. Mount GPU devices and RDMA devices into the container using `--device`
66+
or just use privileged mode `--privileged`.
67+
1. Start the container using host network mode: `--net=host`
68+
69+
### RDMA Library Files Needed
70+
71+
Usually, `MLNX_OFED` install latest supported libs under
72+
`/usr/lib64/mlnx_ofed/valgrind`. Other libs also needed to run RDMA programs
73+
is listed below. These libs must be mounted into the docker container.
74+
75+
* Libs under `/usr/lib64/mlnx_ofed/valgrind`
76+
* libibcm.so
77+
* libibverbs.so
78+
* libmlx4.so
79+
* libmlx5.so
80+
* libmlx5-rdmav2.so
81+
* librdmacm.so
82+
* Other libs:
83+
* libnl-3.so.200
84+
* libnl-route-3.so.200
85+
* libnuma.so.1
86+
87+
## Start to Run the Training Job
88+
89+
Setting NCCL environment variables to turn NCCL switches on and off:
90+
91+
92+
| Env Name | Description |
93+
| --- | --- |
94+
| NCCL_SOCKET_IFNAME | The RDMA device, e.g. eth2 |
95+
| NCCL_P2P_DISABLE | Set to 1 to disable P2P transfer between GPUs |
96+
| NCCL_IB_DISABLE | Set to 1 to disable using RDMA |
97+
| NCCL_IB_CUDA_SUPPORT | Set to 1 to enable GPU Direct if supported |
98+
| NCCL_DEBUG | Set debug level: VERSION, WARN, INFO |
99+
100+
My two servers are: `192.168.16.30,192.168.16.34`, On node 1, Run :
101+
102+
```bash
103+
PADDLE_TRAINER_ID=0 PADDLE_PORT=48372 PADDLE_WORKERS=192.168.16.30,192.168.16.34 POD_IP=192.168.16.30 stdbuf -oL python vgg16.py
104+
```
105+
106+
On node 2, Run:
107+
108+
```bash
109+
PADDLE_TRAINER_ID=1 PADDLE_PORT=48372 PADDLE_WORKERS=192.168.16.30,192.168.16.34 POD_IP=192.168.16.34 stdbuf -oL python vgg16.py
110+
```

paddle/fluid/framework/details/broadcast_op_handle.cc

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ void BroadcastOpHandle::RunImpl() {
3838
out_var_handles.size(), places_.size(),
3939
"The number of output should equal to the number of places.");
4040

41-
// Wait input done, this Wait is asynchronous operation platform::Place
42-
// &in_place;
43-
WaitInputVarGenerated(*in_var_handle);
41+
WaitInputVarGenerated();
4442

4543
std::vector<const Scope *> var_scopes;
4644
for (auto *s : local_scopes_) {
@@ -50,29 +48,9 @@ void BroadcastOpHandle::RunImpl() {
5048
auto *in_var =
5149
var_scopes.at(in_var_handle->scope_idx_)->FindVar(in_var_handle->name_);
5250
PADDLE_ENFORCE_NOT_NULL(in_var);
53-
5451
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
5552

56-
// NOTE: The tensors' Place of input and output must be all on GPU or all on
57-
// CPU.
58-
for (auto *out_var_handle : out_var_handles) {
59-
if (out_var_handle->IsTheSameVar(*in_var_handle)) {
60-
continue;
61-
}
62-
auto t_out_p = out_var_handle->place_;
63-
auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
64-
->FindVar(out_var_handle->name_);
65-
PADDLE_ENFORCE_NOT_NULL(out_var);
66-
if (platform::is_gpu_place(in_tensor.place())) {
67-
PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
68-
"Places of input and output must be all on GPU.");
69-
} else {
70-
t_out_p = platform::CPUPlace();
71-
}
72-
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
73-
VariableVisitor::GetMutableTensor(out_var).mutable_data(t_out_p,
74-
in_tensor.type());
75-
}
53+
InitOutputValue(*in_var_handle, out_var_handles);
7654

7755
if (platform::is_cpu_place(in_tensor.place())) {
7856
for (auto *out_var_handle : out_var_handles) {
@@ -147,11 +125,37 @@ void BroadcastOpHandle::RunImpl() {
147125
}
148126
}
149127

150-
void BroadcastOpHandle::WaitInputVarGenerated(const VarHandle &in_var) {
151-
if (in_var.generated_op_) {
152-
for (auto &pair : dev_ctxes_) {
153-
in_var.generated_op_->Wait(pair.second);
128+
void BroadcastOpHandle::InitOutputValue(
129+
const VarHandle &in_var_handle,
130+
const std::vector<VarHandle *> &out_var_handles) const {
131+
std::vector<const Scope *> var_scopes;
132+
for (auto *s : local_scopes_) {
133+
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
134+
}
135+
auto *in_var =
136+
var_scopes.at(in_var_handle.scope_idx_)->FindVar(in_var_handle.name_);
137+
138+
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
139+
140+
// NOTE: The tensors' Place of input and output must be all on GPU or all on
141+
// CPU.
142+
for (auto *out_var_handle : out_var_handles) {
143+
if (out_var_handle->IsTheSameVar(in_var_handle)) {
144+
continue;
154145
}
146+
auto t_out_p = out_var_handle->place_;
147+
auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
148+
->FindVar(out_var_handle->name_);
149+
PADDLE_ENFORCE_NOT_NULL(out_var);
150+
if (is_gpu_place(in_tensor.place())) {
151+
PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
152+
"Places of input and output must be all on GPU.");
153+
} else {
154+
t_out_p = platform::CPUPlace();
155+
}
156+
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
157+
VariableVisitor::GetMutableTensor(out_var).mutable_data(t_out_p,
158+
in_tensor.type());
155159
}
156160
}
157161

paddle/fluid/framework/details/broadcast_op_handle.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,16 @@ struct BroadcastOpHandle : public OpHandleBase {
5757

5858
protected:
5959
void RunImpl() override;
60-
void WaitInputVarGenerated(const VarHandle &in_var);
6160

6261
private:
6362
const std::vector<Scope *> &local_scopes_;
6463
const std::vector<platform::Place> &places_;
6564
#ifdef PADDLE_WITH_CUDA
6665
const platform::NCCLContextMap *nccl_ctxs_;
6766
#endif
67+
68+
void InitOutputValue(const VarHandle &in_var_handle,
69+
const std::vector<VarHandle *> &out_var_handles) const;
6870
};
6971
} // namespace details
7072
} // namespace framework

paddle/fluid/framework/details/computation_op_handle.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,20 @@ ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
2626
place_(place) {}
2727

2828
void ComputationOpHandle::RunImpl() {
29-
auto *cur_ctx = dev_ctxes_[place_];
30-
for (auto *in : inputs_) {
31-
bool need_wait = in->generated_op_ &&
32-
in->generated_op_->DeviceContext(place_) != cur_ctx;
33-
if (need_wait) {
34-
in->generated_op_->Wait(cur_ctx);
35-
}
36-
}
29+
WaitInputVarGenerated(place_);
3730

3831
this->RunAndRecordEvent([this] {
3932
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
4033
});
4134
}
4235

36+
bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) {
37+
bool need_wait =
38+
in_var && in_var->generated_op_ &&
39+
in_var->generated_op_->DeviceContext(place_) != dev_ctxes_[place_];
40+
return need_wait;
41+
}
42+
4343
std::string ComputationOpHandle::Name() const { return op_->Type(); }
4444
} // namespace details
4545
} // namespace framework

paddle/fluid/framework/details/computation_op_handle.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ struct ComputationOpHandle : public OpHandleBase {
3636
protected:
3737
void RunImpl() override;
3838

39+
virtual bool NeedWait(VarHandleBase *in_var);
40+
3941
private:
4042
std::unique_ptr<OperatorBase> op_;
4143
Scope *scope_;

0 commit comments

Comments
 (0)