Skip to content

Commit 56b04e5

Browse files
author
weixing02
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into initializer
2 parents 9394064 + b1224da commit 56b04e5

File tree

8 files changed

+122
-105
lines changed

8 files changed

+122
-105
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,11 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
5959
auto graph = new SSAGraph();
6060
SSAGraph &result = *graph;
6161
std::unordered_set<std::string> og_has_been_broadcast;
62-
result.vars_.resize(places_.size());
62+
63+
// We cannot invoke resize. It is a bug of GCC 4.8
64+
result.vars_ = std::vector<
65+
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
66+
places_.size());
6367

6468
bool is_forwarding = true;
6569
for (auto *op : program.Block(0).AllOps()) {
@@ -147,15 +151,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
147151
if (vars.empty()) { // This device has no data. continue.
148152
continue;
149153
}
150-
auto *prev_grad = &vars[vars.size() - 1];
151-
op_handle->AddInput(prev_grad);
154+
auto &prev_grad = vars[vars.size() - 1];
155+
op_handle->AddInput(prev_grad.get());
152156

153-
auto &var = vars[vars.size()];
154-
var.place_ = p;
155-
var.name_ = og;
156-
var.version_ = vars.size() - 1;
157+
vars.emplace_back(new VarHandle);
158+
auto &var = vars.back();
159+
var->place_ = p;
160+
var->name_ = og;
161+
var->version_ = vars.size() - 1;
157162

158-
op_handle->AddOutput(&var);
163+
op_handle->AddOutput(var.get());
159164
}
160165
#else
161166
PADDLE_ENFORCE("Not implemented");

paddle/fluid/framework/details/ssa_graph.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
#include <map>
1818
#include <string>
19+
#include <vector>
20+
1921
#include "paddle/fluid/framework/details/op_handle_base.h"
2022
#include "paddle/fluid/framework/details/var_handle.h"
2123

@@ -24,7 +26,9 @@ namespace framework {
2426
namespace details {
2527

2628
struct SSAGraph {
27-
std::vector<std::unordered_map<std::string, std::map<int, VarHandle>>> vars_;
29+
std::vector<
30+
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
31+
vars_;
2832
// aux variables to represent dependency. Useful to resolve data hazard.
2933
std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;
3034
std::vector<std::unique_ptr<OpHandleBase>> ops_;

paddle/fluid/framework/details/ssa_graph_builder.cc

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
2727
auto it_old = name_pair.second.rbegin();
2828
++it_old;
2929
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
30-
auto *write_op = it_new->second.generated_op_;
31-
auto &read_ops = it_old->second.pending_ops_;
30+
auto *write_op = (*it_new)->generated_op_;
31+
auto &read_ops = (*it_old)->pending_ops_;
3232

3333
for (auto *read_op : read_ops) {
3434
// Manually add a dependency var from read_op to write_op;
@@ -54,14 +54,15 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
5454
auto &var_holder = var_holders[each_var_name];
5555
VarHandle *var = nullptr;
5656
if (var_holder.empty()) {
57+
var_holder.emplace_back(new VarHandle);
5758
auto &init_var = var_holder[0];
58-
init_var.place_ = place;
59-
init_var.name_ = each_var_name;
60-
init_var.generated_op_ = nullptr;
61-
init_var.version_ = 0;
62-
var = &init_var;
59+
init_var->place_ = place;
60+
init_var->name_ = each_var_name;
61+
init_var->generated_op_ = nullptr;
62+
init_var->version_ = 0;
63+
var = init_var.get();
6364
} else {
64-
var = &var_holder.rbegin()->second;
65+
var = var_holder.rbegin()->get();
6566
}
6667
return var;
6768
}
@@ -72,19 +73,20 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
7273
size_t place_offset) {
7374
auto &vars = graph->vars_[place_offset][each_var_name];
7475
size_t version = vars.size();
75-
auto &var = vars[version];
76-
var.version_ = version;
77-
var.name_ = each_var_name;
78-
var.place_ = place;
79-
op_handle->AddOutput(&var);
76+
vars.emplace_back(new VarHandle());
77+
auto &var = vars.back();
78+
var->version_ = version;
79+
var->name_ = each_var_name;
80+
var->place_ = place;
81+
op_handle->AddOutput(var.get());
8082
}
8183

8284
template <typename Callback>
8385
void IterAllVar(const SSAGraph &graph, Callback callback) {
8486
for (auto &each : graph.vars_) {
8587
for (auto &pair1 : each) {
8688
for (auto &pair2 : pair1.second) {
87-
callback(pair2.second);
89+
callback(*pair2);
8890
}
8991
}
9092
}

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
6969
for (auto &var_map : graph_->vars_) {
7070
for (auto &name_pair : var_map) {
7171
for (auto &version_pair : name_pair.second) {
72-
InsertPendingVar(version_pair.second);
72+
InsertPendingVar(*version_pair);
7373
}
7474
}
7575
}
@@ -95,7 +95,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
9595
for (auto &var_map : graph_->vars_) {
9696
auto it = var_map.find(fetch_var_name);
9797
if (it != var_map.end()) {
98-
fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second);
98+
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
9999
}
100100
}
101101
}

paddle/fluid/operators/elementwise_op_function.h

Lines changed: 75 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#pragma once
16+
#include <algorithm>
1617
#include "paddle/fluid/framework/eigen.h"
1718
#include "paddle/fluid/framework/op_registry.h"
1819
#include "paddle/fluid/framework/operator.h"
1920
#include "paddle/fluid/platform/transform.h"
2021

2122
#ifdef __NVCC__
23+
#include <cuda.h>
2224
#include <thrust/iterator/iterator_adaptor.h>
23-
#include "paddle/fluid/platform/cuda_helper.h"
2425
constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
2526
#endif
2627

@@ -43,35 +44,35 @@ namespace operators {
4344
*/
4445
inline void get_mid_dims(const framework::DDim& x_dims,
4546
const framework::DDim& y_dims, const int axis,
46-
int& pre, int& n, int& post) {
47-
pre = 1;
48-
n = 1;
49-
post = 1;
47+
int* pre, int* n, int* post) {
48+
*pre = 1;
49+
*n = 1;
50+
*post = 1;
5051
for (int i = 0; i < axis; ++i) {
51-
pre *= x_dims[i];
52+
(*pre) *= x_dims[i];
5253
}
5354

5455
for (int i = 0; i < y_dims.size(); ++i) {
5556
PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i],
5657
"Broadcast dimension mismatch.");
57-
n *= y_dims[i];
58+
(*n) *= y_dims[i];
5859
}
5960

6061
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
61-
post *= x_dims[i];
62+
(*post) *= x_dims[i];
6263
}
6364
}
6465

