Skip to content

Commit b4af472

Browse files
[AUTOGENERATED] [release/2.8] [rocm7.0_internal_testing] skip test_transformer_req_grad on Navi32/Navi4x (#2464)
Cherry-pick of #2385 Co-authored-by: Dmitry Nikolaev <[email protected]>
1 parent 2975ee1 commit b4af472

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

test/distributed/tensor/parallel/test_tp_examples.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
RowwiseParallel,
2828
)
2929
from torch.distributed.tensor.parallel.input_reshard import input_reshard
30+
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FUSED_ATTENTION
31+
from torch.testing._internal.common_device_type import skipIf
3032
from torch.testing._internal.common_utils import (
3133
instantiate_parametrized_tests,
3234
parametrize,
@@ -412,6 +414,7 @@ def test_transformer_training(self, is_seq_parallel, dtype: torch.dtype):
412414
+ f"{str(dtype).split('.')[-1]}_"
413415
+ f"thaw_{'__'.join(sorted({n.rpartition('.')[0].replace('.', '_') for n in thaw})) if thaw else 'all'}",
414416
)
417+
@skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
415418
def test_transformer_req_grad(self, thaw_params, is_seq_parallel, dtype, exp_cnts):
416419
# Sample a subset of `requires_grad` patterns
417420

0 commit comments

Comments
 (0)