Skip to content

Commit 4815519

Browse files
ykkk2333jzhang533
authored andcommitted
fix arg_max for int type, *test=kunlun (#41522)
1 parent 3f3f5a2 commit 4815519

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

paddle/fluid/operators/arg_max_op_xpu.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,15 @@ class ArgMaxXPUKernel : public framework::OpKernel<T> {
2828
auto* out = ctx.Output<framework::LoDTensor>("Out");
2929
auto dtype = ctx.Attr<int>("dtype");
3030
PADDLE_ENFORCE_EQ(
31-
(dtype < 0 || dtype == 3), true,
31+
(dtype < 0 || dtype == 2 || dtype == 3), true,
3232
platform::errors::InvalidArgument(
33-
"The attribute of dtype in xpu argmin/argmax must be [%s], but "
33+
"The attribute of dtype in xpu argmin/argmax must be [%s] or [%s], "
34+
"but "
3435
"received [%s]",
3536
paddle::framework::DataTypeToString(
3637
framework::proto::VarType::INT64),
38+
paddle::framework::DataTypeToString(
39+
framework::proto::VarType::INT32),
3740
paddle::framework::DataTypeToString(
3841
static_cast<framework::proto::VarType::Type>(dtype))));
3942

0 commit comments

Comments
 (0)