Skip to content

Commit 51c12fb

Browse files
authored
fix argsort (#74434)
1 parent edd1126 commit 51c12fb

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

paddle/phi/kernels/gpu/argsort_kernel.cu

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,13 @@ void ArgFullSort(const phi::GPUContext& dev_ctx,
198198
const int64_t num_rows,
199199
const int64_t num_cols,
200200
const bool descending) {
201+
PADDLE_ENFORCE_LE(num_cols,
202+
std::numeric_limits<int>::max(),
203+
::common::errors::PreconditionNotMet(
204+
"The dimension being sorted should be less than "
205+
"2^31, but got %lld. Please check the input tensor. ",
206+
num_cols));
207+
201208
auto cu_stream = dev_ctx.stream();
202209
auto ComputeBlockSize = [](IndType col) {
203210
if (col > 512)
@@ -228,8 +235,14 @@ void ArgFullSort(const phi::GPUContext& dev_ctx,
228235
const int64_t total_elements = num_cols * num_rows;
229236
const int64_t segment_size = num_cols;
230237
const int64_t element_per_call = std::min(max_elements, total_elements);
238+
239+
// make sure element_per_call >= segment_size
240+
const int64_t adjusted_elements_per_call =
241+
std::max(max_elements, segment_size);
242+
231243
// make sure batch size is the multiple of segment_size
232-
const int64_t batch_size = (element_per_call / segment_size) * segment_size;
244+
const int64_t batch_size =
245+
(adjusted_elements_per_call / segment_size) * segment_size;
233246
int64_t offset = 0;
234247
DenseTensor input_indices;
235248

0 commit comments

Comments
 (0)