Skip to content

Commit 98c4c78

Browse files
author
Wojciech Uss
authored
Modify relu native implementation 2 (#30996) (#31348)
1 parent 325bfc3 commit 98c4c78

File tree

8 files changed

+49
-23
lines changed

8 files changed

+49
-23
lines changed

cmake/cuda.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ endif(WIN32)
216216
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -w")
217217
# Set :expt-relaxed-constexpr to suppress Eigen warnings
218218
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
219+
# Set :expt-extended-lambda to enable HOSTDEVICE annotation on lambdas
220+
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda")
219221

220222
if(WIN32)
221223
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler \"/wd4244 /wd4267 /wd4819 \"")

paddle/fluid/operators/activation_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1051,7 +1051,7 @@ REGISTER_OPERATOR(
10511051
ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>,
10521052
ops::ActivationDoubleGradOpInplaceInferer);
10531053

1054-
REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);
1054+
REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluCPUFunctor, ReluGradFunctor);
10551055

10561056
REGISTER_OP_CPU_KERNEL(
10571057
relu_grad_grad,

paddle/fluid/operators/activation_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ REGISTER_OP_CUDA_KERNEL(
6060
/* ========================================================================== */
6161

6262
/* =========================== relu register ============================ */
63-
REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);
63+
REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, ReluCUDAFunctor, ReluGradFunctor);
6464

