Skip to content

Commit 846e99a

Browse files
[slice]fix Fix a possible out-of-bounds bug with index int32 types &&… (#73468)
* [slice]fix Fix a possible out-of-bounds bug with index int32 types && slice-check * slice-check
1 parent d2327f4 commit 846e99a

File tree

3 files changed

+44
-63
lines changed

3 files changed

+44
-63
lines changed

paddle/fluid/pybind/eager_method.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2033,7 +2033,17 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
20332033
int64_t slice_offset =
20342034
static_cast<int64_t>(reinterpret_cast<char*>(sub_tensor.data()) -
20352035
reinterpret_cast<char*>(tensor.data()));
2036-
AdvancedIndex ad = AdvancedIndex(transed_sub_tensor, transed_index);
2036+
2037+
std::vector<paddle::Tensor> transed_index_int64;
2038+
for (auto& indice : transed_index) {
2039+
if (indice.defined() && indice.dtype() == paddle::DataType::INT32) {
2040+
indice = indice.cast(paddle::DataType::INT64); // int32 -> int64
2041+
}
2042+
transed_index_int64.push_back(indice);
2043+
}
2044+
2045+
AdvancedIndex ad =
2046+
AdvancedIndex(transed_sub_tensor, transed_index_int64);
20372047
transed_sub_tensor =
20382048
index_elementwise_put__ad_func(tensor,
20392049
ad.indices,

paddle/phi/kernels/gpu/index_elementwise_put_grad_kernel.cu

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -186,25 +186,14 @@ void LaunchIndexElementwisePutGradCudaKernel(
186186
if (x_grad) {
187187
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
188188

189-
if (index_type == phi::DataType::INT32) {
190-
GPUIndexElementwisePutGradKernel<T, int>(dev_ctx,
191-
x_indices,
192-
input_dims,
193-
input_strides,
194-
index_dims,
195-
index_strides,
196-
slice_offset,
197-
x_grad);
198-
} else if (index_type == phi::DataType::INT64) {
199-
GPUIndexElementwisePutGradKernel<T, int64_t>(dev_ctx,
200-
x_indices,
201-
input_dims,
202-
input_strides,
203-
index_dims,
204-
index_strides,
205-
slice_offset,
206-
x_grad);
207-
}
189+
GPUIndexElementwisePutGradKernel<T, int64_t>(dev_ctx,
190+
x_indices,
191+
input_dims,
192+
input_strides,
193+
index_dims,
194+
index_strides,
195+
slice_offset,
196+
x_grad);
208197
}
209198

210199
auto out_grad_dims = out_grad.dims();
@@ -323,15 +312,13 @@ void IndexElementwisePutGradKernel(
323312
DenseTensor* x_grad,
324313
DenseTensor* value_grad) {
325314
const auto& index_type = indices[0]->dtype();
326-
PADDLE_ENFORCE_EQ(
327-
index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64,
328-
true,
329-
common::errors::InvalidArgument(
330-
"Index holds the wrong type, it holds [%s], but "
331-
"desires to be [%s] or [%s].",
332-
index_type,
333-
phi::DataType::INT32,
334-
phi::DataType::INT64));
315+
PADDLE_ENFORCE_EQ(index_type == phi::DataType::INT64,
316+
true,
317+
common::errors::InvalidArgument(
318+
"Index holds the wrong type, it holds [%s], but "
319+
"desires to be [%s].",
320+
index_type,
321+
phi::DataType::INT64));
335322

336323
std::vector<DenseTensor> tmp_args;
337324
std::vector<const phi::DenseTensor*> int_indices_v =

paddle/phi/kernels/gpu/index_elementwise_put_kernel.cu

Lines changed: 18 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -113,42 +113,26 @@ void IndexElementwisePutKernel(const Context& dev_ctx,
113113
const int64_t slice_offset,
114114
DenseTensor* out) {
115115
const auto& index_type = index[0]->dtype();
116-
PADDLE_ENFORCE_EQ(
117-
index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64,
118-
true,
119-
common::errors::InvalidArgument(
120-
"Index holds the wrong type, it holds [%s], but "
121-
"desires to be [%s] or [%s].",
122-
index_type,
123-
phi::DataType::INT32,
124-
phi::DataType::INT64));
116+
PADDLE_ENFORCE_EQ(index_type == phi::DataType::INT64,
117+
true,
118+
common::errors::InvalidArgument(
119+
"Index holds the wrong type, it holds [%s], but "
120+
"desires to be [%s].",
121+
index_type,
122+
phi::DataType::INT64));
125123

126-
if (out->numel() == 0) return;
127124
dev_ctx.template Alloc<T>(out);
128-
129-
if (index_type == phi::DataType::INT32) {
130-
GPUIndexElementwisePutKernel<T, int>(dev_ctx,
131-
x,
132-
value,
133-
index,
134-
input_dims,
135-
input_strides,
136-
index_dims,
137-
index_strides,
138-
slice_offset,
139-
out);
140-
} else if (index_type == phi::DataType::INT64) {
141-
GPUIndexElementwisePutKernel<T, int64_t>(dev_ctx,
142-
x,
143-
value,
144-
index,
145-
input_dims,
146-
input_strides,
147-
index_dims,
148-
index_strides,
149-
slice_offset,
150-
out);
151-
}
125+
if (out->numel() == 0) return;
126+
GPUIndexElementwisePutKernel<T, int64_t>(dev_ctx,
127+
x,
128+
value,
129+
index,
130+
input_dims,
131+
input_strides,
132+
index_dims,
133+
index_strides,
134+
slice_offset,
135+
out);
152136
}
153137

154138
} // namespace phi

0 commit comments

Comments
 (0)