65-
inline void trim_trailing_singular_dims(framework::DDim& dims) {
66+
inline void trim_trailing_singular_dims(framework::DDim* dims) {
6667
// Remove trailing dimensions of size 1 for y
67-
auto actual_dims_size = dims.size();
68+
auto actual_dims_size = dims->size();
6869
for (; actual_dims_size != 0; --actual_dims_size) {
69-
if (dims[actual_dims_size - 1] != 1) break;
70+
if ((*dims)[actual_dims_size - 1] != 1) break;
7071
}
71-
if (actual_dims_size != dims.size()) {
72-
auto actual_dims = framework::vectorize(dims);
72+
if (actual_dims_size != dims->size()) {
73+
auto actual_dims = framework::vectorize(*dims);
7374
actual_dims.resize(actual_dims_size);
74-
dims = framework::make_ddim(actual_dims);
75+
*dims = framework::make_ddim(actual_dims);
7576
}
7677
}
7778

@@ -159,7 +160,7 @@ class RowwiseTransformIterator<T, platform::CUDADeviceContext>
159160
RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
160161
super_t;
161162
HOSTDEVICE RowwiseTransformIterator(const T* x, int n)
162-
: super_t(x), begin_(x), n_(n){};
163+
: super_t(x), begin_(x), n_(n) {}
163164
friend class thrust::iterator_core_access;
164165

165166
private:
@@ -179,7 +180,7 @@ class MidWiseTransformIterator<T, platform::CUDADeviceContext>
179180
MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
180181
super_t;
181182
HOSTDEVICE MidWiseTransformIterator(const T* x, int n, int post)
182-
: super_t(x), begin_(x), n_(n), post_(post){};
183+
: super_t(x), begin_(x), n_(n), post_(post) {}
183184
friend class thrust::iterator_core_access;
184185

185186
private:
@@ -333,6 +334,55 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out,
333334
}
334335
}
335336
#ifdef __NVCC__
337+
338+
// __shfl_down has been deprecated as of CUDA 9.0.
339+
#if CUDA_VERSION < 9000
340+
template <typename T>
341+
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
342+
return __shfl_down(val, delta);
343+
}
344+
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
345+
#else
346+
#define FULL_WARP_MASK 0xFFFFFFFF
347+
#define CREATE_SHFL_MASK(mask, predicate) \
348+
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
349+
#endif
350+
351+
template <typename T>
352+
__device__ T reduceSum(T val, int tid, int len) {
353+
// TODO(zcd): The warp size should be taken from the
354+
// parameters of the GPU but not specified as 32 simply.
355+
// To make the reduceSum more efficiently,
356+
// I use Warp-Level Parallelism and assume the Warp size
357+
// is 32 which may be different for different GPU,
358+
// but most card's warp size is 32.
359+
__shared__ T shm[32];
360+
const int warpSize = 32;
361+
unsigned mask = 0u;
362+
CREATE_SHFL_MASK(mask, tid < len);
363+
364+
for (int offset = warpSize / 2; offset > 0; offset /= 2)
365+
val += __shfl_down_sync(mask, val, offset);
366+
367+
if (tid < warpSize) shm[tid] = 0;
368+
369+
__syncthreads();
370+
371+
if (tid % warpSize == 0) {
372+
shm[tid / warpSize] = val;
373+
}
374+
375+
CREATE_SHFL_MASK(mask, tid < warpSize);
376+
377+
if (tid < warpSize) {
378+
val = shm[tid];
379+
for (int offset = warpSize / 2; offset > 0; offset /= 2)
380+
val += __shfl_down_sync(mask, val, offset);
381+
}
382+
383+
return val;
384+
}
385+
336386
template <typename T, typename DX_OP, typename DY_OP>
337387
static __global__ void ElemwiseGradBroadcast1CUDAKernel(
338388
const T* x, const T* y, const T* out, const T* dout, int h, int w,
@@ -355,7 +405,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
355405

356406
if (dy) {
357407
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
358-
val = platform::reduceSum(val, tid, h);
408+
val = reduceSum(val, tid, h);
359409
if (threadIdx.x == 0) {
360410
dy[j] = val;
361411
}
@@ -432,7 +482,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
432482
if (dy) {
433483
int h = pre * post;
434484
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
435-
val = platform::reduceSum(val, tid, h);
485+
val = reduceSum(val, tid, h);
436486
if (threadIdx.x == 0) {
437487
dy[j] = val;
438488
}
@@ -472,11 +522,11 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
472522
auto y_dim = y.dims();
473523

474524
axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis);
475-
trim_trailing_singular_dims(y_dim);
525+
trim_trailing_singular_dims(&y_dim);
476526
axis = (y_dim.size() == 0) ? x_dim.size() : axis;
477527

478528
int pre, n, post;
479-
get_mid_dims(x_dim, y_dim, axis, pre, n, post);
529+
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
480530
if (post == 1) {
481531
int h = pre;
482532
int w = n;
@@ -514,7 +564,7 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
514564
}
515565
}
516566
}
517-
};
567+
}
518568

