Skip to content

Commit 1455054

Browse files
[rocm7.1_internal_testing] fix large tensor sort on ROCm (#2543)
Currently std::min -> ::min did not work as expected on ROCm when input values >= 2147483648 Replace std::min to ternary statement Also std::min can be replaced by explicit typing std::min<int64_t> fixes on ROCm: test_sort_and_select.py::TestSortAndSelectCUDA::test_sort_large_cuda_float16 error: RuntimeError: Cannot sort dimension of length 8192 Combines upstream PRs: - pytorch#161054 to fix std::min on ROCm - pytorch#155546 fix python test - pytorch#159939 change test dtype from int8 to float16 Fixes: SWDEV-526432
1 parent 0ea0592 commit 1455054

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

aten/src/ATen/native/cuda/SortStable.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,9 @@ void launch_stable_sort_kernel(
225225
return;
226226
}
227227
228-
int64_t numel_or_intmax =
229-
std::min(numel, static_cast<int64_t>(std::numeric_limits<int>::max()));
228+
const int64_t intmax = static_cast<int64_t>(std::numeric_limits<int>::max());
229+
// On ROCm, std::min -> ::min did not work as expected on when input values >= 2147483648
230+
int64_t numel_or_intmax = numel < intmax ? numel : intmax;
230231
int64_t nsort = self.size(dim);
231232
int64_t nbatch = (numel_or_intmax / nsort) * nsort;
232233
TORCH_CHECK(nbatch > 0, "Cannot sort dimension of length ", nsort);
@@ -238,7 +239,8 @@ void launch_stable_sort_kernel(
238239
scalar_t* values_ptr = values.mutable_data_ptr<scalar_t>();
239240
int64_t remaining = numel;
240241
while (remaining > 0) {
241-
int64_t n = std::min(remaining, nbatch);
242+
// On ROCm, std::min -> ::min did not work as expected on when input values >= 2147483648
243+
int64_t n = remaining < nbatch ? remaining : nbatch;
242244
int64_t nsegments = n / nsort;
243245
244246
if (nsegments == 1 ||

test/test_sort_and_select.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def test_stable_sort(self, device, dtype):
215215
)
216216

217217
@onlyCUDA
218-
@dtypes(torch.uint8)
218+
@dtypes(torch.float16)
219219
@largeTensorTest("200GB") # Unfortunately 80GB A100 is not large enough
220220
def test_sort_large(self, device, dtype):
221221
t0 = torch.randperm(8192, device=device).to(dtype)

0 commit comments

Comments
 (0)