Skip to content

Commit a9d5c44

Browse files
add arg check for argmin and argmax (#59976) (#60006)
1 parent a4eb237 commit a9d5c44

File tree

3 files changed

+15
-0
lines changed

3 files changed

+15
-0
lines changed

paddle/phi/kernels/cpu/arg_min_max_kernel.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,11 @@ void ArgMinMaxKernel(const Context& dev_ctx,
153153
bool flatten,
154154
DataType dtype,
155155
DenseTensor* out) {
156+
PADDLE_ENFORCE_GT(
157+
x.numel(),
158+
0,
159+
phi::errors::InvalidArgument(
160+
"argmin/argmax input numel must > 0, bug got %d", x.numel()));
156161
if (dtype == DataType::UNDEFINED) {
157162
phi::VisitDataTypeTiny(
158163
phi::DataType::INT64,

paddle/phi/kernels/gpu/arg_min_max_kernel.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,11 @@ void ArgMinMaxOpCUDAKernel(const Context& dev_ctx,
211211
bool flatten,
212212
DataType dtype,
213213
DenseTensor* out) {
214+
PADDLE_ENFORCE_GT(
215+
x.numel(),
216+
0,
217+
phi::errors::InvalidArgument(
218+
"argmin/argmax input numel must > 0, bug got %d", x.numel()));
214219
if (dtype == DataType::UNDEFINED) {
215220
phi::VisitDataTypeTiny(
216221
phi::DataType::INT64,

paddle/phi/kernels/xpu/arg_min_max_kernel.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ void ArgMaxKernel(const Context& dev_ctx,
3030
bool flatten,
3131
DataType dtype,
3232
DenseTensor* out) {
33+
PADDLE_ENFORCE_GT(
34+
x.numel(),
35+
0,
36+
phi::errors::InvalidArgument(
37+
"argmin/argmax input numel must > 0, bug got %d", x.numel()));
3338
using XPUType = typename XPUTypeTrait<T>::Type;
3439
PADDLE_ENFORCE_EQ(
3540
(dtype == DataType::UNDEFINED || dtype == DataType::INT32 ||

0 commit comments

Comments
 (0)