Skip to content

Commit 4ac2729

Browse files
author
Yibing Liu
authored
Fix the get of attr pad_value under dtype float16 in pad2d op (#22909)
test=release/1.7
1 parent c3a87e3 commit 4ac2729

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

paddle/fluid/operators/pad2d_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ class Pad2dCPUKernel : public framework::OpKernel<T> {
345345
GetPaddings(pads, context);
346346
auto mode = context.Attr<std::string>("mode");
347347
auto data_format = context.Attr<std::string>("data_format");
348-
T value = context.Attr<T>("pad_value");
348+
T value = static_cast<T>(context.Attr<float>("pad_value"));
349349

350350
auto* x = context.Input<Tensor>("X");
351351
auto in_dims = x->dims();

paddle/fluid/operators/pad2d_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ class Pad2dCUDAKernel : public framework::OpKernel<T> {
314314
GetPaddings(pads, context);
315315
auto mode = context.Attr<std::string>("mode");
316316
auto data_format = context.Attr<std::string>("data_format");
317-
T value = context.Attr<T>("pad_value");
317+
T value = static_cast<T>(context.Attr<float>("pad_value"));
318318

319319
auto* x = context.Input<Tensor>("X");
320320
auto in_dims = x->dims();

0 commit comments

Comments
 (0)