Skip to content

Commit 84a0b07

Browse files
committed
[release/2.6] fix failed tests on MI350
1 parent b360c6e commit 84a0b07

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

aten/src/ATen/native/cuda/SortStable.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,9 @@ void launch_stable_sort_kernel(
226226
return;
227227
}
228228
229-
int64_t numel_or_intmax =
230-
std::min(numel, static_cast<int64_t>(std::numeric_limits<int>::max()));
229+
// On ROCm, std::min -> ::min did not work as expected on when input values >= 2147483648
230+
const int64_t intmax = static_cast<int64_t>(std::numeric_limits<int>::max());
231+
int64_t numel_or_intmax = numel < intmax ? numel : intmax;
231232
int64_t nsort = self.size(dim);
232233
int64_t nbatch = (numel_or_intmax / nsort) * nsort;
233234
TORCH_CHECK(nbatch > 0, "Cannot sort dimension of length ", nsort);
@@ -239,7 +240,8 @@ void launch_stable_sort_kernel(
239240
scalar_t* values_ptr = values.mutable_data_ptr<scalar_t>();
240241
int64_t remaining = numel;
241242
while (remaining > 0) {
242-
int64_t n = std::min(remaining, nbatch);
243+
// On ROCm, std::min -> ::min did not work as expected on when input values >= 2147483648
244+
int64_t n = remaining < nbatch ? remaining : nbatch;
243245
int64_t nsegments = n / nsort;
244246
245247
if (nsegments == 1 ||

test/test_sort_and_select.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,21 +209,21 @@ def test_stable_sort(self, device, dtype):
209209
)
210210

211211
@onlyCUDA
212-
@dtypes(torch.uint8)
212+
@dtypes(torch.float16)
213213
@largeTensorTest("200GB") # Unfortunately 80GB A100 is not large enough
214214
def test_sort_large(self, device, dtype):
215215
t0 = torch.randperm(8192, device=device).to(dtype)
216216
t = t0.view(1, 8192).expand(2**18 + 1, -1).contiguous()
217217
v, i = t.sort()
218218
del t
219-
iv, im = i.var_mean(dim=0)
219+
iv, im = torch.var_mean(i.to(dtype), dim=0)
220220
del i
221-
vv, vm = v.var_mean(dim=0)
221+
vv, vm = torch.var_mean(v.to(dtype), dim=0)
222222
del v
223223
self.assertEqual(vv, torch.zeros_like(vv))
224224
self.assertEqual(iv, torch.zeros_like(iv))
225-
self.assertEqual(vm, torch.arange(255, dtype=dtype, device=device))
226-
self.assertEqual(im, t0.sort().indices)
225+
self.assertEqual(vm, torch.arange(8192, dtype=dtype, device=device))
226+
self.assertEqual(im, t0.sort().indices, exact_dtype=False)
227227

228228
@dtypes(torch.float32)
229229
def test_sort_restride(self, device, dtype):

0 commit comments

Comments
 (0)