Skip to content

Commit 06c86d4

Browse files
authored
test=develop, bug fix for index_select and roll op (#25251) (#25360)
1 parent bddfa21 commit 06c86d4

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

paddle/fluid/operators/index_select_op.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,23 @@ void IndexSelectInner(const framework::ExecutionContext& context,
5252
TensorToVector(index, context.device_context(), &index_vec);
5353
std::vector<T> out_vec(output->numel());
5454

55+
for (int i = 0; i < index_size; i++) {
56+
PADDLE_ENFORCE_GE(
57+
index_vec[i], 0,
58+
platform::errors::InvalidArgument(
59+
"Variable value (index) of OP(index_select) "
60+
"expected >= 0 and < %ld, but got %ld. Please check input "
61+
"value.",
62+
input_dim[dim], index_vec[i]));
63+
PADDLE_ENFORCE_LT(
64+
index_vec[i], input_dim[dim],
65+
platform::errors::InvalidArgument(
66+
"Variable value (index) of OP(index_select) "
67+
"expected >= 0 and < %ld, but got %ld. Please check input "
68+
"value.",
69+
input_dim[dim], index_vec[i]));
70+
}
71+
5572
VLOG(3) << "Index_Select_Debug; outer_nums: " << outer_nums
5673
<< "; slice_size: " << slice_size << "; input_width: " << input_width
5774
<< "; output_width: " << output_width

python/paddle/fluid/tests/unittests/test_roll_op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_check_grad_normal(self):
4949
class TestRollOpCase2(TestRollOp):
5050
def init_dtype_type(self):
5151
self.dtype = np.float32
52-
self.x_shape = (100, 100, 5)
52+
self.x_shape = (100, 10, 5)
5353
self.shifts = [8, -1]
5454
self.dims = [-1, -2]
5555

@@ -59,7 +59,7 @@ def input_data(self):
5959
self.data_x = np.array(
6060
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
6161

62-
def test_roll_api(self):
62+
def test_roll_op_api(self):
6363
self.input_data()
6464

6565
# case 1:

0 commit comments

Comments
 (0)