Skip to content

Commit 37a9437

Browse files
authored
Merge pull request #7538 from JiayiFeng/dev_elementwise_max_min
elementwise max min
2 parents 388aa51 + a37f6ad commit 37a9437

13 files changed

+696
-41
lines changed

paddle/gserver/tests/sequence_recurrent_group.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
22
#
3-
#Licensed under the Apache License, Version 2.0 (the "License");
4-
#you may not use this file except in compliance with the License.
5-
#You may obtain a copy of the License at
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
66
#
77
# http://www.apache.org/licenses/LICENSE-2.0
88
#
9-
#Unless required by applicable law or agreed to in writing, software
10-
#distributed under the License is distributed on an "AS IS" BASIS,
11-
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
#See the License for the specific language governing permissions and
13-
#limitations under the License.
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
1414
from paddle.trainer_config_helpers import *
1515

1616
######################## data source ################################

paddle/operators/elementwise_add_op.h

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -28,39 +28,7 @@ template <typename DeviceContext, typename T>
2828
class ElementwiseAddKernel : public framework::OpKernel<T> {
2929
public:
3030
void Compute(const framework::ExecutionContext& ctx) const override {
31-
using Tensor = framework::Tensor;
32-
33-
auto* x = ctx.Input<Tensor>("X");
34-
auto* y = ctx.Input<Tensor>("Y");
35-
auto* z = ctx.Output<Tensor>("Out");
36-
z->mutable_data<T>(ctx.GetPlace());
37-
TransformFunctor<AddFunctor<T>, T, DeviceContext> functor(
38-
x, y, z, ctx.template device_context<DeviceContext>(), AddFunctor<T>());
39-
40-
auto x_dims = x->dims();
41-
auto y_dims = y->dims();
42-
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
43-
"Rank of first input must >= rank of second input.");
44-
45-
if (x_dims == y_dims) {
46-
functor.Run();
47-
return;
48-
}
49-
50-
int axis = ctx.Attr<int>("axis");
51-
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
52-
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
53-
"Axis should be in range [0, x_dims)");
54-
55-
int pre, n, post;
56-
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
57-
if (post == 1) {
58-
functor.RunRowWise(n, pre);
59-
return;
60-
} else {
61-
functor.RunMidWise(n, pre, post);
62-
return;
63-
}
31+
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx);
6432
}
6533
};
6634

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/elementwise_max_op.h"
16+
#include "paddle/operators/elementwise_op.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
class ElementwiseMaxOpMaker : public ElementwiseOpMaker {
21+
public:
22+
ElementwiseMaxOpMaker(OpProto* proto, OpAttrChecker* op_checker)
23+
: ElementwiseOpMaker(proto, op_checker) {
24+
SetComment("Max", "Out = max(X, Y)");
25+
AddComment(comment_);
26+
}
27+
};
28+
} // namespace operators
29+
} // namespace paddle
30+
31+
namespace ops = paddle::operators;
32+
REGISTER_OP(elementwise_max, ops::ElementwiseOp, ops::ElementwiseMaxOpMaker,
33+
elementwise_max_grad, ops::ElementwiseOpGrad);
34+
REGISTER_OP_CPU_KERNEL(
35+
elementwise_max,
36+
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, float>,
37+
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, double>,
38+
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, int>,
39+
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, int64_t>);
40+
REGISTER_OP_CPU_KERNEL(
41+
elementwise_max_grad,
42+
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, float>,
43+
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, double>,
44+
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, int>,
45+
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#define EIGEN_USE_GPU
16+
#include "paddle/operators/elementwise_max_op.h"
17+
18+
namespace ops = paddle::operators;
19+
20+
REGISTER_OP_CUDA_KERNEL(
21+
elementwise_max,
22+
ops::ElementwiseMaxKernel<paddle::platform::CUDADeviceContext, float>,
23+
ops::ElementwiseMaxKernel<paddle::platform::CUDADeviceContext, double>,
24+
ops::ElementwiseMaxKernel<paddle::platform::CUDADeviceContext, int>,
25+
ops::ElementwiseMaxKernel<paddle::platform::CUDADeviceContext, int64_t>);
26+
REGISTER_OP_CUDA_KERNEL(
27+
elementwise_max_grad,
28+
ops::ElementwiseMaxGradKernel<paddle::platform::CUDADeviceContext, float>,
29+
ops::ElementwiseMaxGradKernel<paddle::platform::CUDADeviceContext, double>,
30+
ops::ElementwiseMaxGradKernel<paddle::platform::CUDADeviceContext, int>,
31+
ops::ElementwiseMaxGradKernel<paddle::platform::CUDADeviceContext,
32+
int64_t>);

paddle/operators/elementwise_max_op.h

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/operators/elementwise_op_function.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
template <typename T>
23+
struct MaxFunctor {
24+
inline HOSTDEVICE T operator()(T a, T b) const { return a > b ? a : b; }
25+
};
26+
27+
template <typename DeviceContext, typename T>
28+
class ElementwiseMaxKernel : public framework::OpKernel<T> {
29+
public:
30+
void Compute(const framework::ExecutionContext& ctx) const override {
31+
ElementwiseComputeEx<MaxFunctor<T>, DeviceContext, T>(ctx);
32+
}
33+
};
34+
35+
template <typename T>
36+
struct ElementwiseMaxGradFunctor {
37+
template <typename Device, typename X, typename Y, typename Z, typename dX,
38+
typename dY, typename dZ>
39+
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) {
40+
auto x_e = framework::EigenVector<T>::Flatten(*x);
41+
auto y_e = framework::EigenVector<T>::Flatten(*y);
42+
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
43+
44+
if (dx) {
45+
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
46+
dx_e.device(d) = (x_e > y_e).template cast<T>() * dz_e;
47+
}
48+
if (dy) {
49+
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
50+
dy_e.device(d) = (x_e <= y_e).template cast<T>() * dz_e;
51+
}
52+
}
53+
};
54+
55+
template <typename T>
56+
struct ElementwiseMaxBroadCastGradFunctor {
57+
template <typename Device, typename X, typename Y, typename Z, typename dX,
58+
typename dY, typename dZ, typename Pre, typename N>
59+
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) {
60+
auto x_e = framework::EigenVector<T>::Flatten(*x);
61+
auto y_e = framework::EigenVector<T>::Flatten(*y);
62+
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
63+
64+
auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n))
65+
.broadcast(Eigen::DSizes<int, 2>(pre, 1))
66+
.reshape(Eigen::DSizes<int, 1>(x_e.size()));
67+
68+
if (dx) {
69+
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
70+
dx_e.device(d) = (x_e > y_e_bcast).template cast<T>() * dz_e;
71+
}
72+
73+
if (dy) {
74+
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
75+
dy_e.device(d) = ((x_e <= y_e_bcast).template cast<T>() * dz_e)
76+
.reshape(Eigen::DSizes<int, 2>(pre, n))
77+
.sum(Eigen::array<int, 1>{{0}});
78+
}
79+
}
80+
};
81+
82+
template <typename T>
83+
struct ElementwiseMaxBroadCast2GradFunctor {
84+
template <typename Device, typename X, typename Y, typename Z, typename dX,
85+
typename dY, typename dZ, typename Pre, typename N, typename Post>
86+
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n,
87+
Post post) {
88+
auto x_e = framework::EigenVector<T>::Flatten(*x);
89+
auto y_e = framework::EigenVector<T>::Flatten(*y);
90+
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
91+
92+
auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1))
93+
.broadcast(Eigen::DSizes<int, 3>(pre, 1, post))
94+
.reshape(Eigen::DSizes<int, 1>(x_e.size()));
95+
if (dx) {
96+
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
97+
dx_e.device(d) = (x_e > y_e_bcast).template cast<T>() * dz_e;
98+
}
99+
100+
if (dy) {
101+
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
102+
dy_e.device(d) = ((x_e <= y_e_bcast).template cast<T>() * dz_e)
103+
.reshape(Eigen::DSizes<int, 3>(pre, n, post))
104+
.sum(Eigen::array<int, 2>{{0, 2}});
105+
}
106+
}
107+
};
108+
109+
template <typename DeviceContext, typename T>
110+
class ElementwiseMaxGradKernel : public framework::OpKernel<T> {
111+
public:
112+
void Compute(const framework::ExecutionContext& ctx) const override {
113+
ElementwiseGradCompute<DeviceContext, T, ElementwiseMaxGradFunctor<T>,
114+
ElementwiseMaxBroadCastGradFunctor<T>,
115+
ElementwiseMaxBroadCast2GradFunctor<T>>(ctx);
116+
}
117+
};
118+
119+
} // namespace operators
120+
} // namespace paddle
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/elementwise_min_op.h"
16+
#include "paddle/operators/elementwise_op.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
class ElementwiseMinOpMaker : public ElementwiseOpMaker {
21+
public:
22+
ElementwiseMinOpMaker(OpProto* proto, OpAttrChecker* op_checker)
23+
: ElementwiseOpMaker(proto, op_checker) {
24+
SetComment("Max", "Out = min(X, Y)");
25+
AddComment(comment_);
26+
}
27+
};
28+
} // namespace operators
29+
} // namespace paddle
30+
31+
namespace ops = paddle::operators;
32+
REGISTER_OP(elementwise_min, ops::ElementwiseOp, ops::ElementwiseMinOpMaker,
33+
elementwise_min_grad, ops::ElementwiseOpGrad);
34+
REGISTER_OP_CPU_KERNEL(
35+
elementwise_min,
36+
ops::ElementwiseMinKernel<paddle::platform::CPUDeviceContext, float>,
37+
ops::ElementwiseMinKernel<paddle::platform::CPUDeviceContext, double>,
38+
ops::ElementwiseMinKernel<paddle::platform::CPUDeviceContext, int>,
39+
ops::ElementwiseMinKernel<paddle::platform::CPUDeviceContext, int64_t>);
40+
REGISTER_OP_CPU_KERNEL(
41+
elementwise_min_grad,
42+
ops::ElementwiseMinGradKernel<paddle::platform::CPUDeviceContext, float>,
43+
ops::ElementwiseMinGradKernel<paddle::platform::CPUDeviceContext, double>,
44+
ops::ElementwiseMinGradKernel<paddle::platform::CPUDeviceContext, int>,
45+
ops::ElementwiseMinGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#define EIGEN_USE_GPU
16+
#include "paddle/operators/elementwise_min_op.h"
17+
18+
namespace ops = paddle::operators;
19+
20+
REGISTER_OP_CUDA_KERNEL(
21+
elementwise_min,
22+
ops::ElementwiseMinKernel<paddle::platform::CUDADeviceContext, float>,
23+
ops::ElementwiseMinKernel<paddle::platform::CUDADeviceContext, double>,
24+
ops::ElementwiseMinKernel<paddle::platform::CUDADeviceContext, int>,
25+
ops::ElementwiseMinKernel<paddle::platform::CUDADeviceContext, int64_t>);
26+
REGISTER_OP_CUDA_KERNEL(
27+
elementwise_min_grad,
28+
ops::ElementwiseMinGradKernel<paddle::platform::CUDADeviceContext, float>,
29+
ops::ElementwiseMinGradKernel<paddle::platform::CUDADeviceContext, double>,
30+
ops::ElementwiseMinGradKernel<paddle::platform::CUDADeviceContext, int>,
31+
ops::ElementwiseMinGradKernel<paddle::platform::CUDADeviceContext,
32+
int64_t>);

0 commit comments

Comments
 (0)