6565
REGISTER_OP_CUDA_KERNEL(
6666
relu_grad_grad,

paddle/fluid/operators/activation_op.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,17 @@ struct ExpGradFunctor : public BaseActivationFunctor<T> {
318318

319319
// relu(x) = max(x, 0)
320320
template <typename T>
321-
struct ReluFunctor : public BaseActivationFunctor<T> {
321+
struct ReluCPUFunctor : public BaseActivationFunctor<T> {
322+
template <typename Device, typename X, typename Out>
323+
void operator()(Device d, X x, Out out) const {
324+
out.device(d) = x.unaryExpr([] HOSTDEVICE(T v) {
325+
return v > static_cast<T>(0) ? v : static_cast<T>(0);
326+
});
327+
}
328+
};
329+
330+
template <typename T>
331+
struct ReluCUDAFunctor : public BaseActivationFunctor<T> {
322332
template <typename Device, typename X, typename Out>
323333
void operator()(Device d, X x, Out out) const {
324334
out.device(d) = x.cwiseMax(static_cast<T>(0));

paddle/fluid/operators/fused/fused_bn_activation_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class FusedBatchNormActKernel<platform::CUDADeviceContext, T>
9393
auto y_v = framework::EigenVector<T>::Flatten(*y);
9494
auto &dev = *dev_ctx.eigen_device();
9595
if (act_type == "relu") {
96-
ReluFunctor<T>()(dev, x_v, y_v);
96+
ReluCUDAFunctor<T>()(dev, x_v, y_v);
9797
} else {
9898
PADDLE_THROW(
9999
platform::errors::Unimplemented("Unsupported activation type"));

paddle/fluid/operators/gru_unit_op.h

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License. */
1818
#include "paddle/fluid/framework/op_registry.h"
1919
#include "paddle/fluid/operators/activation_op.h"
2020
#include "paddle/fluid/operators/math/blas.h"
21+
#include "paddle/fluid/platform/place.h"
2122

2223
namespace paddle {
2324
namespace operators {
@@ -37,19 +38,24 @@ template <typename DeviceContext, typename T>
3738
class GRUUnitKernel : public framework::OpKernel<T> {
3839
public:
3940
template <typename Device, typename X, typename Y>
40-
void ActCompute(const int act_type, const Device& d, X x, Y y) const {
41-
if (act_type == identity)
41+
void ActCompute(const int act_type, const Device& d, X x, Y y,
42+
platform::Place place) const {
43+
if (act_type == identity) {
4244
y.device(d) = x;
43-
else if (act_type == sigmoid)
45+
} else if (act_type == sigmoid) {
4446
SigmoidFunctor<T>()(d, x, y);
45-
else if (act_type == tanh)
47+
} else if (act_type == tanh) {
4648
TanhFunctor<T>()(d, x, y);
47-
else if (act_type == relu)
48-
ReluFunctor<T>()(d, x, y);
49-
else
49+
} else if (act_type == relu) {
50+
if (place == platform::CPUPlace())
51+
ReluCPUFunctor<T>()(d, x, y);
52+
else
53+
ReluCUDAFunctor<T>()(d, x, y);
54+
} else {
5055
PADDLE_THROW(platform::errors::Unimplemented(
5156
"Unsupported activation type, only supports identity, sigmoid, tanh "
5257
"and relu."));
58+
}
5359
}
5460

5561
void Compute(const framework::ExecutionContext& context) const override {
@@ -97,11 +103,13 @@ class GRUUnitKernel : public framework::OpKernel<T> {
97103
Eigen::array<int, 2> extents{{batch_size, frame_size}};
98104
Eigen::array<int, 2> u_offsets{{0, 0}};
99105
ActCompute(context.Attr<int>("gate_activation"), place,
100-
g.slice(u_offsets, extents), g.slice(u_offsets, extents));
106+
g.slice(u_offsets, extents), g.slice(u_offsets, extents),
107+
context.GetPlace());
101108
auto u = g.slice(u_offsets, extents); // update gate
102109
Eigen::array<int, 2> r_offsets{{0, frame_size}};
103110
ActCompute(context.Attr<int>("gate_activation"), place,
104-
g.slice(r_offsets, extents), g.slice(r_offsets, extents));
111+
g.slice(r_offsets, extents), g.slice(r_offsets, extents),
112+
context.GetPlace());
105113
auto r = g.slice(r_offsets, extents); // reset gate
106114
r_h_p.device(place) = r * h_p; // reset previous hidden state
107115
blas.GEMM(false, false, batch_size, frame_size, frame_size, 1,
@@ -111,7 +119,8 @@ class GRUUnitKernel : public framework::OpKernel<T> {
111119

112120
Eigen::array<int, 2> c_offsets{{0, frame_size * 2}};
113121
ActCompute(context.Attr<int>("activation"), place,
114-
g.slice(c_offsets, extents), g.slice(c_offsets, extents));
122+
g.slice(c_offsets, extents), g.slice(c_offsets, extents),
123+
context.GetPlace());
115124
auto c = g.slice(c_offsets, extents); // output candidate
116125

117126
// calculate final output

paddle/fluid/operators/lstmp_op.h

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ limitations under the License. */
2222
#include "paddle/fluid/operators/math/detail/activation_functions.h"
2323
#include "paddle/fluid/operators/math/lstm_compute.h"
2424
#include "paddle/fluid/operators/math/sequence2batch.h"
25+
#include "paddle/fluid/platform/place.h"
2526
#include "paddle/fluid/platform/transform.h"
2627

2728
namespace paddle {
@@ -81,18 +82,22 @@ class LSTMPKernel : public framework::OpKernel<T> {
8182
public:
8283
template <typename Device, typename X, typename Y>
8384
void ActCompute(const math::detail::ActivationType act_type, const Device& d,
84-
X x, Y y) const {
85-
if (act_type == math::detail::ActivationType::kIdentity)
85+
X x, Y y, platform::Place place) const {
86+
if (act_type == math::detail::ActivationType::kIdentity) {
8687
y.device(d) = x;
87-
else if (act_type == math::detail::ActivationType::kSigmoid)
88+
} else if (act_type == math::detail::ActivationType::kSigmoid) {
8889
SigmoidFunctor<T>()(d, x, y);
89-
else if (act_type == math::detail::ActivationType::kTanh)
90+
} else if (act_type == math::detail::ActivationType::kTanh) {
9091
TanhFunctor<T>()(d, x, y);
91-
else if (act_type == math::detail::ActivationType::kReLU)
92-
ReluFunctor<T>()(d, x, y);
93-
else
92+
} else if (act_type == math::detail::ActivationType::kReLU) {
93+
if (place == platform::CPUPlace())
94+
ReluCPUFunctor<T>()(d, x, y);
95+
else
96+
ReluCUDAFunctor<T>()(d, x, y);
97+
} else {
9498
PADDLE_THROW(
9599
platform::errors::InvalidArgument("unsupported activation type"));
100+
}
96101
}
97102

98103
void Compute(const framework::ExecutionContext& ctx) const override {
@@ -225,7 +230,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
225230
&proj_t, static_cast<T>(0.0));
226231
if (proj_act != math::detail::ActivationType::kIdentity) {
227232
auto proj_t_dev = EigenMatrix<T>::From(proj_t);
228-
ActCompute(cell_act, place, proj_t_dev, proj_t_dev);
233+
ActCompute(cell_act, place, proj_t_dev, proj_t_dev, ctx.GetPlace());
229234
}
230235
if (proj_clip && proj_clip > 0.0) {
231236
T* x_data = proj_t.data<T>();

paddle/fluid/operators/rnn_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -979,7 +979,7 @@ class RNNCPUKernel : public framework::OpKernel<T> {
979979
} else if (is_rnn_relu(ctx)) {
980980
gate_num = 1;
981981
RnnFunc<
982-
SimpleRNNCell<T, ReluFunctor, math::detail::ActivationType::kReLU>,
982+
SimpleRNNCell<T, ReluCPUFunctor, math::detail::ActivationType::kReLU>,
983983
Layer, SingleLayer, BidirLayer, T>(
984984
ctx, input, weight_list, pre_state[0], nullptr, sequence_length,
985985
state[0], nullptr, output, dropout_mask, num_layers, gate_num,

0 commit comments

Comments
 (0)