Skip to content

Commit 8c7a369

Browse files
committed
fix large tensor sort on ROCm
Currently std::min -> ::min did not work as expected on ROCm when input values >= 2147483648
1 parent 08daf94 commit 8c7a369

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

aten/src/ATen/native/cuda/SortStable.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,9 @@ void launch_stable_sort_kernel(
225225
return;
226226
}
227227
228-
int64_t numel_or_intmax =
229-
std::min(numel, static_cast<int64_t>(std::numeric_limits<int>::max()));
228+
const int64_t intmax = static_cast<int64_t>(std::numeric_limits<int>::max());
229+
// On ROCm, std::min -> ::min did not work as expected on when input values >= 2147483648
230+
int64_t numel_or_intmax = numel < intmax ? numel : intmax;
230231
int64_t nsort = self.size(dim);
231232
int64_t nbatch = (numel_or_intmax / nsort) * nsort;
232233
TORCH_CHECK(nbatch > 0, "Cannot sort dimension of length ", nsort);
@@ -238,7 +239,8 @@ void launch_stable_sort_kernel(
238239
scalar_t* values_ptr = values.mutable_data_ptr<scalar_t>();
239240
int64_t remaining = numel;
240241
while (remaining > 0) {
241-
int64_t n = std::min(remaining, nbatch);
242+
// On ROCm, std::min -> ::min did not work as expected on when input values >= 2147483648
243+
int64_t n = remaining < nbatch ? remaining : nbatch;
242244
int64_t nsegments = n / nsort;
243245
244246
if (nsegments == 1 ||

0 commit comments

Comments
 (0)