@@ -43,17 +43,31 @@ class SumKernel : public framework::OpKernel<T> {
43
43
out->mutable_data <T>(context.GetPlace ());
44
44
}
45
45
auto result = EigenVector<T>::Flatten (*out);
46
+ auto &place =
47
+ *context.template device_context <DeviceContext>().eigen_device ();
48
+ int start = in_place ? 1 : 0 ;
46
49
if (!in_place) {
47
- math::SetConstant<DeviceContext, T> constant_functor;
48
- constant_functor (context.template device_context <DeviceContext>(), out,
49
- 0.0 );
50
+ if ((in_num >= 2 ) && in_vars[0 ]->IsType <framework::LoDTensor>() &&
51
+ in_vars[1 ]->IsType <framework::LoDTensor>()) {
52
+ auto &in_0 = in_vars[0 ]->Get <framework::LoDTensor>();
53
+ auto &in_1 = in_vars[1 ]->Get <framework::LoDTensor>();
54
+ if (in_0.numel () && in_1.numel ()) {
55
+ auto in_0_e = EigenVector<T>::Flatten (in_0);
56
+ auto in_1_e = EigenVector<T>::Flatten (in_1);
57
+ result.device (place) = in_0_e + in_1_e;
58
+ start = 2 ;
59
+ }
60
+ }
61
+ if (start != 2 ) {
62
+ math::SetConstant<DeviceContext, T> constant_functor;
63
+ constant_functor (context.template device_context <DeviceContext>(),
64
+ out, 0.0 );
65
+ }
50
66
}
51
67
52
68
math::SelectedRowsAddToTensor<DeviceContext, T> functor;
53
- auto &place =
54
- *context.template device_context <DeviceContext>().eigen_device ();
55
69
// If in_place, just skip the first tensor
56
- for (size_t i = in_place ? 1 : 0 ; i < in_num; i++) {
70
+ for (size_t i = start ; i < in_num; i++) {
57
71
if (in_vars[i]->IsType <framework::LoDTensor>()) {
58
72
auto &in_t = in_vars[i]->Get <framework::LoDTensor>();
59
73
if (in_t .numel () == 0 ) {
0 commit comments