|
37 | 37 | ) |
38 | 38 | from torch._inductor.template_heuristics import CUDAConfigHeuristic, GemmConfig |
39 | 39 | from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8 |
| 40 | +from torch.testing._internal.common_device_type import largeTensorTest |
40 | 41 | from torch.testing._internal.common_utils import ( |
41 | 42 | instantiate_parametrized_tests, |
42 | 43 | IS_WINDOWS, |
43 | 44 | parametrize, |
44 | 45 | TEST_WITH_ROCM, |
| 46 | + MI300_ARCH, |
| 47 | + runOnRocmArch, |
| 48 | + skipIfXpu, |
45 | 49 | ) |
46 | 50 | from torch.testing._internal.logging_utils import multiple_logs_to_string |
47 | 51 | from torch.utils._triton import has_triton_tma_device |
|
54 | 58 | from torch._inductor.virtualized import V |
55 | 59 | from torch.fx.experimental.proxy_tensor import make_fx |
56 | 60 | from torch.testing import FileCheck |
57 | | -from torch.testing._internal.common_utils import MI300_ARCH, runOnRocmArch, skipIfXpu |
58 | 61 | from torch.testing._internal.inductor_utils import ( |
59 | 62 | get_func_call, |
60 | 63 | get_kernel_launch, |
@@ -804,6 +807,8 @@ def test_conv_backend(self): |
804 | 807 |
|
805 | 808 | self.assertIn("NoValidChoicesError", str(context.exception)) |
806 | 809 |
|
| 810 | + # Some ROCm GPUs don't have enough VRAM to run all autotune configurations and padding benchmarks |
| 811 | + @largeTensorTest("30 GB", device=GPU_TYPE) |
807 | 812 | def test_non_contiguous_input_mm(self): |
808 | 813 | """ |
809 | 814 | Make sure the triton template can work with non-contiguous inputs without crash. |
@@ -856,6 +861,8 @@ def f(x, y): |
856 | 861 | # TODO: fix accuracy failure of the triton template on XPU. |
857 | 862 | # and enable this test case. |
858 | 863 | @skipIfXpu |
| 864 | + # Some ROCm GPUs don't have enough VRAM to run all autotune configurations and padding benchmarks |
| 865 | + @largeTensorTest("30 GB", device=GPU_TYPE) |
859 | 866 | def test_non_contiguous_input_mm_plus_mm(self): |
860 | 867 | x1 = rand_strided((50257, 32768), (1, 50304), device=GPU_TYPE) |
861 | 868 | y1 = rand_strided((32768, 768), (768, 1), device=GPU_TYPE) |
|
0 commit comments