Skip to content

Commit b548ecb

Browse files
panyx0718sneaxiy
authored andcommitted
add stack_op
2 parents bc4f537 + a2c0e52 commit b548ecb

File tree

7 files changed

+540
-79
lines changed

7 files changed

+540
-79
lines changed

paddle/fluid/framework/array.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <cstdint>
18+
#include "paddle/fluid/platform/hostdevice.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
template <typename T, size_t N>
23+
class Array {
24+
static_assert(N > 0, "The size of array must be larger than 0");
25+
26+
public:
27+
HOSTDEVICE Array() {}
28+
29+
HOSTDEVICE explicit Array(const T &val) {
30+
for (size_t i = 0; i < N; ++i) data_[i] = val;
31+
}
32+
33+
HOSTDEVICE const T *Get() const { return data_; }
34+
35+
HOSTDEVICE T *GetMutable() { return data_; }
36+
37+
HOSTDEVICE T &operator[](size_t index) { return data_[index]; }
38+
39+
HOSTDEVICE const T &operator[](size_t index) const { return data_[index]; }
40+
41+
HOSTDEVICE constexpr size_t size() const { return N; }
42+
43+
private:
44+
T data_[N];
45+
};
46+
47+
} // namespace framework
48+
} // namespace paddle

paddle/fluid/operators/stack_op.cc

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/stack_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
struct CPUStackFunctor {
21+
template <typename DeviceContext, typename T>
22+
void operator()(const DeviceContext& ctx, const std::vector<const T*>& x,
23+
T* y, int pre, int n, int post) const {
24+
int total_num = pre * post * n;
25+
for (int idx = 0; idx < total_num; ++idx) {
26+
int i = idx / (n * post);
27+
int which_x = idx / post - i * n;
28+
int x_index = i * post + idx % post;
29+
y[idx] = x[which_x][x_index];
30+
}
31+
}
32+
};
33+
34+
struct CPUStackGradFunctor {
35+
template <typename DeviceContext, typename T>
36+
void operator()(const DeviceContext& ctx, std::vector<T*>& dx, // NOLINT
37+
const T* dy, int pre, int n, int post) const {
38+
int total_num = pre * post * n;
39+
for (int idx = 0; idx < total_num; ++idx) {
40+
int i = idx / (n * post);
41+
int which_x = idx / post - i * n;
42+
int x_index = i * post + idx % post;
43+
dx[which_x][x_index] = dy[idx];
44+
}
45+
}
46+
};
47+
48+
} // namespace operators
49+
} // namespace paddle
50+
51+
namespace plat = paddle::platform;
52+
namespace ops = paddle::operators;
53+
REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker,
54+
ops::StackGradOpDescMaker);
55+
REGISTER_OPERATOR(stack_grad, ops::StackOpGrad);
56+
57+
REGISTER_OP_CPU_KERNEL(
58+
stack,
59+
ops::StackKernel<plat::CPUDeviceContext, float, ops::CPUStackFunctor>,
60+
ops::StackKernel<plat::CPUDeviceContext, double, ops::CPUStackFunctor>);
61+
62+
REGISTER_OP_CPU_KERNEL(stack_grad,
63+
ops::StackGradKernel<plat::CPUDeviceContext, float,
64+
ops::CPUStackGradFunctor>,
65+
ops::StackGradKernel<plat::CPUDeviceContext, double,
66+
ops::CPUStackGradFunctor>);

paddle/fluid/operators/stack_op.cu

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <thrust/device_vector.h>
16+
#include "paddle/fluid/framework/array.h"
17+
#include "paddle/fluid/operators/stack_op.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
template <typename T, typename VecXType>
23+
__global__ void StackCUDAKernel(VecXType x, T* y, int total_num, int n,
24+
int post) {
25+
int idx = threadIdx.x + blockIdx.x * blockDim.x;
26+
if (idx < total_num) {
27+
int i = idx / (n * post);
28+
int which_x = idx / post - i * n;
29+
int x_index = i * post + idx % post;
30+
y[idx] = x[which_x][x_index];
31+
}
32+
}
33+
34+
template <typename T, typename VecDxType>
35+
__global__ void StackGradCUDAKernel(VecDxType dx, const T* dy, int total_num,
36+
int n, int post) {
37+
int idx = threadIdx.x + blockIdx.x * blockDim.x;
38+
if (idx < total_num) {
39+
int i = idx / (n * post);
40+
int which_x = idx / post - i * n;
41+
int x_index = i * post + idx % post;
42+
dx[which_x][x_index] = dy[idx];
43+
}
44+
}
45+
46+
struct GPUStackFunctor {
47+
template <typename DeviceContext, typename T>
48+
void operator()(const DeviceContext& ctx, const std::vector<const T*>& x,
49+
T* y, int pre, int n, int post) const {
50+
int total_num = pre * post * n;
51+
int threads = 512;
52+
int grid = (total_num + threads - 1) / threads;
53+
54+
constexpr auto kMaxThreshold = 16;
55+
if (n <= kMaxThreshold) {
56+
framework::Array<const T*, kMaxThreshold> arr;
57+
for (int i = 0; i < n; ++i) arr[i] = x[i];
58+
StackCUDAKernel<<<grid, threads, 0, ctx.stream()>>>(arr, y, total_num, n,
59+
post);
60+
} else {
61+
VLOG(10) << "Stack more than " << kMaxThreshold
62+
<< " tensors may be slow on GPU.";
63+
thrust::device_vector<const T*> dev_x(x);
64+
StackCUDAKernel<<<grid, threads, 0, ctx.stream()>>>(dev_x.data().get(), y,
65+
total_num, n, post);
66+
}
67+
}
68+
};
69+
70+
struct GPUStackGradFunctor {
71+
template <typename DeviceContext, typename T>
72+
void operator()(const DeviceContext& ctx, std::vector<T*>& dx, // NOLINT
73+
const T* dy, int pre, int n, int post) const {
74+
int total_num = pre * post * n;
75+
int threads = 512;
76+
int grid = (total_num + threads - 1) / threads;
77+
78+
constexpr auto kMaxThreshold = 16;
79+
if (n <= kMaxThreshold) {
80+
framework::Array<T*, kMaxThreshold> arr;
81+
for (int i = 0; i < n; ++i) arr[i] = dx[i];
82+
StackGradCUDAKernel<<<grid, threads, 0, ctx.stream()>>>(
83+
arr, dy, total_num, n, post);
84+
} else {
85+
VLOG(10) << "Stack more than " << kMaxThreshold
86+
<< " tensors may be slow on GPU.";
87+
thrust::device_vector<T*> dev_dx(dx);
88+
StackGradCUDAKernel<<<grid, threads, 0, ctx.stream()>>>(
89+
dev_dx.data().get(), dy, total_num, n, post);
90+
}
91+
}
92+
};
93+
94+
} // namespace operators
95+
} // namespace paddle
96+
97+
namespace plat = paddle::platform;
98+
namespace ops = paddle::operators;
99+
100+
REGISTER_OP_CUDA_KERNEL(
101+
stack,
102+
ops::StackKernel<plat::CUDADeviceContext, float, ops::GPUStackFunctor>,
103+
ops::StackKernel<plat::CUDADeviceContext, double, ops::GPUStackFunctor>);
104+
105+
REGISTER_OP_CUDA_KERNEL(stack_grad,
106+
ops::StackGradKernel<plat::CUDADeviceContext, float,
107+
ops::GPUStackGradFunctor>,
108+
ops::StackGradKernel<plat::CUDADeviceContext, double,
109+
ops::GPUStackGradFunctor>);

0 commit comments

Comments
 (0)