519569
template <typename DeviceContext, typename T, typename functor,
520570
typename broadcastfunctor, typename broadcast2functor>
@@ -543,11 +593,11 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
543593
}
544594

545595
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
546-
trim_trailing_singular_dims(y_dims);
596+
trim_trailing_singular_dims(&y_dims);
547597
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
548598

549599
int pre, n, post;
550-
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
600+
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
551601

552602
if (post == 1) {
553603
broadcastfunctor f;
@@ -582,11 +632,11 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
582632
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
583633
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
584634
"Axis should be in range [0, x_dims)");
585-
trim_trailing_singular_dims(y_dims);
635+
trim_trailing_singular_dims(&y_dims);
586636
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
587637

588638
int pre, n, post;
589-
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
639+
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
590640
if (post == 1) {
591641
functor.RunRowWise(n, pre);
592642
return;

paddle/fluid/platform/cuda_helper.h

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -62,53 +62,5 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
6262
}
6363
#endif
6464

65-
// __shfl_down has been deprecated as of CUDA 9.0.
66-
#if CUDA_VERSION < 9000
67-
template <typename T>
68-
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
69-
return __shfl_down(val, delta);
70-
}
71-
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
72-
#else
73-
#define FULL_WARP_MASK 0xFFFFFFFF
74-
#define CREATE_SHFL_MASK(mask, predicate) \
75-
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
76-
#endif
77-
78-
template <typename T>
79-
__device__ T reduceSum(T val, int tid, int len) {
80-
// TODO(zcd): The warp size should be taken from the
81-
// parameters of the GPU but not specified as 32 simply.
82-
// To make the reduceSum more efficiently,
83-
// I use Warp-Level Parallelism and assume the Warp size
84-
// is 32 which may be different for different GPU,
85-
// but most card's warp size is 32.
86-
__shared__ T shm[32];
87-
const int warpSize = 32;
88-
unsigned mask = 0u;
89-
CREATE_SHFL_MASK(mask, tid < len);
90-
91-
for (int offset = warpSize / 2; offset > 0; offset /= 2)
92-
val += __shfl_down_sync(mask, val, offset);
93-
94-
if (tid < warpSize) shm[tid] = 0;
95-
96-
__syncthreads();
97-
98-
if (tid % warpSize == 0) {
99-
shm[tid / warpSize] = val;
100-
}
101-
102-
CREATE_SHFL_MASK(mask, tid < warpSize);
103-
104-
if (tid < warpSize) {
105-
val = shm[tid];
106-
for (int offset = warpSize / 2; offset > 0; offset /= 2)
107-
val += __shfl_down_sync(mask, val, offset);
108-
}
109-
110-
return val;
111-
}
112-
11365
} // namespace platform
11466
} // namespace paddle

0 commit comments

Comments
 (0)