Skip to content

Commit de5dec8

Browse files
authored
[Cherry-pick]Fix gather op bug (#19169)
* fix gather op bug test=release/1.5
1 parent cc3ba76 commit de5dec8

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

paddle/fluid/operators/gather.cu.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,16 @@ __global__ void GatherCUDAKernel(const T* params, const IndexT* indices,
4949
template <typename T, typename IndexT = int>
5050
void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
5151
const Tensor& index, Tensor* output) {
52-
// PADDLE_ENFORCE(platform::is_gpu_place(place));
5352
// check index of shape 1-D
54-
PADDLE_ENFORCE(index.dims().size() == 1 ||
55-
(index.dims().size() == 2 && index.dims()[1] == 1));
53+
if (index.dims().size() == 1) {
54+
PADDLE_ENFORCE_GT(index.dims()[0], 0,
55+
"The index of gather_op should not be empty when the "
56+
"index's rank is 1.");
57+
} else if (index.dims().size() == 2) {
58+
PADDLE_ENFORCE_EQ(index.dims()[1], 1,
59+
" If the index's rank of gather_op is 2, the second "
60+
"dimension should be 1.");
61+
}
5662

5763
int index_size = index.dims()[0];
5864

0 commit comments

Comments
 (0)