Skip to content

Commit 01fb2be

Browse files
author
Tomasz Patejko
committed
MKL elementwise add: default implementation used for integral types, float16 and/or GPU
1 parent 6f93248 commit 01fb2be

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

paddle/fluid/operators/elementwise_add_op.h

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,12 @@ void default_elementwise_add(const framework::ExecutionContext& ctx,
3636
}
3737

3838
template <typename DeviceContext, typename T>
39-
typename std::enable_if<std::is_floating_point<T>::value>::type elementwise_add(
40-
const framework::ExecutionContext& ctx, const framework::Tensor* x,
41-
const framework::Tensor* y, framework::Tensor* z) {
39+
typename std::enable_if<
40+
std::is_floating_point<T>::value &&
41+
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
42+
elementwise_add(const framework::ExecutionContext& ctx,
43+
const framework::Tensor* x, const framework::Tensor* y,
44+
framework::Tensor* z) {
4245
auto eigen_x = framework::EigenVector<T>::Flatten(*x);
4346
auto eigen_y = framework::EigenVector<T>::Flatten(*y);
4447
auto eigen_z = framework::EigenVector<T>::Flatten(*z);
@@ -48,9 +51,12 @@ typename std::enable_if<std::is_floating_point<T>::value>::type elementwise_add(
4851
}
4952

5053
template <typename DeviceContext, typename T>
51-
typename std::enable_if<std::is_integral<T>::value>::type elementwise_add(
52-
const framework::ExecutionContext& ctx, const framework::Tensor* x,
53-
const framework::Tensor* y, framework::Tensor* z) {
54+
typename std::enable_if<
55+
!std::is_floating_point<T>::value ||
56+
!std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
57+
elementwise_add(const framework::ExecutionContext& ctx,
58+
const framework::Tensor* x, const framework::Tensor* y,
59+
framework::Tensor* z) {
5460
default_elementwise_add<DeviceContext, T>(ctx, x, y, z);
5561
}
5662

@@ -66,7 +72,7 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
6672
z->mutable_data<T>(ctx.GetPlace());
6773

6874
auto dims_equal = x->dims() == y->dims();
69-
if (platform::is_cpu_place(ctx.GetPlace()) && dims_equal) {
75+
if (dims_equal) {
7076
elementwise_add<DeviceContext, T>(ctx, x, y, z);
7177
} else {
7278
default_elementwise_add<DeviceContext, T>(ctx, x, y, z);

0 commit comments

Comments
 (0)