Skip to content

Commit 6a1ddd6

Browse files
authored
[cherry pick]fix paddle tensor numel check (#41665)
1 parent 43ee4a3 commit 6a1ddd6

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

paddle/fluid/platform/device/gpu/gpu_launch_config.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,11 @@ struct GpuLaunchConfig {
9999
inline GpuLaunchConfig GetGpuLaunchConfig1D(
100100
const platform::CUDADeviceContext& context, int64_t numel,
101101
int vec_size = 1) {
102-
PADDLE_ENFORCE_GT(numel, 0, platform::errors::InvalidArgument(
103-
"element quantity should be greater than 0,"
104-
" but received value is: %d.",
105-
numel));
102+
PADDLE_ENFORCE_GE(numel, 0,
103+
platform::errors::InvalidArgument(
104+
"element quantity should be greater than or equal to 0,"
105+
" but received value is: %d.",
106+
numel));
106107
// Get compute_capability
107108
const int capability = context.GetComputeCapability();
108109
/* If thread number per block is 64/128/256/512, cuda performs better.*/

paddle/phi/backends/gpu/gpu_launch_config.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,12 @@ struct GpuLaunchConfig {
101101
inline GpuLaunchConfig GetGpuLaunchConfig1D(const phi::GPUContext& context,
102102
int64_t numel,
103103
int vec_size = 1) {
104-
PADDLE_ENFORCE_GT(
105-
numel,
106-
0,
107-
phi::errors::InvalidArgument("element quantity should be greater than 0,"
108-
" but received value is: %d.",
109-
numel));
104+
PADDLE_ENFORCE_GE(numel,
105+
0,
106+
phi::errors::InvalidArgument(
107+
"element quantity should be greater than or equal to 0,"
108+
" but received value is: %d.",
109+
numel));
110110
// Get compute_capability
111111
const int capability = context.GetComputeCapability();
112112
/* If thread number per block is 64/128/256/512, cuda performs better.*/

0 commit comments

Comments
 (0)