@@ -297,7 +297,25 @@ void set_constant_with_place<platform::GPUPlace>(
297
297
template struct RowwiseAdd <platform::GPUPlace, float >;
298
298
template struct RowwiseAdd <platform::GPUPlace, double >;
299
299
template struct ColwiseSum <platform::GPUPlace, float >;
300
- template struct ColwiseSum <platform::GPUPlace, double >;
300
+ // template struct ColwiseSum<platform::GPUPlace, double>;
301
+ // The ColwiseSum<platform::GPUPlace, double> failed in debug mode,
302
+ // and only failed for this case. So reimplemented it.
303
+ template <>
304
+ void ColwiseSum<platform::GPUPlace, double >::operator ()(
305
+ const platform::DeviceContext& context, const framework::Tensor& input,
306
+ framework::Tensor* vector) {
307
+ auto in_dims = input.dims ();
308
+ auto size = input.numel () / in_dims[0 ];
309
+ PADDLE_ENFORCE_EQ (vector->numel (), size);
310
+ framework::Tensor one;
311
+ one.mutable_data <double >({in_dims[0 ]}, context.GetPlace ());
312
+ SetConstant<platform::GPUPlace, double > set;
313
+ set (context, &one, static_cast <double >(1.0 ));
314
+ gemv<platform::GPUPlace, double >(context, true , static_cast <int >(in_dims[0 ]),
315
+ static_cast <int >(in_dims[1 ]), 1.0 ,
316
+ input.data <double >(), one.data <double >(),
317
+ 0.0 , vector->data <double >());
318
+ }
301
319
302
320
} // namespace math
303
321
} // namespace operators
0 commit comments