Skip to content

Commit e299809

Browse files
committed
[release/2.6] unittest fixes for MI350
1 parent 2feede2 commit e299809

File tree

8 files changed

+20
-14
lines changed

8 files changed

+20
-14
lines changed

aten/src/ATen/native/cuda/SortStable.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,9 @@ void launch_stable_sort_kernel(
226226
return;
227227
}
228228
229-
int64_t numel_or_intmax =
230-
std::min(numel, static_cast<int64_t>(std::numeric_limits<int>::max()));
229+
const int64_t intmax = static_cast<int64_t>(std::numeric_limits<int>::max());
230+
// On ROCm, std::min -> ::min did not work as expected on when input values >= 2147483648
231+
int64_t numel_or_intmax = numel < intmax ? numel : intmax;
231232
int64_t nsort = self.size(dim);
232233
int64_t nbatch = (numel_or_intmax / nsort) * nsort;
233234
TORCH_CHECK(nbatch > 0, "Cannot sort dimension of length ", nsort);
@@ -239,7 +240,8 @@ void launch_stable_sort_kernel(
239240
scalar_t* values_ptr = values.mutable_data_ptr<scalar_t>();
240241
int64_t remaining = numel;
241242
while (remaining > 0) {
242-
int64_t n = std::min(remaining, nbatch);
243+
// On ROCm, std::min -> ::min did not work as expected on when input values >= 2147483648
244+
int64_t n = remaining < nbatch ? remaining : nbatch;
243245
int64_t nsegments = n / nsort;
244246
245247
if (nsegments == 1 ||

test/quantization/core/test_quantized_op.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import random
99
import sys
1010
import unittest
11+
from packaging.version import Version
1112
from typing import NamedTuple, List
1213

1314
import torch
@@ -65,7 +66,7 @@ class PointwisePostOp(NamedTuple):
6566
def avoid_vpmaddubsw_overflow_linear(
6667
batch_size, input_channels, output_channels, X, X_min, X_max, W, W_min, W_max
6768
):
68-
if sys.version_info >= (3, 13):
69+
if Version(np.__version__) >= Version("2.1"):
6970
raise unittest.SkipTest("numpy 2.1 overflow error")
7071
for i, j in np.ndindex((batch_size, output_channels)):
7172
for k in range(0, input_channels // 2 * 2, 2):

test/test_matmul_cuda.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -361,10 +361,9 @@ def test_float8_basics(self, device) -> None:
361361

362362
self._test_tautological_mm(device, size=64, out_dtype=torch.float16)
363363
self._test_tautological_mm(device, size=96, out_dtype=torch.float32)
364-
# hipblaslt does not yet support bfloat16 output
365-
if torch.version.hip is None:
366-
self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16)
367-
with self.assertRaises(RuntimeError):
364+
self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16)
365+
366+
with self.assertRaises(AssertionError if torch.version.hip or device == "cpu" else RuntimeError):
368367
self._test_tautological_mm(device, out_dtype=e5m2_type)
369368

370369
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)

test/test_scatter_gather_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def _test_scatter_base(self, fn, *, device, dtype, is_scalar, reduction,
158158
# When we are running opportunistic_fastatomics, we will expect some floating point rounding
159159
# errors as the order of operation is not guaranteed.
160160
if TEST_WITH_ROCM \
161-
and 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName \
161+
and torch.cuda.get_device_properties(0).gcnArchName[0:5] in ('gfx94', 'gfx95')\
162162
and not torch.are_deterministic_algorithms_enabled():
163163
self.assertEqual(actual, expected, atol=1e-9, rtol=1e-6)
164164
else:

test/test_sort_and_select.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,21 +209,21 @@ def test_stable_sort(self, device, dtype):
209209
)
210210

211211
@onlyCUDA
212-
@dtypes(torch.uint8)
212+
@dtypes(torch.float16)
213213
@largeTensorTest("200GB") # Unfortunately 80GB A100 is not large enough
214214
def test_sort_large(self, device, dtype):
215215
t0 = torch.randperm(8192, device=device).to(dtype)
216216
t = t0.view(1, 8192).expand(2**18 + 1, -1).contiguous()
217217
v, i = t.sort()
218218
del t
219-
iv, im = i.var_mean(dim=0)
219+
iv, im = torch.var_mean(i.to(dtype), dim=0)
220220
del i
221-
vv, vm = v.var_mean(dim=0)
221+
vv, vm = torch.var_mean(v.to(dtype), dim=0)
222222
del v
223223
self.assertEqual(vv, torch.zeros_like(vv))
224224
self.assertEqual(iv, torch.zeros_like(iv))
225-
self.assertEqual(vm, torch.arange(255, dtype=dtype, device=device))
226-
self.assertEqual(im, t0.sort().indices)
225+
self.assertEqual(vm, torch.arange(8192, dtype=dtype, device=device))
226+
self.assertEqual(im, t0.sort().indices, exact_dtype=False)
227227

228228
@dtypes(torch.float32)
229229
def test_sort_restride(self, device, dtype):

test/test_transformers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3302,6 +3302,8 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset,
33023302
fudge_factors['grad_query'] = 650.0
33033303
if dtype == torch.float32:
33043304
fudge_factors['grad_key'] = 90.0
3305+
if "gfx95" in torch.cuda.get_device_properties(0).gcnArchName:
3306+
fudge_factors['grad_value'] = 12.0
33053307

33063308
check_out_and_grad(
33073309
(out_ref, out_lp_ref, out),

torch/_tensor_str.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ def _tensor_str(self, indent):
344344
torch.float8_e5m2fnuz,
345345
torch.float8_e4m3fn,
346346
torch.float8_e4m3fnuz,
347+
torch.float8_e8m0fnu,
347348
]:
348349
self = self.half()
349350

torch/cuda/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1695,6 +1695,7 @@ def addmm_kernel_impl(*args, **kwargs):
16951695
"is_bf16_supported",
16961696
"is_current_stream_capturing",
16971697
"is_initialized",
1698+
"is_tf32_supported",
16981699
"jiterator",
16991700
"list_gpu_processes",
17001701
"make_graphed_callables",

0 commit comments

Comments
 (0)