@@ -36,9 +36,12 @@ void default_elementwise_add(const framework::ExecutionContext& ctx,
36
36
}
37
37
38
38
template <typename DeviceContext, typename T>
39
- typename std::enable_if<std::is_floating_point<T>::value>::type elementwise_add (
40
- const framework::ExecutionContext& ctx, const framework::Tensor* x,
41
- const framework::Tensor* y, framework::Tensor* z) {
39
+ typename std::enable_if<
40
+ std::is_floating_point<T>::value &&
41
+ std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
42
+ elementwise_add (const framework::ExecutionContext& ctx,
43
+ const framework::Tensor* x, const framework::Tensor* y,
44
+ framework::Tensor* z) {
42
45
auto eigen_x = framework::EigenVector<T>::Flatten (*x);
43
46
auto eigen_y = framework::EigenVector<T>::Flatten (*y);
44
47
auto eigen_z = framework::EigenVector<T>::Flatten (*z);
@@ -48,9 +51,12 @@ typename std::enable_if<std::is_floating_point<T>::value>::type elementwise_add(
48
51
}
49
52
50
53
template <typename DeviceContext, typename T>
51
- typename std::enable_if<std::is_integral<T>::value>::type elementwise_add (
52
- const framework::ExecutionContext& ctx, const framework::Tensor* x,
53
- const framework::Tensor* y, framework::Tensor* z) {
54
+ typename std::enable_if<
55
+ !std::is_floating_point<T>::value ||
56
+ !std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
57
+ elementwise_add (const framework::ExecutionContext& ctx,
58
+ const framework::Tensor* x, const framework::Tensor* y,
59
+ framework::Tensor* z) {
54
60
default_elementwise_add<DeviceContext, T>(ctx, x, y, z);
55
61
}
56
62
@@ -66,7 +72,7 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
66
72
z->mutable_data <T>(ctx.GetPlace ());
67
73
68
74
auto dims_equal = x->dims () == y->dims ();
69
- if (platform::is_cpu_place (ctx. GetPlace ()) && dims_equal) {
75
+ if (dims_equal) {
70
76
elementwise_add<DeviceContext, T>(ctx, x, y, z);
71
77
} else {
72
78
default_elementwise_add<DeviceContext, T>(ctx, x, y, z);
0 commit comments