@@ -26,6 +26,34 @@ struct AddFunctor {
26
26
inline HOSTDEVICE T operator ()(T a, T b) const { return a + b; }
27
27
};
28
28
29
+ template <typename DeviceContext, typename T>
30
+ void default_elementwise_add (const framework::ExecutionContext& ctx,
31
+ const framework::Tensor* x,
32
+ const framework::Tensor* y, framework::Tensor* z) {
33
+ int axis = ctx.Attr <int >(" axis" );
34
+ ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
35
+ AddFunctor<T>(), z);
36
+ }
37
+
38
+ template <typename DeviceContext, typename T>
39
+ typename std::enable_if<std::is_floating_point<T>::value>::type elementwise_add (
40
+ const framework::ExecutionContext& ctx, const framework::Tensor* x,
41
+ const framework::Tensor* y, framework::Tensor* z) {
42
+ auto eigen_x = framework::EigenVector<T>::Flatten (*x);
43
+ auto eigen_y = framework::EigenVector<T>::Flatten (*y);
44
+ auto eigen_z = framework::EigenVector<T>::Flatten (*z);
45
+
46
+ auto blas = math::GetBlas<DeviceContext, T>(ctx);
47
+ blas.VADD (x->numel (), eigen_x.data (), eigen_y.data (), eigen_z.data ());
48
+ }
49
+
50
+ template <typename DeviceContext, typename T>
51
+ typename std::enable_if<std::is_integral<T>::value>::type elementwise_add (
52
+ const framework::ExecutionContext& ctx, const framework::Tensor* x,
53
+ const framework::Tensor* y, framework::Tensor* z) {
54
+ default_elementwise_add<DeviceContext, T>(ctx, x, y, z);
55
+ }
56
+
29
57
template <typename DeviceContext, typename T>
30
58
class ElementwiseAddKernel : public framework ::OpKernel<T> {
31
59
public:
@@ -36,19 +64,12 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
36
64
const auto y = ctx.Input <Tensor>(" Y" );
37
65
auto z = ctx.Output <Tensor>(" Out" );
38
66
z->mutable_data <T>(ctx.GetPlace ());
39
- int axis = ctx.Attr <int >(" axis" );
40
67
41
68
auto dims_equal = x->dims () == y->dims ();
42
69
if (platform::is_cpu_place (ctx.GetPlace ()) && dims_equal) {
43
- auto eigen_x = framework::EigenVector<T>::Flatten (*x);
44
- auto eigen_y = framework::EigenVector<T>::Flatten (*y);
45
- auto eigen_z = framework::EigenVector<T>::Flatten (*z);
46
-
47
- auto blas = math::GetBlas<DeviceContext, T>(ctx);
48
- blas.VADD (x->numel (), eigen_x.data (), eigen_y.data (), eigen_z.data ());
70
+ elementwise_add<DeviceContext, T>(ctx, x, y, z);
49
71
} else {
50
- ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
51
- AddFunctor<T>(), z);
72
+ default_elementwise_add<DeviceContext, T>(ctx, x, y, z);
52
73
}
53
74
}
54
75
};
0 commit comments