|
27 | 27 | ) |
28 | 28 | from torch.distributed.tensor import DTensor, init_device_mesh, Shard |
29 | 29 | from torch.distributed.tensor.debug import CommDebugMode |
30 | | -from torch.testing._internal.common_cuda import TEST_CUDA |
| 30 | +from torch.testing._internal.common_cuda import ( |
| 31 | + PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, |
| 32 | + TEST_CUDA, |
| 33 | +) |
31 | 34 | from torch.testing._internal.common_distributed import skip_if_lt_x_gpu |
32 | 35 | from torch.testing._internal.common_fsdp import ( |
33 | 36 | check_sharded_parity, |
|
41 | 44 | ) |
42 | 45 | from torch.testing._internal.common_utils import ( |
43 | 46 | get_cycles_per_ms, |
| 47 | + NAVI4_ARCH, |
44 | 48 | run_tests, |
| 49 | + skipIfRocmArch, |
45 | 50 | wrapSwapTensorsTest, |
46 | 51 | ) |
47 | 52 | from torch.testing._internal.distributed._tensor.common_dtensor import ( |
@@ -94,6 +99,7 @@ def world_size(self) -> int: |
94 | 99 | return 4 |
95 | 100 |
|
96 | 101 | @unittest.skipIf(not TEST_CUDA, "no cuda") |
| 102 | + @skipIfRocmArch(NAVI4_ARCH) # Supported in future releaes |
97 | 103 | def test_param_registration_after_forward(self): |
98 | 104 | """Tests the parameter registration after forward.""" |
99 | 105 | device = torch.device("cuda", 0) |
@@ -200,6 +206,7 @@ def world_size(self) -> int: |
200 | 206 |
|
201 | 207 | @unittest.skipIf(not TEST_CUDA, "no cuda") |
202 | 208 | @wrapSwapTensorsTest(True) |
| 209 | + @skipIfRocmArch(NAVI4_ARCH) # Supported in future releaes |
203 | 210 | def test_to_float64_after_init(self): |
204 | 211 | """Tests that the user can cast the module to float64 after init.""" |
205 | 212 | # NOTE: Test fp64 instead of a lower precision dtype like bf16 for |
@@ -310,6 +317,9 @@ def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]: |
310 | 317 |
|
311 | 318 | @skip_if_lt_x_gpu(2) |
312 | 319 | @compiled_fsdp_test(compile_compute_on_module=Transformer) |
| 320 | + @unittest.skipIf( |
| 321 | + not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA" |
| 322 | + ) |
313 | 323 | def test_train_parity_multi_group(self): |
314 | 324 | """ |
315 | 325 | Tests train parity against DDP when using multiple parameter groups for |
|
0 commit comments