Skip to content

Commit a46feda

Browse files
authored
[release/2.7][SWDEV-522381] NAVI3x Fixes (#2437)
- Skip test_sac_ilp.py UTs, skipped upstream too - test_comm_mode_features.py UTs skipped because PLATFORM_SUPPORTS_FUSED_ATTENTION not true for NAVI32 - test_schedule_multiproc.py - update tol (cherry picked from commit 9c357c3)
1 parent 5e2f3e1 commit a46feda

File tree

3 files changed

+8
-3
lines changed

3 files changed

+8
-3
lines changed

test/distributed/_tools/test_sac_ilp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
from torch.testing._internal.common_cuda import TEST_CUDA
2121
from torch.testing._internal.common_utils import (
2222
run_tests,
23-
skipIfRocm,
2423
skipIfTorchDynamo,
2524
TestCase,
25+
skipIfRocm,
26+
skipIfRocmArch,
27+
NAVI_ARCH,
2628
)
2729
from torch.testing._internal.distributed._tensor.common_dtensor import (
2830
ModelArgs,
@@ -178,6 +180,7 @@ def test_sac_ilp_case1(self):
178180

179181
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
180182
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
183+
@skipIfRocmArch(NAVI_ARCH)
181184
def test_sac_ilp_case2(self):
182185
"""
183186
This is a case where the memory budget is not binding, meaning that no

test/distributed/pipelining/test_schedule_multiproc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ def test_schedule_with_native_zero_bubble(self, ScheduleClass):
604604
for name, p in stage_module.named_parameters():
605605
ref_p = ref_submod.get_parameter(name)
606606
try:
607-
torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5)
607+
torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=9e-5)
608608
except AssertionError:
609609
print(
610610
f"Parameter test failed for {submod_name}.{name}: {p.grad} vs {ref_p.grad}"

test/distributed/tensor/debug/test_comm_mode_features.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
with_comms,
2525
)
2626

27-
27+
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FUSED_ATTENTION
28+
import unittest
2829
c10d_functional = torch.ops.c10d_functional
2930

3031

@@ -221,6 +222,7 @@ def test_MLP_module_tracing(self):
221222

222223
@skip_unless_torch_gpu
223224
@with_comms
225+
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
224226
def test_transformer_module_tracing(self, is_seq_parallel=False):
225227
"""
226228
tests module-level tracing for more complicated transformer module and

0 commit comments

Comments
 (0)