Skip to content

Commit 52007ea

Browse files
authored
Merge pull request #5872 from qingqing01/op_debug
Fix lstm_op and gru_op in debug mode.
2 parents 98700ce + 7fb1f7a commit 52007ea

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

paddle/operators/math/math_function.cu

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,25 @@ void set_constant_with_place<platform::GPUPlace>(
297297
template struct RowwiseAdd<platform::GPUPlace, float>;
298298
template struct RowwiseAdd<platform::GPUPlace, double>;
299299
template struct ColwiseSum<platform::GPUPlace, float>;
300-
template struct ColwiseSum<platform::GPUPlace, double>;
300+
// template struct ColwiseSum<platform::GPUPlace, double>;
301+
// The ColwiseSum<platform::GPUPlace, double> failed in debug mode,
302+
// and only failed for this case. So reimplemented it.
303+
template <>
304+
void ColwiseSum<platform::GPUPlace, double>::operator()(
305+
const platform::DeviceContext& context, const framework::Tensor& input,
306+
framework::Tensor* vector) {
307+
auto in_dims = input.dims();
308+
auto size = input.numel() / in_dims[0];
309+
PADDLE_ENFORCE_EQ(vector->numel(), size);
310+
framework::Tensor one;
311+
one.mutable_data<double>({in_dims[0]}, context.GetPlace());
312+
SetConstant<platform::GPUPlace, double> set;
313+
set(context, &one, static_cast<double>(1.0));
314+
gemv<platform::GPUPlace, double>(context, true, static_cast<int>(in_dims[0]),
315+
static_cast<int>(in_dims[1]), 1.0,
316+
input.data<double>(), one.data<double>(),
317+
0.0, vector->data<double>());
318+
}
301319

302320
} // namespace math
303321
} // namespace operators

0 commit comments

Comments
 (0)